對抗搜索(一)

        最近看劉汝佳老師的《算法經典入門訓練指南》,搜相關的算法博客時,發現一本神書《人工智能  一種現代的方法》(如下簡稱《人工智能》),裏面囊括的算法,也讓我對算法有了新的認知。在學校時,相信你們都學過數據結構和算法,這些算法是你們接觸的最基礎的算法,再往上走,你們作人工智能,又涉及到機器學習和深度學習。鑑於本身的認知視野優先,總感受這兩類之間的算法中間跳過了不少東西。所以最近看算法時,難免要搜到不少博客,才知道算法的種類不少,每一個類下面又有不少分支,圖大類下的路徑搜索,約束知足問題,以及這篇要講的對抗搜索等等。東西太多可是在瞭解的過程當中,也解答了之前的一些疑惑,好比路徑搜索在遊戲中的應用,約束知足的應用,以及對抗搜索在棋牌類的應用。最後感嘆在校期間沒有參加ACM訓練,沒有抓住機會擴展本身的視野。強烈建議在校生有機會參加參加ACM訓練,拿不到獎牌也能夠擴展本身的視野。真的很重要。正文開始:java

      對抗搜索

        (本文不少東西都是參考上面兩本書,文末會貼上本身的代碼。)算法

        相信你們在網上或多或少的都玩過不少對抗類遊戲,好比五子棋、象棋、國際象棋、圍棋等。初期時,可能不少時候是和電腦進行人機聯繫。電腦方的出牌策略就是應用了不少對抗搜索的算法(對抗類遊戲可能有多我的參與,本文只討論二人對抗的遊戲,多人對抗的請參考《人工智能》這本書)。數組

        這些問題很是難於求解,例如國際象棋的平均分支因子大約是35,一盤棋通常每一個遊戲者走50步,因此搜索樹大約有35100即10154個結點(儘管整個狀態空間「只」約1040個不一樣結點)。和現實世界同樣,遊戲要求即便沒法找到最優決策也必須能作某種決策,而不能花費太多的時間。換句話說,這些遊戲有嚴格的時間限制(time limit)。因此對博弈的研究也產生了一些有趣的思想,如何儘量充分的利用好時間。數據結構

        咱們知道對抗類遊戲是參與者輪流出招,咱們能夠將其寫成相似於路徑尋找的問題:app

        初始狀態:包括棋盤局面和肯定該哪一個遊戲者出招
        後繼函數:返回(move, state)列表,每一項表示一個合法招數和對應的結果狀態。
        終止測試:判斷遊戲是否結束。遊戲結束的狀態稱爲終止狀態。
        效用函數:也稱目標函數或收益函數,是終止狀態的得分。國際象棋中贏、輸、平分別是1,-1和0分,而圍棋、黑白棋等能夠有更多的結果。機器學習

        考慮一個簡單的遊戲:井字棋。如今有兩個參與者MAX和MIN,在3×3的棋盤上,MAX劃叉,MIN劃圓圈。任何一種圖案佔據了一行或者一列或者一整條斜對角線(主副對角線),那麼斷定相應的遊戲者獲勝。以下圖(摘自《人工智能》)。初始狀態棋盤爲空,而後依次由MAX、MIN方輪流走,這樣就造成了一顆相似於搜索樹的博弈樹。數據結構和算法

           

       這個圖列舉了雙方能接受的全部選擇。咱們能夠看到只有葉子節點纔有評價函數,能夠看到從根節點到葉子節點是雙方按照當前路徑走下來的最終結果(贏、輸、平局)。每條路徑都對應一個結果,雙反不論在何時,確定都要選擇「最利於」本身獲勝的步驟。此時的核心問題就是在每一步的時候,MAX/MIN如何來設計評價函數來選擇「最利於」本身的下棋步驟。好比在第一步時,MAX到底該如何得知本身要選擇9個選擇中的哪個。函數

        這裏採用極大極小值方法: 對MAX方來講,評價函數越大越好,而對MIN方來講,評價函數越小越好。也就是在每一步中,MAX方選擇全部節點中評價函數最大的節點,做爲本身當前的落棋選擇,而MIN則相反。若是一個MIN結點有三個兒子,評價值分別爲3,4,-1。最聰明的對手必定會選擇那個-1的兒子(這樣對MAX最不利),而若是對手並無發現這個走步(或者並不以爲它的後繼狀態對MAX最不利),它可能選擇的是3或者4。學習

        惋惜因爲博弈樹太大,若是要直接追蹤到最終狀態,這對於計算機來講也是一個超大的負荷,所以合理的方案是在固定深度截斷,在這個深度內的「葉子節點」雙發按照極大極小值方法來選擇本身每一步的落棋選擇。對於井字棋遊戲,一個可能的評價函數是:
      e(s) = (MAX可能佔有的行/列/對角線數) - (MIN可能佔有的行/列/對角線數)測試

其中「可能佔有」的意思是「此行/列/對角線」不含對方的符號。更復雜的評價函數每每是對各類特徵進行加權計算。 下圖是深度爲2時的評價函數計算。

能夠驗證對max的第一步來講,選擇走中間那個節點是最優的選擇。若是此時MIN選擇走第一行正中,那麼此時節點的部分搜索樹以下。

剛纔所述的算法成爲MAXMIN算法,咱們採起遞歸的計算方式來描述整個算法:

int max_value ( int dep , state s ){
    if ( terminal ( s )) return e ( s ); //終止狀態
    if ( dep == maxdepth ) return e ( s ); //深度截斷,返回評價函數
    v = - inf ; //初始化爲負無窮
    succ = make_successors ( s ); // succ [ i爲第]個後繼狀態i
    for ( i = 0; i < succ . count ; i ++)
        v = max (v , min_value ( succ [ i ])); //計算全部兒子的最大值
    return v ;
}
int min_value ( int dep , state s ){
    if ( terminal ( s )) return e ( s ); //終止狀態
    if ( dep == maxdepth ) return e ( s ); //深度截斷,返回評價函數
    v = inf ; //初始化爲無窮大
    succ = make_successors ( s ); // succ [ i爲第]個後繼狀態i
    for ( i = 0; i < succ . count ; i ++)
        v = min (v , max_value ( succ [ i ])); //計算全部兒子的最小值(劉汝佳老師的書中是錯的)
    return v ;
}

 

文末附上「井字棋」的完整JAVA代碼

package search;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

import javax.xml.parsers.FactoryConfigurationError;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 對抗搜索-井字棋遊戲
 */
public class Search_Ant {
    private static int n = 3;
    private static final int maxPlayer = 1;
    private static final int minPlayer = 2;
    private static int maxDepth = 3; //對抗預測的最大深度
    private static int[][] initArray = new int[n][n];
    private static int depth = 0;
    private static int alpha = -10000;
    private static int beta = 10000;


    public static void main(String[] args) {
        State initState = new State();
        initState.setCurrentState(initArray);
        antSearch(initState, depth);
    }

    /**
     * 開始對抗搜索
     * 偶數max執行;
     * 奇數min執行;
     * 共對抗執行的最大次數: n * n ;
     */
    public static void antSearch(State state, int depth) {
        if(isSuccess(state)){
            System.out.println("某一方成功贏了");
            return;
        }
        if (depth >= n * n) {
            System.out.println("雙方平局");
            return;
        }

        //偶數次由max方走,奇數次由min方走;
        if (depth % 2 == 0) {
            maxValue(state, 0);
            int rowIndex = state.getNextBestState().getRowIndex();
            int columnIndex = state.getNextBestState().getColumnIndex();
            state.setRowIndex(rowIndex);
            state.setColumnIndex(columnIndex);
            state.getCurrentState()[rowIndex][columnIndex] = 1;
            display(state, maxPlayer, depth);
//            System.out.println(String.format("max方執行: (%s,%s)", state.getRowIndex(), state.getColumnIndex()));
        } else if (depth % 2 == 1) {
            minValue(state, 0);
            int rowIndex = state.getNextBestState().getRowIndex();
            int columnIndex = state.getNextBestState().getColumnIndex();
            state.setRowIndex(rowIndex);
            state.setColumnIndex(columnIndex);
            state.getCurrentState()[rowIndex][columnIndex] = 2;
            display(state, minPlayer, depth);
//            System.out.println(String.format("min方執行: (%s,%s)", state.getRowIndex(), state.getColumnIndex()));
        }
//        ++depth;
//        開始下一次迭代
//        antSearch(state, depth);
    }

    /**
     * 得到當前狀態的後繼節點
     *
     * @param currentState
     * @return
     */
    public static List<State> getSuccessor(State currentState, int player) {
        List<State> successorList = new ArrayList<State>();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (currentState.getCurrentState()[i][j] == 0) {
                    int[][] array = copyArray(currentState.getCurrentState());
                    if (player == 1) {
                        array[i][j] = 1;
                    } else {
                        array[i][j] = 2;
                    }
                    State nextState = new State(i, j, array);
                    successorList.add(nextState);
                }
            }
        }
        return successorList;
    }

    public static int[][] copyArray(int[][] array) {
        int[][] copyArray = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                copyArray[i][j] = array[i][j];
            }
        }
        return copyArray;
    }

    /**
     * max方執行
     *
     * @return
     */
    public static int maxValue(State state, int currentDepth) {
        int[][] currentState = state.getCurrentState();
        if (currentDepth >= maxDepth) {
            //計算差值
            return evalFunction(currentState);
        }
        if (isSuccess(state)) {
            //計算差值
            return evalFunction(currentState);
        }

        List<State> successor = getSuccessor(state, 1);
        int v = -10000;
        int target = -1;
        for (int i = 0; i < successor.size(); i++) {
            //這裏須要優化,Alpha-Beta剪枝,縮小檢索空間
            int value = minValue(successor.get(i), currentDepth + 1);
            if (v < value) {
                v = value;
                target = i;
            }
//            System.out.println(String.format("depth=%s, value=%s, maxPlayer,數組: %s", currentDepth, value,display((successor.get(i)))));
        }
        if (target == -1) { //已經此時是平局
            return 0;
        }
        System.out.println(String.format("depth=%s, value=%s, maxPlayer,數組: %s", currentDepth, v,display((successor.get(target)))));
        state.setNextBestState(successor.get(target));

        return v;
    }

    /**
     * min方執行
     *
     * @return
     */
    public static int minValue(State state, int currentDepth) {
        int[][] currentState = state.getCurrentState();
        if (currentDepth >= maxDepth) {
            //計算差值
            return evalFunction(currentState);
        }
        if (isSuccess(state)) {
            //計算差值
            return evalFunction(currentState);
        }

        List<State> successor = getSuccessor(state, 2);
        int v = 10000;
        int target = -1;
        for (int i = 0; i < successor.size(); i++) {
            int value = maxValue(successor.get(i), currentDepth + 1);
            //這裏須要優化,Alpha-Beta剪枝,縮小檢索空間
            if (v > value) {
                target = i;
                v = value;
            }
//            System.out.println(String.format("depth=%s, value=%s, minPlayer,數組: %s", currentDepth,value,display((successor.get(i)))));
        }
        if (target == -1) { //已經此時是平局
            return 0;
        }
        System.out.println(String.format("depth=%s, value=%s, minPlayer,數組: %s", currentDepth, v,display((successor.get(target)))));
        state.setNextBestState(successor.get(target));
        return v;
    }

    /**
     * 當前狀態的評估函數
     *
     * @param currentState
     * @return
     */
    public static int evalFunction(int[][] currentState) {
        int minPlayerResult = getPlayerOccupy(currentState, maxPlayer);
        int maxPlayerResult = getPlayerOccupy(currentState, minPlayer);

        return maxPlayerResult - minPlayerResult;
    }

    /**
     * 得到某一方所佔用的座標
     *
     * @param currentState
     * @return
     */
    public static int getPlayerOccupy(int[][] currentState, int palyer) {
        Set<Integer> rowOccupy = new HashSet<Integer>();
        Set<Integer> columnOccupy = new HashSet<Integer>();
        boolean mainDiag = false;
        boolean viceDiag = false;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (currentState[i][j] == palyer) {
                    rowOccupy.add(i);
                    columnOccupy.add(j);
                    if (i == j) { //在主對角線上
                        mainDiag = true;
                    }
                    if (i + j == n - 1) {//在副對角線上
                        viceDiag = true;
                    }
                }

            }
        }

        int result = 0;
        result += n - rowOccupy.size();
        result += n - columnOccupy.size();
        result += mainDiag ? 0 : 1;
        result += viceDiag ? 0 : 1;

        return result;
    }

    /**
     * 判斷當前狀態是否能夠判斷某一方已經勝利
     *
     * @return
     */
    public static boolean isSuccess(State state) {
        if (isRowSame(state)
                || isColumnSame(state)
                || isMainDiagSame(state)
                || isViceDiagSame(state)) {
            return true;
        }
        return false;
    }

    /**
     * 一行爲相同
     *
     * @return
     */
    public static boolean isRowSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int preValue = currentState[rowIndex][0];
        if (preValue == 0) {
            return false;
        }
        for (int i = 1; i < n; i++) {
            if (currentState[rowIndex][i] != preValue) {
                return false;
            }
        }
        return true;
    }

    /**
     * 列相同
     *
     * @return
     */
    public static boolean isColumnSame(State state) {
        int[][] currentState = state.getCurrentState();
        int columnIndex = state.getColumnIndex();
        int preValue = currentState[0][columnIndex];
        if (preValue == 0) {
            return false;
        }
        for (int i = 1; i < n; i++) {
            if (currentState[i][columnIndex] != preValue) {
                return false;
            }
        }
        return true;
    }

    /**
     * 主對角線是否相同
     *
     * @return
     */
    public static boolean isMainDiagSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int columnIndex = state.getColumnIndex();
        if (rowIndex == columnIndex) {
            int preValue = currentState[0][0];
            if (preValue == 0) {
                return false;
            }
            for (int i = 1; i < n; i++) {
                if (currentState[i][i] != preValue) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    /**
     * 副對角線是否相同
     *
     * @return
     */
    public static boolean isViceDiagSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int columnIndex = state.getColumnIndex();
        if (rowIndex + columnIndex == n - 1) {
            int preValue = currentState[0][n - 1];
            if (preValue == 0) {
                return false;
            }
            int m = 0;
            int k = n - 1;
            for (int i = 1; i < n; i++) {
                if (currentState[m + i][k - i] != preValue) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    public static void display(State state, int player, int depth) {
        int[][] array = state.getCurrentState();
        System.out.println("===============================================");
        System.out.println(String.format("第%s步: 當前方以及走的座標: %s --> (%s,%s)", depth, player, state.getRowIndex(), state.getColumnIndex()));
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(array[i][j] + " ");
            }
            System.out.println();
        }
    }
    public static String display(State state) {
        StringBuffer buffer = new StringBuffer();
        int[][] array = state.getCurrentState();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
//                System.out.print(array[i][j]+" ");
                buffer.append(array[i][j] + " ");
            }
        }
        return buffer.toString();
    }

}

@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
class State {
    private int rowIndex;
    private int columnIndex;
    private int[][] currentState;
    private State nextBestState; //當前最好的狀態


    public State(int i, int j, int[][] array) {
        this.rowIndex = i;
        this.columnIndex = j;
        this.currentState = array;
    }
}
相關文章
相關標籤/搜索