《算法導論》——矩陣乘法的Strassen算法

前言:ios

  不少朋友看到我寫的《算法導論》系列,可能會以爲雲裏霧裏,不知所云。這裏我再次說明,本系列博文時配合《算法導論》一書,給出該書涉及的算法的c++實現。請結合《算法導論》一書閱讀該系列博文。我這裏有該書的電子版,有須要的朋友能夠留言。c++

正題:算法

  今天討論的算法是矩陣乘法的Strassen算法,該算法的精髓在於減小n/2矩陣*n/2矩陣的次數。首先,做一些寫該算法的基礎工做:ide

/*
    * 矩陣的加法運算
    */
    void Add(int** matrixA, int** matrixB, int** matrixResult,int length)
    {
        for(int i = 0; i < length; i++) {
            for(int j = 0; j < length; j++) {
                matrixResult[i][j] = matrixA[i][j] + matrixB[i][j];
            }
        }
    }

    /*
    * 矩陣的減法運算
    */
    void Sub(int** matrixA, int** matrixB, int** matrixResult,int length)
    {
        for(int i = 0; i < length; i++) {
            for(int j = 0; j < length; j++) {
                matrixResult[i][j] = matrixA[i][j] - matrixB[i][j];
            }
        }
    }

    /*
    * 矩陣乘法
    */
    void Mul(int** matrixA, int** matrixB, int** matrixResult){
        for(int i = 0; i < 2; ++i) {
            for(int j = 0; j < 2; ++j) {
                matrixResult[i][j] = 0;
                for(int k = 0; k < 2; ++k) {
                    matrixResult[i][j] += matrixA[i][k] * matrixB[k][j];
                }
            }
        }
    }
View Code

接着進入核心部分:測試

void Strassen(int** matrixA, int** matrixB, int** matrixResult,int length)
    {
        int halfLength=length/2;
        int** a11=new int*[halfLength];
        int** a12=new int*[halfLength];
        int** a21=new int*[halfLength];
        int** a22=new int*[halfLength];

        int** b11=new int*[halfLength];
        int** b12=new int*[halfLength];
        int** b21=new int*[halfLength];
        int** b22=new int*[halfLength];

        int** s1=new int*[halfLength];
        int** s2=new int*[halfLength];
        int** s3=new int*[halfLength];
        int** s4=new int*[halfLength];
        int** s5=new int*[halfLength];
        int** s6=new int*[halfLength];
        int** s7=new int*[halfLength];

        int** matrixResult11=new int*[halfLength];
        int** matrixResult12=new int*[halfLength];
        int** matrixResult21=new int*[halfLength];
        int** matrixResult22=new int*[halfLength];        

        int** temp=new int*[halfLength];
        int** temp1=new int*[halfLength];
        if(halfLength==1){
            Mul(matrixA, matrixB, matrixResult);
        }else{
            //首先將矩陣A,B 分爲4塊
            for(int i = 0; i < halfLength; i++) {    
                a11[i]=new int[halfLength];
                a12[i]=new int[halfLength];
                a21[i]=new int[halfLength];
                a22[i]=new int[halfLength];

                b11[i]=new int[halfLength];
                b12[i]=new int[halfLength];
                b21[i]=new int[halfLength];
                b22[i]=new int[halfLength];

                s1[i]=new int[halfLength];
                s2[i]=new int[halfLength];
                s3[i]=new int[halfLength];
                s4[i]=new int[halfLength];
                s5[i]=new int[halfLength];
                s6[i]=new int[halfLength];
                s7[i]=new int[halfLength];

                matrixResult11[i]=new int[halfLength];
                matrixResult12[i]=new int[halfLength];
                matrixResult21[i]=new int[halfLength];
                matrixResult22[i]=new int[halfLength];

                temp[i]=new int[halfLength];
                temp1[i]=new int[halfLength];
                for(int j = 0; j < halfLength; j++) {
                    a11[i][j]=matrixA[i][j];
                    a12[i][j]=matrixA[i][j+halfLength];
                    a21[i][j]=matrixA[i+halfLength][j];
                    a22[i][j]=matrixA[i+halfLength][j+halfLength];
                    b11[i][j]=matrixB[i][j];
                    b12[i][j]=matrixB[i][j+halfLength];
                    b21[i][j]=matrixB[i+halfLength][j];
                    b22[i][j]=matrixB[i+halfLength][j+halfLength];
                }                
            }

            //計算s1
            Sub(b12, b22, temp,halfLength);
            Strassen(a11, temp, s1,halfLength);
            //計算s2
            Add(a11, a12, temp,halfLength);
            Strassen(temp, b22, s2,halfLength);
            //計算s3
            Add(a21, a22, temp,halfLength);
            Strassen(temp, b11, s3,halfLength);
            //計算s4
            Sub(b21, b11, temp,halfLength);
            Strassen(a22, temp, s4,halfLength);
            //計算s5
            Add(a11, a22, temp1,halfLength);
            Add(b11, b22, temp,halfLength);
            Strassen(temp1, temp, s5,halfLength);
            //計算s6
            Sub(a12, a22, temp1,halfLength);
            Add(b21, b22, temp,halfLength);
            Strassen(temp1, temp, s6,halfLength);
            //計算s7
            Sub(a11, a21, temp1,halfLength);
            Add(b11, b12, temp,halfLength);
            Strassen(temp1, temp, s7,halfLength);

            //計算matrixResult11
            Add(s5, s4, temp1,halfLength);
            Sub(temp1, s2, temp,halfLength);
            Add(temp, s6, matrixResult11,halfLength);
            //計算matrixResult12
            Add(s1, s2, matrixResult12,halfLength);
            //計算matrixResult21
            Add(s3, s4, matrixResult21,halfLength);
            //計算matrixResult22
            Add(s5, s1, temp1,halfLength);
            Sub(temp1, s3, temp,halfLength);
            Sub(temp, s7, matrixResult22,halfLength);

            //結果送回matrixResult中
            for(int i = 0; i < halfLength; i++) {
                for(int j = 0; j < halfLength; j++) {
                    matrixResult[i][j]=matrixResult11[i][j];
                    matrixResult[i][j+halfLength]=matrixResult12[i][j];
                    matrixResult[i+halfLength][j]=matrixResult21[i][j];
                    matrixResult[i+halfLength][j+halfLength]=matrixResult22[i][j];
                }
                delete(a11[i]);
                delete(a12[i]);
                delete(a21[i]);
                delete(a22[i]);

                delete(b11[i]);
                delete(b12[i]);
                delete(b21[i]);
                delete(b22[i]);

                delete(s1[i]);
                delete(s2[i]);
                delete(s3[i]);
                delete(s4[i]);
                delete(s5[i]);
                delete(s6[i]);
                delete(s7[i]);

                delete(matrixResult11[i]);
                delete(matrixResult12[i]);
                delete(matrixResult21[i]);
                delete(matrixResult22[i]);

                delete(temp[i]);
                delete(temp1[i]);
            }
            delete(a11);
            delete(a12);
            delete(a21);
            delete(a22);

            delete(b11);
            delete(b12);
            delete(b21);
            delete(b22);

            delete(s1);
            delete(s2);
            delete(s3);
            delete(s4);
            delete(s5);
            delete(s6);
            delete(s7);

            delete(matrixResult11);
            delete(matrixResult12);
            delete(matrixResult21);
            delete(matrixResult22);

            delete(temp);
            delete(temp1);
        }
    }

該算法看着或許有些冗長,幾乎一半都在進行動態指針的初始化和刪除。利用該算法計算矩陣乘的時間複雜度爲θ(n^lg7)。spa

測試一下吧:3d

#include "stdafx.h"
#include <iostream>
#include "SquareMatrix.h"

using namespace std;
using namespace dksl;


//STRASSEN矩陣乘法算法


const int N=8; //常量N用來定義矩陣的大小
int _tmain(int argc, _TCHAR* argv[])
{
    int **a=new int*[4];
    int **b=new int*[4];
    int **c=new int*[4];
    for(int i=0;i<4;i++)
        {
            a[i]=new int[4];
            b[i]=new int[4];
            c[i]=new int[4];
        for(int j=0;j<4;j++)
        {
            a[i][j]=1;
            b[i][j]=2;
        }
    }
    Strassen(a,b,c,4);
    for(int i=0;i<4;i++)
    {
        for(int j=0;j<4;j++)
            cout<<c[i][j]<<" ";
        cout<<endl;
    }
    system("PAUSE");
    return 0;
}

相關文章
相關標籤/搜索