動手實現並行版AlphaZero五子棋

前言

項目連接:github.com/hijkzzz/alp…html

AlphaZero算法已經發布了一年多了,GitHub也有各類各樣的實現,有一千行Python代碼單線程低性能版,也有數萬行C++代碼的分佈式版本。可是這些實現都不能知足通常的算法愛好者的需求,即一個簡單的而且單機的可運行的高性能AlphaZero算法。前端

一圖解密AlphaZero

首先咱們經過一張圖瞭解一下AlphaZero算法的原理git

能夠看到AlaphaGo Zero的算法流程分爲:github

  1. 自對弈(利用蒙特卡洛樹搜索)N局生成棋譜
  2. 利用生成的棋譜訓練網絡
  3. 評估新訓練的網絡

分析

對於Python版本的AlphaZero算法,一般受限制於GIL,過程當中最耗時間的自對弈階段(見下圖)沒法並行化,因此最直接的優化方式是使用C++這種高性能語言實現底層運算細節。算法

解決方法

線程池

github.com/hijkzzz/alp…網絡

爲了並行化自對弈過程,首先咱們須要實現一個C++的線程池。關於線程池網上有不少的資料能夠參考,這裏就很少作敘述。多線程

Root Parallelization

從算法流程圖中能夠看到,自對弈過程使用蒙特卡洛樹搜索實現,因此有兩個維度能夠並行化自對弈:Root Parallelization和Tree Parallelization。其中Root Parallelization指的是同時開啓N局對弈,每一個線程負責一局遊戲。Tree Parallelization指的是把單局遊戲中的蒙特卡洛樹搜索(MCTS)並行化。因而用N個線程就很容易實現Root Parallelization,下面咱們討論Tree Parallelization。分佈式

Tree Parallelization

首先分析一下蒙特卡洛樹搜索(MCTS)的運行過程:性能

每執行一步棋子,MCTS要執行M次落子模擬,每次模擬就是一次遞歸過程,以下:測試

  1. Select,若是當前節點不是葉子節點則經過特定的UCT算法(探索-利用算法,經過神經網絡預測的勝率值(q值)以及先驗機率計算選擇機率,勝率/先驗機率越高選擇概率越大)找出最優的下一個落子位置,搜索進入下一層,直到當前節點是葉子節點。

  2. Expand and evaluate,若是當前節點是葉子節點,這裏分爲兩種狀況:

    • 當前節點遊戲結束,某一方獲勝,則進行Backup向上回溯更新父節點的勝率值
    • 若是遊戲沒有結束,則用神經網絡預測當前節點的勝率和下一層的先驗機率,用這個先驗機率展開此節點,而後進行Backup向上回溯更新父節點的勝率值(q值)
  3. Backup,每一個節點保存一個勝率值(q值),q值等於贏的次數/訪問次數,backup從結束狀態向上更新這個值以及訪問次數。

  4. Play,實際遊戲中落子的時候選擇根節點下訪問次數最多的子節點便可(由於q值越大的節點select的機率越大,訪問次數也越多)。

因此咱們能夠同時進行M'(小於M)次模擬,因此對一些關鍵數據就要加鎖,好比蒙特卡洛樹的父子節點關係,訪問次數,q值等。也有人研發出了一些無鎖的算法[5],可是由於預先分配樹節點的關係,對內存的佔用量極大,通常的機器跑不起來,因此這裏用的是加鎖版的並行蒙特卡洛樹搜索。

Virtual Loss

對於Tree Parallelization,若是咱們簡單的把蒙特卡洛搜索(MCTS)並行化,那麼會遇到一個問題:M'個線程常常會搜索同一個節點,這樣咱們的並行化就失去了意義,由於搜索同一個節點意味着重複工做。因此在UCT算法中,當一個節點被一個線程訪問時,咱們加入一個Virtual Loss的懲罰,這樣其它線程就不太可能會選擇這個節點進行搜索。

LibTorch

由於MCTS的過程當中須要用到神經網絡預測勝率和先驗機率,因此C++須要調用Python實現的神經網絡預測方法,可是這樣又會回到原點。即Pyhton的GIL限制會致使並行化的自對弈被強制串行化執行。因此咱們使用PyTorch的C++前端LibTorch實現神經網絡。

CUDA Stream

工做後對於運行在GPU上的神經網絡來講,實際上咱們的程序仍是沒有真正的並行化。這是由於LibTorch的預測執行受限制於Default CUDA Steam,即默認是串行的,會致使多線程調用預測被阻塞。有兩個方法來避免這個問題:1. 用多個CUDA Stream 2.合併預測請求。這裏咱們使用的方法是用緩衝隊列合併多個預測,一次性推送到GPU,這樣就防止了GPU工做流的爭用致使線程阻塞。

SWIG

www.swig.org/tutorial.ht…

最後咱們把上述相關的C++代碼用SWIG封裝成Python接口,以供主程序調用。雖然這會致使一部分性能開銷,可是大大提升了開發的效率。

測試

通過測試,並行化後的訓練效率至少提高了10倍。簡單的計算一下,假設每一個MCTS用4個線程,同時玩4局遊戲,即4x4=16倍,考慮鎖和緩衝隊列以及Python接口的開銷,提高數量級是合理的。此外只要GPU足夠強悍,提高線程數還能繼續提升性能。最後我用了一天時間在一塊GTX1070上訓練了一個標準的15x15的五子棋算法,已經能夠完敗普通玩家。

參考文獻

  1. Mastering the Game of Go without Human Knowledge
  2. Mastering Chess and Shogi by Self-Play with a General Reinforcement Learning Algorithm
  3. Parallel Monte-Carlo Tree Search
  4. An Analysis of Virtual Loss in Parallel MCTS
  5. A Lock-free Multithreaded Monte-Carlo Tree Search Algorithm
相關文章
相關標籤/搜索