题目

假设有A,B两个矩阵,且其均为n*n维矩阵,n为2的幂(n>=2)。求A与B的乘积。

分治法解矩阵乘积
分治法解矩阵乘积

通过上图我们可以看到书中的利用分治法解决的伪代码。

解决思路一

暴力解法

时间复杂度

O(n^3)

程序范例

void solve_matrix_multiply_n3(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH])
{
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
for (int k = 0; k < MATRIX_LENGTH; k++) {
c[i][j] += a[i][k] * b[j][k];
}
}
}
}

解决思路二

利用分治法解决矩阵乘积

程序范例

void solve_matrix_multiply_dc(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH],
int a_row_start, int a_row_end, int a_col_start, int a_col_end,
int b_row_start, int b_row_end, int b_col_start, int b_col_end)
{
if (a_row_start == a_row_end && a_col_start == a_col_end) {
c[a_row_start][b_col_start] += a[a_row_start][a_col_start] * b[b_row_start][b_col_start];
} else {
/* divide and combine */
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);


 
solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
}
}

时间复杂度

其时间复杂度不等式为:T(n) = 8*T(n/2) + Θ(n^2),所以其时间复杂度为:

O(n^3)

测试

假设A,B均为521*521的矩阵,并利用0-99的随机数初始化。分别利用暴力解法和分治法进行求解。

程序如下

#include <iostream>
#include <time.h>

using namespace std;

#define MATRIX_LENGTH 512

int main()
{
/* O(n^3) */
int a[MATRIX_LENGTH][MATRIX_LENGTH] {};
int b[MATRIX_LENGTH][MATRIX_LENGTH] {};
int c[MATRIX_LENGTH][MATRIX_LENGTH] {};

std::cout<<“initialization…”<<std::endl;
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
a[i][j] = rand() % 100;
b[i][j] = rand() % 100;
// std::cout<<a[i][j]<<std::endl;
//std::cout<<b[i][j]<<std::endl;
}
}


 
std::cout<<“solving by n3″<<std::endl;
clock_t start, finish;
double duration;
/* 测量一个事件持续的时间*/
start = clock();
solve_matrix_multiply_n3(a, b, c);
finish = clock();
duration = (double)(finish – start) / CLOCKS_PER_SEC;
printf( “%f seconds\n”, duration );
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
//std::cout<<c[i][j]<<std::endl;
}
}


 
std::cout<<“solving by dc”<<std::endl;
/* 8*T(n/2) + n^2 */
int d[MATRIX_LENGTH][MATRIX_LENGTH] {};
start = finish = 0;
duration = 0.0;
start = clock();
solve_matrix_multiply_dc(a, b, d, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH – 1, 0, MATRIX_LENGTH -1);
finish = clock();
duration = (double)(finish – start) / CLOCKS_PER_SEC;
printf( “%f seconds\n”, duration );

for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
//std::cout<<d[i][j]<<std::endl;
}
}
std::cout<<“solved”<<std::endl;
return 0;
}

测试结果

# ./a.out                                                                                                                                                              :( 20 16-12-17 - 23:59:31
initialization...
solving by n3
0.820000 seconds
solving by dc
2.430000 seconds
solved

总结

看来该情景下,分治法慎用。


文章来源:胡旭博客 -> 分治法解矩阵乘积

参考文章:算法导论(p75  – p79)

转载请注明出处,违者必究!

Share:

Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.