dp方法論——由矩陣相乘問題學習dp解題思路

前篇戳:dp入門——由分杆問題認識動態規劃html

導語

刷過一些算法題,就會十分珍惜「方法論」這種東西。Leetcode上只有題目、討論和答案,沒有方法論。每每答案看起來十分切中要害,可是從看題目到獲得思路的那一段,就是繞不過去。樓主有段時間曾把這個過程歸結於智商和靈感的結合,直到有天爲了搞懂Leetcode上一位老兄的題型總結,花兩天時間學習了回溯法,忽然有種驚爲天人的感受——原來真正掌握一個算法是應該舉一反三的,而不是將題中一個細節換掉就又成了新題……算法

掌握方法論絕對是一種很爽的感受。看起來好像很花費時間,實際上是一種「由於慢,因此快」的方法。之前可能你學習一個dp題目要大半天;當你花了半個周時間,學會了dp的套路,你會發現,有些medium的dp題甚至不須要半個小時就能作完,並且從頭至尾不需提示,全靠本身!數組

方法論

那麼,怎麼從一個看起來毫無頭緒的問題出發,找到解題的思路並用dp將問題解出來呢?本文以矩陣相乘問題爲例,給出dp問題的通常解題思路。ide

固然,按照思路解題的前提是你已經知道這道題要用dp去解,如何肯定一個問題能夠用dp去解,則是下一篇要討論的話題。post

下面就是動態規劃的通常解題思路:學習

  1. 分析最優解的特徵。
  2. 遞歸地定義最優解的值。
  3. 計算最優解的值。
  4. 根據計算好的信息構造最優解。

看起來很是抽象是吧?在這裏不須要徹底理解。等你看徹底文再回來,保你會有不同的感覺。url

矩陣相乘問題

問題

這是一個看起來可能有點抽象的數學問題,但請你耐心往下看。當你看完解法時,你會驚異於動態規劃的魔力。spa

題目:給出一個由n個矩陣組成的矩陣鏈<A1,A2,...,An>,矩陣Ai的秩爲pi-1×pi。將A1A2...An這個乘積全括號化,使得計算這個乘積所須要的的標量乘法最少。翻譯

全括號化是以一種遞歸的形式定義的:設計

一個全括號化的乘積只有兩種可能:一是一個單個矩陣;二是兩個全括號化的乘積的乘積。

天啦也太繞了,舉個例子吧。對於矩陣鏈<A1,A2,A3,A4>的乘積,共有五種全括號化的方法:

(A1(A2(A3A4))),

(A1((A2A3)A4)),

((A1A2)(A3A4)),

(((A1A2)A3)A4),

((A1(A2A3))A4)

咱們知道矩陣乘法是知足結合律的,因此以上五個式子的乘積相等,可是它們的運算時間是否相等呢?

矩陣乘法的運算時間

咱們知道,矩陣乘法的定義是:

兩個互相兼容的矩陣A,B能夠相乘。互相兼容是指A的列數與B的行數相等。假如A是一個p×q的矩陣,而B是一個q×r的矩陣,則乘積C是一個p×r的矩陣且有

cij = ∑ aik·bkj, k = 1,...,q.

因爲要對C中的每個元素進行計算(共q·r個元素),而每次運算要作q次乘法,因此總的運算時間爲pqr。

來看看讓乘積中的不一樣因子結合對運算時間有什麼影響。假設咱們有 <A1,A2,A3>這個矩陣鏈,三個矩陣的秩分別爲10×100, 100×5和5×50。則

  • ((A1A2)A3)的運算時間爲10×100×5+10×5×50=7500;
  • (A1(A2A3))的運算時間爲100×5×50+10×100×50=75000。

按照不一樣的順序作矩陣乘法,所須要的乘法次數竟相差10倍。

初步分析

按照慣例,咱們來感覺一下窮舉的算法複雜度。

假設有一個長度爲n的矩陣鏈,咱們經過遍歷全部的全括號化的可能性來解題。設全括號化的可能性數目爲P(n)。當n爲1時,矩陣鏈只有一個矩陣,符合全括號化的定義;當n>=2時,全括號化後爲兩個矩陣的乘積,即((...)(...))的形式。用遞歸的思路去分析,則中間兩個括號的分界位置有n-1種可能,以下面豎線所示

A1|A2|A3|...|An

當分界線將矩陣鏈分爲長度爲k和n-k的兩個子矩陣鏈時,全括號化可能性爲P(k)P(n-k)。咱們對全部的k值求和,就得出給整個矩陣鏈全括號化的數目:

P(n) = ∑ P(k)P(n-k), k=1...n-1   (n>=2)

這是一個卡塔蘭數(Catalan Number),它的增加速率爲Ω(4n/n3/2),它的漸進值爲Ω(2n)

對漸進值還不太熟,若是有小夥伴明白「增加速率」和「漸進值」之間的關係,歡迎指教。

總的來講,若是對這個題目使用窮舉法,算法複雜度是指數的。後面咱們分析了dp的算法複雜度,再來比較。

用dp方法論解題

算法的學習永遠沒有「手把手」這一說。若是你在認真學習這篇文章,但願你能作到比你看到的小節思路提早一點。好比,在看第一步前,先對這個題目有一點大體思路,明白讓本身迷茫的點在哪裏;看第x步前,對第x步的內容在心中有一個猜想。這樣作比起徹底放棄思考,只是跟着文章的思路走,收穫會大不少。

第一步:分析最優解的特徵

這一步的精髓是分析最優子解如何構成最優解

在上一節中已經提到,對於n>=2的狀況,全括號化後爲((chain_1)(chain_2))的形式。這樣,問題天然而然地分紅了兩個子問題:求先後兩個子括號中的最優解。

假設對於某種特定的分割(即chain_1chain_2之間的分界線位置固定),chain_1的秩爲m×p,其內部的標量乘法數目爲x;chain_2的秩爲p×n,其內部的標量乘法數目爲y。則整個矩陣鏈的乘法次數爲x+y+mpn。因爲m,p,n是固定的,咱們須要讓x和y爲最小值從而使整個矩陣鏈的乘法次數最小。即,對於某種特定的分割,兩個子括號中的最優解構成整個問題的最優解的一個選項

總結來講,咱們將矩陣乘積簡略地當作兩個子矩陣的乘積,這兩個子矩陣的分界有n-1種可能。對每一種可能,問題被分割成兩個子問題,即求左右兩個子矩陣鏈的最優解。若是遍歷這n-1種可能並選出最好的一個,那就是整個問題的最優解。

第二步:遞歸地定義最優解的值

第二步很是關鍵,是咱們將先後思路打通的一步。

第一步中提出了一個比較簡單的思路,即把矩陣鏈分割成左右兩個子矩陣鏈。既然有了這個初步思路,咱們就來塗鴉一番,看看這個思路是否可行。

對於遞歸性的問題,一個很好的方法是畫遞歸樹,這樣會使得問題看起來比較具象,並且也會暴露一些算法上的問題,好比重疊子樹等。畫遞歸樹的時候,最好舉一個實際的例子。這裏咱們假設有一個長度爲4的矩陣鏈<A1,A2,A3,A4>,簡單地畫一下它的子問題分割:

 

上圖中的數字表示子矩陣鏈的長度,根爲4,即初始矩陣鏈;它能夠分爲1+3,2+2,3+1三種狀況,這三種狀況又能夠各自細分。

這裏暴露了一個問題,請看圖中的兩個塗色的子樹。兩個子樹的節點數字是同樣的。可是左邊這個子樹的根節點3表明的是A2A3A4這個乘積;而右邊這個表明的是A1A2A3這個乘積。因爲A1,A2,A3,A4四個矩陣的秩是未知的,它們極可能不相同,則A1A2A3A2A3A4的最優解也頗有可能不一樣。換言之,它們並非同一個子問題,它們的子子樹也並不相同。

這個問題意味着咱們對子問題的定義不夠嚴謹——子問題不能只用長度這個變量來肯定。也就是說,若是在bottom-up的dp中用一個數組記錄子問題的值,那麼這個數組應該是一個二維數組。子問題不只應該由子矩陣鏈的長度肯定,還要加上起始index這樣的信息。

爲了更通用一些,咱們不用起始index+長度,而選用起始index+結束index的定義方法,這是二維dp的慣用套路,在許多字符串和數組有關的問題中都有用到。

設用一個二位矩陣dp[][]存取子問題的解。定義dp[i][j](1<=i<=j<=n)的值爲Ai...Aj的最小乘法次數。則按照以上的思路,能夠把Ai...Aj再遞歸細分爲子問題Ai...AkAk+1...Aj(i<=k<j),則Ai...Aj的最優解值爲兩個子問題最優解的和+兩個子矩陣鏈相乘的乘法次數。即有

i==j時,dp[i][j] = 0;

i <j時,dp[i][j] = min{dp[i][k] + dp[k+1][j] + pi-1pkpj}, k = i...j-1 (p爲各個矩陣的秩,見題目一節)

到此爲止,最關鍵的一步順利完成啦(樓主寫得好累,擊掌╭(○`∀´○)╯╰(○'◡'○)╮)。在這一步中,咱們遞歸地定義了子問題最優解的值,完成了算法最核心的設計部分。在後面兩步中,咱們只要把上面這兩個式子翻譯成代碼,再注意一些實現細節就能夠了。

第三步:計算最優解的值

細節一

從第二步瓜熟蒂落,咱們會在一個二維數組裏記錄子問題的解。可是按照什麼順序去填這個二維數組是個問題。

仍是舉例子,在<A1,A2,A3,A4>這個矩陣鏈中,咱們會有一個5×5的二維數組,隨便挑選dp[1][4]這個元素舉例。根據第二步中的狀態轉移方程,有

dp[1][4] = min{(dp[1][1]+dp[2][4]+...),(dp[1][2]+dp[3][4]+...),(dp[1][3]+dp[4][4]+...)}

省略號表示咱們此處不需關注pi-1pkpj這一項,只須要看這個格子對其它格子的依賴是什麼樣子。

由上圖能夠看出,要計算某一個元素(粉色邊框),咱們須要其左邊下面的元素(一樣深度的藍色表示一組數據)。

因此,咱們的遍歷方向是從下到上,從左到右

細節二

細心的讀者可能注意到還有一個問題,就是咱們一直在求「最優解的值」,也就是「最小的乘法次數」,但是題目中要求的是「最優解」,也就是「加括號的方式」。

這二者並不矛盾,專一於求解前者可讓咱們先思考相對簡單的問題,一般在求解前者的過程當中,咱們也找出了後者,只是沒有將它記錄下來。

在此題中,咱們能夠選擇用一個一樣的二維矩陣s[][]來記錄後者,其中s[i][j]中記錄Ai...Aj的分割分界線k。

代碼

 1     int matrixChain(int[] p){
 2         int n = p.length - 1; //number of matrices
 3         int[][] dp = new int[n + 1][n + 1]; //we need dp[1][n]
 4         int[][] s = new int[n + 1][n + 1];    //for storing of k
 5         for(int[] row : dp)
 6             Arrays.fill(row, Integer.MAX_VALUE);
 7 
 8         for(int i = 1; i <= n; i++)
 9             dp[i][i] = 0;    //dp[i][j] = 0 when i == j
10         
11         for(int i = n; i >= 1; i--)
12             for(int j = i; j <= n; j++){
13                 if(i == j){
14                     dp[i][j] = 0;
15                 }else{
16                     for(int k = i; k < j; k++){
17                         int count = dp[i][k] + dp[k+1][j] + p[i-1]*p[k]*p[j];
18                         if(count < dp[i][j]){
19                             dp[i][j] = count; //record optimal solution value
20                             s[i][j] = k;      //record splitting point k
21                         }
22                     }
23                 }
24             }
25         return dp[1][n];
26     }

運行一個例子:

即輸入的數組p爲{30,35,15,5,10,20,25}。

若是在return以前打印出dp[][]和s[][]的值,結果爲:

      

從左圖可看出最優解爲dp[1][6] = 15,125,即最少能夠進行一萬五千屢次乘法。右圖記錄了對於每個[i,j]決定的子矩陣鏈如何進行括號分割。

順便分享一個ArrayPrinter的util,能夠直接用,能打印出上圖那樣的二維int數組。

 1 public class ArrayPrinter {
 2     public static void print(int[] arr){
 3         printReplacing(false, arr, 0,"");
 4     }
 5     
 6     public static void print(int[][] matrix){
 7         printReplacing(false, matrix, 0,"");
 8     }
 9     
10     public static void printReplacing(int[] arr, int before, String after){
11         printReplacing(true, arr, before, after);
12     }
13     
14     public static void printReplacing(int[][] matrix, int before, String after){
15         printReplacing(true, matrix, before, after);
16     }
17     
18     /*--------------------------private utils-------------------------------*/
19     
20     private static void printReplacing(boolean replace, int[] arr, int before, String after){
21         int maxLen = maxLength(arr);
22         if(replace){
23             for(int i : arr)
24                 print(((i==before)?after:number(i)), maxLen);
25         }else{
26             for(int i : arr)
27                 print(number(i), maxLen);
28         }
29         print("\n", maxLen);
30     }
31     
32     public static void printReplacing(boolean replace, int[][] matrix, int before, String after){
33         int maxLen = maxLength(matrix);
34         if(replace){
35             for(int[] row : matrix){
36                 for(int i : row)
37                     print(((i==before)?after:number(i)), maxLen);
38                 print("\n", maxLen);
39             }
40         }else{
41             for(int[] row : matrix){
42                 for(int i : row)
43                     print(number(i), maxLen);
44                 print("\n", maxLen);
45             }
46         }
47     }
48 
49     private static int maxLength(int[] arr){
50         int maxLen = 0;
51         for(int aint : arr)
52             maxLen = Math.max(Integer.toString(aint).length(), maxLen);
53         return maxLen;
54     }
55     
56     private static int maxLength(int[][] matrix){
57         int maxLen = 0;
58         for(int row[] : matrix)
59             maxLen = Math.max(maxLength(row), maxLen);
60         return maxLen;
61     }
62     
63     //actual printing 
64     private static void print(String s, int length){
65         System.out.print(String.format("%1$"+(length+1)+"s", s));
66     }
67     
68     //formatting of number
69     private static String number(int i){
70         return NumberFormat.getNumberInstance(Locale.US).format(i);
71     } 
72 }
ArrayPrinter

使用方法:

1 ArrayPrinter.printReplacing(dp, Integer.MAX_VALUE, "/");
2 ArrayPrinter.print(s);

第四步:根據計算好的信息構造最優解

還差一步就大功告成。這一步咱們要拿着上一步計算出的矩陣s把最終的全括號矩陣乘積打印出來。遞歸打印便可。

 1     private void printParenthesis(int[][] s, int i, int j) {
 2         if(i == j)
 3             print("A"+i);
 4         else{
 5             print("(");
 6             printParenthesis(s, i, s[i][j]);
 7             printParenthesis(s, s[i][j]+1, j);
 8             print(")");
 9         }
10     }

打印結果:

複雜度

前面說過,窮舉法的複雜度大概是O(2n)。在以上的dp算法中,主算法須要填滿一個(n+1)×(n+1)的二維數組的上半部分,每填一個元素須要一個長度爲j-i的循環,可經過這個思路對j-i進行求和(i=0...n, j=i...n),也能夠經過大概估算獲得時間複雜度爲O(n3),遠好於窮舉法。

空間複雜度主要由二維數組決定,爲O(n2)。

總結

本文主要介紹瞭解一個dp問題的思路。

dp問題通常有兩個顯著特色,這一點下一篇會詳細講述:

  • 問題的最優解由子問題的最優解構成
  • 子問題互相重疊

也再複習一下解題的四個步驟,看你如今有沒有更深入的理解:

  1. 分析最優解的特徵。               (分析最優子解如何構成最優解)
  2. 遞歸地定義最優解的值。               (畫遞歸樹,定義子問題,寫狀態轉移方程)
  3. 計算最優解的值。                        (寫代碼求出最優解,若是有要求的話,記錄額外信息,爲第4步做準備)
  4. 根據計算好的信息構造最優解。       (從第3步記錄的信息中構建最優解,在本題中就是括號的寫法)

參考資料

算法導論(英文版)3rd Ed. 15.2

相關文章
相關標籤/搜索