A3C——一種異步強化學習方法

目錄
python

一、簡介二、算法細節三、代碼3.1 主結構3.2 Actor Critic 網絡3.3 Worker3.4 Worker並行工做四、參考git

一、簡介

A3C是Google DeepMind 提出的一種解決Actor-Critic不收斂問題的算法。咱們知道DQN中很重要的一點是他具備經驗池,能夠下降數據之間的相關性,而A3C則提出下降數據之間的相關性的另外一種方法:異步github

簡單來講:A3C會建立多個並行的環境, 讓多個擁有副結構的 agent 同時在這些並行環境上更新主結構中的參數. 並行中的 agent 們互不干擾, 而主結構的參數更新受到副結構提交更新的不連續性干擾, 因此更新的相關性被下降, 收斂性提升.
web

二、算法細節

A3C的算法實際上就是將Actor-Critic放在了多個線程中進行同步訓練. 能夠想象成幾我的同時在玩同樣的遊戲, 而他們玩遊戲的經驗都會同步上傳到一箇中央大腦. 而後他們又從中央大腦中獲取最新的玩遊戲方法。算法

這樣, 對於這幾我的, 他們的好處是: 中央大腦聚集了全部人的經驗, 是最會玩遊戲的一個, 他們能時不時獲取到中央大腦的必殺招, 用在本身的場景中.緩存

對於中央大腦的好處是: 中央大腦最怕一我的的連續性更新, 不僅基於一我的推送更新這種方式能打消這種連續性. 使中央大腦沒必要像DQN,DDPG那樣的記憶庫也能很好的更新。
微信


爲了達到這個目的,咱們要有兩套體系, 能夠看做中央大腦擁有 global net和他的參數, 每位玩家有一個 global net的副本 local net, 能夠定時向 global net推送更新, 而後定時從 global net那獲取綜合版的更新.

若是在 tensorboard 中查看咱們今天要創建的體系, 這就是你會看到的。

W_0就是第0個 worker, 每一個 worker均可以分享 global_net


若是咱們調用 sync中的 pull, 這個 worker就會從 global_net中獲取到最新的參數.


若是咱們調用sync中的push, 這個worker就會將本身的我的更新推送去global_net.
網絡

三、代碼

此次咱們也是使用連續動做環境Pendulum作例子。
app

3.1 主結構


咱們使用了 Normal distribution 來選擇動做, 因此在搭建神經網絡的時候, actor這邊要輸出動做的均值和方差. 而後放入 Normal distribution 去選擇動做. 計算 actor loss的時候咱們還須要使用到 critic提供的 TD error做爲 gradient ascent 的導向.

critic只須要獲得他對於 state的價值就行了. 用於計算 TD error.

3.2 Actor Critic 網絡

這裏由於代碼有點多,有些部分會使用僞代碼,完整代碼最後會附上連接。異步

咱們將ActorCritic合併成一整套系統, 這樣方便運行.

 1# 這個 class 能夠被調用生成一個 global net.
2# 也能被調用生成一個 worker 的 net, 由於他們的結構是同樣的,
3# 因此這個 class 能夠被重複利用.
4class ACNet(object):
5    def __init__(self, globalAC=None):
6        # 當建立 worker 網絡的時候, 咱們傳入以前建立的 globalAC 給這個 worker
7        if 這是 global:   # 判斷當下創建的網絡是 local 仍是 global
8            with tf.variable_scope('Global_Net'):
9                self._build_net()
10        else:
11            with tf.variable_scope('worker'):
12                self._build_net()
13
14            # 接着計算 critic loss 和 actor loss
15            # 用這兩個 loss 計算要推送的 gradients
16
17            with tf.name_scope('sync'):  # 同步
18                with tf.name_scope('pull'):
19                    # 更新去 global
20                with tf.name_scope('push'):
21                    # 獲取 global 參數
22
23    def _build_net(self):
24        # 在這裏搭建 Actor 和 Critic 的網絡
25        return 均值, 方差, state_value
26
27    def update_global(self, feed_dict):
28        # 進行 push 操做
29
30    def pull_global(self):
31        # 進行 pull 操做
32
33    def choose_action(self, s):
34        # 根據 s 選動做

這些只是在建立網絡而已,worker還有屬於本身的class, 用來執行在每一個線程裏的工做.

3.3 Worker

每一個worker有本身的class, class 裏面有他的工做內容work

 1class Worker(object):
2    def __init__(self, name, globalAC):
3        self.env = gym.make(GAME).unwrapped # 建立本身的環境
4        self.name = name    # 本身的名字
5        self.AC = ACNet(name, globalAC) # 本身的 local net, 並綁定上 globalAC
6
7    def work(self):
8        # s, a, r 的緩存, 用於 n_steps 更新
9        buffer_s, buffer_a, buffer_r = [], [], []
10        while not COORD.should_stop() and GLOBAL_EP < MAX_GLOBAL_EP:
11            s = self.env.reset()
12
13            for ep_t in range(MAX_EP_STEP):
14                a = self.AC.choose_action(s)
15                s_, r, done, info = self.env.step(a)
16
17                buffer_s.append(s)  # 添加各類緩存
18                buffer_a.append(a)
19                buffer_r.append(r)
20
21                # 每 UPDATE_GLOBAL_ITER 步 或者回合完了, 進行 sync 操做
22                if total_step % UPDATE_GLOBAL_ITER == 0 or done:
23                    # 得到用於計算 TD error 的 下一 state 的 value
24                    if done:
25                        v_s_ = 0   # terminal
26                    else:
27                        v_s_ = SESS.run(self.AC.v, {self.AC.s: s_[np.newaxis, :]})[00]
28
29                    buffer_v_target = []    # 下 state value 的緩存, 用於算 TD
30                    for r in buffer_r[::-1]:    # 進行 n_steps forward view
31                        v_s_ = r + GAMMA * v_s_
32                        buffer_v_target.append(v_s_)
33                    buffer_v_target.reverse()
34
35                    buffer_s, buffer_a, buffer_v_target = np.vstack(buffer_s), np.vstack(buffer_a), np.vstack(buffer_v_target)
36
37                    feed_dict = {
38                        self.AC.s: buffer_s,
39                        self.AC.a_his: buffer_a,
40                        self.AC.v_target: buffer_v_target,
41                    }
42
43                    self.AC.update_global(feed_dict)    # 推送更新去 globalAC
44                    buffer_s, buffer_a, buffer_r = [], [], []   # 清空緩存
45                    self.AC.pull_global()   # 獲取 globalAC 的最新參數
46
47                s = s_
48                if done:
49                    GLOBAL_EP += 1  # 加一回合
50                    break   # 結束這回合

3.4 Worker並行工做

這裏是重點,也就是Worker並行工做的計算

 1    GLOBAL_AC = ACNet(GLOBAL_NET_SCOPE)  # 創建 Global AC
2    workers = []
3    for i in range(N_WORKERS):  # 建立 worker, 以後在並行
4        workers.append(Worker(GLOBAL_AC))   # 每一個 worker 都有共享這個 global AC
5
6COORD = tf.train.Coordinator()  # Tensorflow 用於並行的工具
7
8worker_threads = []
9for worker in workers:
10    job = lambda: worker.work()
11    t = threading.Thread(target=job)    # 添加一個工做線程
12    t.start()
13    worker_threads.append(t)
14COORD.join(worker_threads)  # tf 的線程調度

電腦裏CPU有幾個核就能夠創建多少個worker, 也就能夠把它們放在CPU核數個線程中並行探索更新. 最後的學習結果能夠用這個獲取 moving average 的 reward 的圖來歸納.

完整代碼連接:

https://github.com/cristianoc20/RL_learning/tree/master/A3C

四、參考

  1. https://medium.com/emergent-future/simple-reinforcement-learning-with-tensorflow-part-8-asynchronous-actor-critic-agents-a3c-c88f72a5e9f2

  2. https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/6-3-A3C/


本文分享自微信公衆號 - 計算機視覺漫談()。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索