前言:ios
不少朋友看到我寫的《算法導論》系列,可能會以爲雲裏霧裏,不知所云。這裏我再次說明,本系列博文時配合《算法導論》一書,給出該書涉及的算法的c++實現。請結合《算法導論》一書閱讀該系列博文。我這裏有該書的電子版,有須要的朋友能夠留言。c++
正題:算法
今天討論的算法是矩陣乘法的Strassen算法,該算法的精髓在於減小n/2矩陣*n/2矩陣的次數。首先,做一些寫該算法的基礎工做:ide
/* * 矩陣的加法運算 */ void Add(int** matrixA, int** matrixB, int** matrixResult,int length) { for(int i = 0; i < length; i++) { for(int j = 0; j < length; j++) { matrixResult[i][j] = matrixA[i][j] + matrixB[i][j]; } } } /* * 矩陣的減法運算 */ void Sub(int** matrixA, int** matrixB, int** matrixResult,int length) { for(int i = 0; i < length; i++) { for(int j = 0; j < length; j++) { matrixResult[i][j] = matrixA[i][j] - matrixB[i][j]; } } } /* * 矩陣乘法 */ void Mul(int** matrixA, int** matrixB, int** matrixResult){ for(int i = 0; i < 2; ++i) { for(int j = 0; j < 2; ++j) { matrixResult[i][j] = 0; for(int k = 0; k < 2; ++k) { matrixResult[i][j] += matrixA[i][k] * matrixB[k][j]; } } } }
接着進入核心部分:測試
void Strassen(int** matrixA, int** matrixB, int** matrixResult,int length) { int halfLength=length/2; int** a11=new int*[halfLength]; int** a12=new int*[halfLength]; int** a21=new int*[halfLength]; int** a22=new int*[halfLength]; int** b11=new int*[halfLength]; int** b12=new int*[halfLength]; int** b21=new int*[halfLength]; int** b22=new int*[halfLength]; int** s1=new int*[halfLength]; int** s2=new int*[halfLength]; int** s3=new int*[halfLength]; int** s4=new int*[halfLength]; int** s5=new int*[halfLength]; int** s6=new int*[halfLength]; int** s7=new int*[halfLength]; int** matrixResult11=new int*[halfLength]; int** matrixResult12=new int*[halfLength]; int** matrixResult21=new int*[halfLength]; int** matrixResult22=new int*[halfLength]; int** temp=new int*[halfLength]; int** temp1=new int*[halfLength]; if(halfLength==1){ Mul(matrixA, matrixB, matrixResult); }else{ //首先將矩陣A,B 分爲4塊 for(int i = 0; i < halfLength; i++) { a11[i]=new int[halfLength]; a12[i]=new int[halfLength]; a21[i]=new int[halfLength]; a22[i]=new int[halfLength]; b11[i]=new int[halfLength]; b12[i]=new int[halfLength]; b21[i]=new int[halfLength]; b22[i]=new int[halfLength]; s1[i]=new int[halfLength]; s2[i]=new int[halfLength]; s3[i]=new int[halfLength]; s4[i]=new int[halfLength]; s5[i]=new int[halfLength]; s6[i]=new int[halfLength]; s7[i]=new int[halfLength]; matrixResult11[i]=new int[halfLength]; matrixResult12[i]=new int[halfLength]; matrixResult21[i]=new int[halfLength]; matrixResult22[i]=new int[halfLength]; temp[i]=new int[halfLength]; temp1[i]=new int[halfLength]; for(int j = 0; j < halfLength; j++) { a11[i][j]=matrixA[i][j]; a12[i][j]=matrixA[i][j+halfLength]; a21[i][j]=matrixA[i+halfLength][j]; a22[i][j]=matrixA[i+halfLength][j+halfLength]; b11[i][j]=matrixB[i][j]; b12[i][j]=matrixB[i][j+halfLength]; b21[i][j]=matrixB[i+halfLength][j]; b22[i][j]=matrixB[i+halfLength][j+halfLength]; } } //計算s1 Sub(b12, b22, temp,halfLength); Strassen(a11, temp, s1,halfLength); //計算s2 Add(a11, a12, temp,halfLength); Strassen(temp, b22, s2,halfLength); //計算s3 Add(a21, a22, temp,halfLength); Strassen(temp, b11, s3,halfLength); //計算s4 Sub(b21, b11, temp,halfLength); Strassen(a22, temp, s4,halfLength); //計算s5 Add(a11, a22, temp1,halfLength); Add(b11, b22, temp,halfLength); Strassen(temp1, temp, s5,halfLength); //計算s6 Sub(a12, a22, temp1,halfLength); Add(b21, b22, temp,halfLength); Strassen(temp1, temp, s6,halfLength); //計算s7 Sub(a11, a21, temp1,halfLength); Add(b11, b12, temp,halfLength); Strassen(temp1, temp, s7,halfLength); //計算matrixResult11 Add(s5, s4, temp1,halfLength); Sub(temp1, s2, temp,halfLength); Add(temp, s6, matrixResult11,halfLength); //計算matrixResult12 Add(s1, s2, matrixResult12,halfLength); //計算matrixResult21 Add(s3, s4, matrixResult21,halfLength); //計算matrixResult22 Add(s5, s1, temp1,halfLength); Sub(temp1, s3, temp,halfLength); Sub(temp, s7, matrixResult22,halfLength); //結果送回matrixResult中 for(int i = 0; i < halfLength; i++) { for(int j = 0; j < halfLength; j++) { matrixResult[i][j]=matrixResult11[i][j]; matrixResult[i][j+halfLength]=matrixResult12[i][j]; matrixResult[i+halfLength][j]=matrixResult21[i][j]; matrixResult[i+halfLength][j+halfLength]=matrixResult22[i][j]; } delete(a11[i]); delete(a12[i]); delete(a21[i]); delete(a22[i]); delete(b11[i]); delete(b12[i]); delete(b21[i]); delete(b22[i]); delete(s1[i]); delete(s2[i]); delete(s3[i]); delete(s4[i]); delete(s5[i]); delete(s6[i]); delete(s7[i]); delete(matrixResult11[i]); delete(matrixResult12[i]); delete(matrixResult21[i]); delete(matrixResult22[i]); delete(temp[i]); delete(temp1[i]); } delete(a11); delete(a12); delete(a21); delete(a22); delete(b11); delete(b12); delete(b21); delete(b22); delete(s1); delete(s2); delete(s3); delete(s4); delete(s5); delete(s6); delete(s7); delete(matrixResult11); delete(matrixResult12); delete(matrixResult21); delete(matrixResult22); delete(temp); delete(temp1); } }
該算法看着或許有些冗長,幾乎一半都在進行動態指針的初始化和刪除。利用該算法計算矩陣乘的時間複雜度爲θ(n^lg7)。spa
測試一下吧:3d
#include "stdafx.h" #include <iostream> #include "SquareMatrix.h" using namespace std; using namespace dksl; //STRASSEN矩陣乘法算法 const int N=8; //常量N用來定義矩陣的大小 int _tmain(int argc, _TCHAR* argv[]) { int **a=new int*[4]; int **b=new int*[4]; int **c=new int*[4]; for(int i=0;i<4;i++) { a[i]=new int[4]; b[i]=new int[4]; c[i]=new int[4]; for(int j=0;j<4;j++) { a[i][j]=1; b[i][j]=2; } } Strassen(a,b,c,4); for(int i=0;i<4;i++) { for(int j=0;j<4;j++) cout<<c[i][j]<<" "; cout<<endl; } system("PAUSE"); return 0; }