所以我在 C 中实现了 Strassen 算法,即使对于大型矩阵,它也无法击败经典的 ikj 算法。我只使用 2 的幂的矩阵,填充双精度数。我读到对于“大”n,Strassen 算法击败了经典算法,例如对于 n >= 1000。对我来说,情况并非如此。我不使用截止点,因为它不是原始算法的一部分。如果我这样做了,它实际上比经典的要快,但如果没有它,它的表现会更糟。这是预期的行为吗?我在网上找到的关于 Strassen 算法的资料实际上是在使用截止点吗?
我尝试优化 Strassen,只使用存储为一维数组的两个矩阵,将它们作为参数传递,如果可能的话实际上不创建子矩阵,但只是将新大小和指针传递给两个单元格,所以我不分配为很多内存。但也许对于大型矩阵来说,这实际上很糟糕,因为我正在访问整个矩阵并且需要将它存储在主内存中的某个位置。
这是一个 Github 链接:https ://github.com/AlexLoitzl/matmul 该代码尚未正确记录...这里是 Strassen 实现:
void strassen_mult_base(int n, int x_row_length, const double X[], int y_row_length, const double Y[], int z_row_length, double Z[], int base, void (*inner_func)()) {
if (n == 1){
Z[0] = X[0] * Y[0];
return;
}
const int k = n/2;
//Split Matrices X and Y into 4 blocks without any memory allocation by using proper offset
const double *A = X;
const double *B = X + k;
const double *C = X + k*x_row_length;
const double *D = C + k;
const double *E = Y;
const double *F = Y + k;
const double *G = Y + k*y_row_length;
const double *H = G + k;
//Allocate memory for temporary matrices P0 - P6 for our 7 Multiplications
const int size = k*k*sizeof(double);
double *P[7];
for (int i = 0; i < 7; i++) {
P[i] = (double *) malloc(size);
}
//Allocate memory for interim results of calculating P0 - P6
double *S = (double *) malloc(size);
double *R = (double *) malloc(size);
// P0 = A*(F - H);
sub(k, y_row_length, F, y_row_length, H, k, S);
strassen_mult_base(k, x_row_length, A, k, S, k, P[0], base, inner_func);
// P1 = (A + B)*H
add(k, x_row_length, A, x_row_length, B, k, S);
strassen_mult_base(k, k, S, y_row_length, H, k, P[1], base, inner_func);
// P2 = (C + D)*E
add(k, x_row_length, C, x_row_length, D, k, S);
strassen_mult_base(k, k, S, y_row_length, E, k, P[2], base, inner_func);
// P3 = D*(G - E);
sub(k, y_row_length, G, y_row_length, E, k, S);
strassen_mult_base(k, x_row_length, D, k, S, k, P[3], base, inner_func);
// P4 = (A + D)*(E + H)
add(k, x_row_length, A, x_row_length, D, k, S);
add(k, y_row_length, E, y_row_length, H, k, R);
strassen_mult_base(k, k, S, k, R, k, P[4], base, inner_func);
// P5 = (B - D)*(G + H)
sub(k, x_row_length, B, x_row_length, D, k, S);
add(k, y_row_length, G, y_row_length, H, k, R);
strassen_mult_base(k, k, S, k, R, k, P[5], base, inner_func);
// P6 = (A - C)*(E + F)
sub(k, x_row_length, A, x_row_length, C, k, S);
add(k, y_row_length, E, y_row_length, F, k, R);
strassen_mult_base(k, k, S, k, R, k, P[6], base, inner_func);
// Z upper left = (P3 + P4) + (P5 - P1)
add(k, k, P[4], k, P[3], k, S);
sub(k, k, P[5], k, P[1], k, R);
add(k, k, S, k, R, z_row_length, Z);
// Z lower left = P2 + P3
add(k, k, P[2], k, P[3], z_row_length, Z + k*z_row_length);
// Z upper right = P0 + P1
add(k, k, P[0], k, P[1], z_row_length, Z + k);
// Z lower right = (P0 + P4) - (P2 + P6)
add(k, k, P[0], k, P[4], k, S);
add(k, k, P[2], k, P[6], k, R);
sub(k, k, S, k, R, z_row_length, Z + k*(z_row_length + 1));
free(R); // deallocate temp matrices
free(S);
for (int i = 6; i >= 0; i--)
free(P[i]);
}