分治法解矩阵乘积

看到此文,是否觉得体内洪荒之力爆发,饥渴难耐想吐槽、情不自禁想捐赠
本文为原创文章,尊重辛勤劳动,可以免费摘要、推荐或聚合,亦可完整转载,但完整转载需要标明原出处,违者必究。

支付宝微  信

题目

假设有A,B两个矩阵,且其均为n*n维矩阵,n为2的幂(n>=2)。求A与B的乘积。

分治法解矩阵乘积

分治法解矩阵乘积

通过上图我们可以看到书中的利用分治法解决的伪代码。

解决思路一

暴力解法

时间复杂度

O(n^3)

程序范例

void solve_matrix_multiply_n3(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH])
{
    for (int i = 0; i < MATRIX_LENGTH; i++) {
        for (int j = 0; j < MATRIX_LENGTH; j++) {
            for (int k = 0; k < MATRIX_LENGTH; k++) {
                c[i][j] += a[i][k] * b[j][k];
            }
        }
    }
}

解决思路二

利用分治法解决矩阵乘积

程序范例

void solve_matrix_multiply_dc(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH],
        int a_row_start, int a_row_end, int a_col_start, int a_col_end,
        int b_row_start, int b_row_end, int b_col_start, int b_col_end)
{
    if (a_row_start == a_row_end && a_col_start == a_col_end) {
        c[a_row_start][b_col_start] += a[a_row_start][a_col_start] * b[b_row_start][b_col_start];
    } else {
        /* divide and combine */
        solve_matrix_multiply_dc(a, b, c,
                a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
                b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
        solve_matrix_multiply_dc(a, b, c,
                a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
                ((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

        solve_matrix_multiply_dc(a, b, c,
                a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
                b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
        solve_matrix_multiply_dc(a, b, c,
                a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
                ((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);

        solve_matrix_multiply_dc(a, b, c,
                ((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
                b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
        solve_matrix_multiply_dc(a, b, c,
                ((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
                ((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

        solve_matrix_multiply_dc(a, b, c,
                ((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
                b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
        solve_matrix_multiply_dc(a, b, c,
                ((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
                ((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
    }
}

时间复杂度

其时间复杂度不等式为:T(n) = 8*T(n/2) + Θ(n^2),所以其时间复杂度为:

O(n^3)

测试

假设A,B均为521*521的矩阵,并利用0-99的随机数初始化。分别利用暴力解法和分治法进行求解。

程序如下

#include <iostream>
#include <time.h>

using namespace std;

#define MATRIX_LENGTH  512

int main()
{
    /* O(n^3) */
    int a[MATRIX_LENGTH][MATRIX_LENGTH] {};
    int b[MATRIX_LENGTH][MATRIX_LENGTH] {};
    int c[MATRIX_LENGTH][MATRIX_LENGTH] {};

    std::cout<<"initialization..."<<std::endl;
    for (int i = 0; i < MATRIX_LENGTH; i++) {
        for (int j = 0; j < MATRIX_LENGTH; j++) {
            a[i][j] = rand() % 100;
            b[i][j] = rand() % 100;
//            std::cout<<a[i][j]<<std::endl;
            //std::cout<<b[i][j]<<std::endl;
        }
    }


    std::cout<<"solving by n3"<<std::endl;
    clock_t start, finish;
    double duration;
    /* 测量一个事件持续的时间*/
    start = clock();
    solve_matrix_multiply_n3(a, b, c);
    finish = clock();
    duration = (double)(finish - start) / CLOCKS_PER_SEC;
    printf( "%f seconds\n", duration );
    for (int i = 0; i < MATRIX_LENGTH; i++) {
        for (int j = 0; j < MATRIX_LENGTH; j++) {
            //std::cout<<c[i][j]<<std::endl;
        }
    }

    std::cout<<"solving by dc"<<std::endl;
    /* 8*T(n/2) + n^2 */
    int d[MATRIX_LENGTH][MATRIX_LENGTH] {};
    start = finish = 0;
    duration = 0.0;
    start = clock();
    solve_matrix_multiply_dc(a, b, d, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH - 1, 0, MATRIX_LENGTH -1);
    finish = clock();
    duration = (double)(finish - start) / CLOCKS_PER_SEC;
    printf( "%f seconds\n", duration );


    for (int i = 0; i < MATRIX_LENGTH; i++) {
        for (int j = 0; j < MATRIX_LENGTH; j++) {
            //std::cout<<d[i][j]<<std::endl;
        }
    }
    std::cout<<"solved"<<std::endl;
    return 0;
}

测试结果

# ./a.out                                                                                                                                                              :( 20 16-12-17 - 23:59:31
initialization...
solving by n3
0.820000 seconds
solving by dc
2.430000 seconds
solved

总结

看来该情景下,分治法慎用。


文章来源:胡旭博客 -> 分治法解矩阵乘积

参考文章:算法导论(p75  - p79)

转载请注明出处,违者必究!


这是一篇原创文章,如果您觉得有价值,可以通过捐赠来支持我的创作~
捐赠者会展示在博客的某个页面,钱将会用在有价值的地方,思考中...


分类: C/C++, 技术, 算法, 编程 | 标签: , , , , , , | 评论 | Permalink

发表评论

电子邮件地址不会被公开。