# 分治法解矩阵乘积

O（n^3）

## 程序范例

```void solve_matrix_multiply_n3(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH])
{
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
for (int k = 0; k < MATRIX_LENGTH; k++) {
c[i][j] += a[i][k] * b[j][k];
}
}
}
}
```

## 程序范例

```void solve_matrix_multiply_dc(int a[][MATRIX_LENGTH], int b[][MATRIX_LENGTH], int c[][MATRIX_LENGTH],
int a_row_start, int a_row_end, int a_col_start, int a_col_end,
int b_row_start, int b_row_end, int b_col_start, int b_col_end)
{
if (a_row_start == a_row_end && a_col_start == a_col_end) {
c[a_row_start][b_col_start] += a[a_row_start][a_col_start] * b[b_row_start][b_col_start];
} else {
/* divide and combine */
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
solve_matrix_multiply_dc(a, b, c,
a_row_start, (a_row_start + a_row_end)>>1, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);

solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, b_col_start, (b_col_start + b_col_end)>>1);
solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, b_col_start, (b_col_start + b_col_end)>>1);

solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, a_col_start, (a_col_start + a_col_end)>>1,
b_row_start, (b_row_start + b_row_end)>>1, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
solve_matrix_multiply_dc(a, b, c,
((a_row_start + a_row_end)>>1) + 1, a_row_end, ((a_col_start + a_col_end)>>1) + 1, a_col_end,
((b_row_start + b_row_end)>>1) + 1, b_row_end, ((b_col_start + b_col_end)>>1) + 1, b_col_end);
}
}

```

O（n^3）

## 测试

### 程序如下

```#include <iostream>
#include <time.h>

using namespace std;

#define MATRIX_LENGTH  512

int main()
{
/* O(n^3) */
int a[MATRIX_LENGTH][MATRIX_LENGTH] {};
int b[MATRIX_LENGTH][MATRIX_LENGTH] {};
int c[MATRIX_LENGTH][MATRIX_LENGTH] {};

std::cout<<"initialization..."<<std::endl;
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
a[i][j] = rand() % 100;
b[i][j] = rand() % 100;
//            std::cout<<a[i][j]<<std::endl;
//std::cout<<b[i][j]<<std::endl;
}
}

std::cout<<"solving by n3"<<std::endl;
clock_t start, finish;
double duration;
/* 测量一个事件持续的时间*/
start = clock();
solve_matrix_multiply_n3(a, b, c);
finish = clock();
duration = (double)(finish - start) / CLOCKS_PER_SEC;
printf( "%f seconds\n", duration );
for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
//std::cout<<c[i][j]<<std::endl;
}
}

std::cout<<"solving by dc"<<std::endl;
/* 8*T(n/2) + n^2 */
int d[MATRIX_LENGTH][MATRIX_LENGTH] {};
start = finish = 0;
duration = 0.0;
start = clock();
solve_matrix_multiply_dc(a, b, d, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH -1, 0, MATRIX_LENGTH - 1, 0, MATRIX_LENGTH -1);
finish = clock();
duration = (double)(finish - start) / CLOCKS_PER_SEC;
printf( "%f seconds\n", duration );

for (int i = 0; i < MATRIX_LENGTH; i++) {
for (int j = 0; j < MATRIX_LENGTH; j++) {
//std::cout<<d[i][j]<<std::endl;
}
}
std::cout<<"solved"<<std::endl;
return 0;
}
```

### 测试结果

```# ./a.out                                                                                                                                                              :( 20 16-12-17 - 23:59:31
initialization...
solving by n3
0.820000 seconds
solving by dc
2.430000 seconds
solved
```