DQN(Deep Q-learning)入門教程(四)之Q-learning Play Flappy Bird

在上一篇博客中,咱們詳細的對Q-learning的算法流程進行了介紹。同時咱們使用了\(\epsilon-貪婪法\)防止陷入局部最優。html

那麼咱們能夠想一下,最後咱們獲得的結果是什麼樣的呢?由於咱們考慮到了全部的(\(\epsilon-貪婪法\)致使的)狀況,所以最終咱們將會獲得一張以下的Q-Table表。python

Q-Table \(a_1\) \(a_2\)
\(s_1\) \(q(s_1,a_1)\) \(q(s_1,a_2)\)
\(s_2\) \(q(s_2,a_1)\) \(q(s_2,a_2)\)
\(s_3\) \(q(s_3,a_1)\) \(q(s_3,a_2)\)

當agent運行到某一個場景\(s\)時,會去查詢已經訓練好的Q-Table,而後從中選擇一個最大的\(q\)對應的action。git

訓練內容

這一次,咱們將對Flappy-bird遊戲進行訓練。這個遊戲的介紹我就很少說了,能夠看一下維基百科的介紹。es6

遊戲就是控制一隻🐦穿越管道,而後能夠得到分數,對於小鳥來講,他只有兩個動做,跳or不跳,而咱們的目標就是使小鳥穿越管道得到更多的分數。github

前置準備

由於咱們的目標是來學習「強化學習」的,因此咱們不可能說本身去弄一個Flappy-bird(固然本身弄也能夠),這裏咱們直接使用一個已經寫好的Flappy-bird。算法

PyGame-Learning-Environment,是一個Python的強化學習環境,簡稱PLE,下面時他Github上面的介紹:app

PyGame Learning Environment (PLE) is a learning environment, mimicking the Arcade Learning Environment interface, allowing a quick start to Reinforcement Learning in Python. The goal of PLE is allow practitioners to focus design of models and experiments instead of environment design.less

PLE hopes to eventually build an expansive library of games.dom

而後關於FlappyBird的文檔介紹在這裏,文檔的介紹仍是蠻清楚的。安裝步驟以下所示,推薦在Pipenv的環境下安裝,不過你也能夠直接clone個人代碼而後而後根據reademe的步驟進行使用。函數

git clone https://github.com/ntasfi/PyGame-Learning-Environment.git
cd PyGame-Learning-Environment/
pip install -e .

須要的庫以下:

  • pygame
  • numpy
  • pillow

函數說明

官方文檔有幾個的函數在這裏說下,由於等下咱們須要用到。

  • getGameState():得到遊戲當前的狀態,返回值爲一個字典:

    1. player y position.
    2. players velocity.
    3. next pipe distance to player
    4. next pipe top y position
    5. next pipe bottom y position
    6. next next pipe distance to player
    7. next next pipe top y position
    8. next next pipe bottom y position

    部分數據表示以下:

  • reset_game():從新開始遊戲

  • act(action):在遊戲中執行一個動做,參數爲動做,返回執行後的分數。

  • game_over():假如遊戲結束,則返回True,否者返回False。

  • getActionSet():得到遊戲的動做集合。

咱們的窗體大小默認是288*512,其中鳥的速度在-20到10之間(最小速度我並不知道,可是通過觀察,並無小於-20的狀況,而最大的速度在源代碼裏面已經說明好了爲10)

Coding Time

在前面咱們說,經過getGameState()函數,咱們能夠得到幾個關於環境的數據,在這裏咱們選擇以下的數據:

  • next_pipe_dist_to_player:
  • player_y與next_pipe_top_y的差值
  • 🐦的速度

可是咱們能夠想想,next_pipe_dist_to_player一共會有多少種的取值:由於窗體大小爲288*512,則取值的範圍大約是0~288,也就是說它大約有288個取值,而關於player_y與next_pipe_top_y的差值,則大概有1024個取值。這樣很難讓模型收斂,所以咱們將數值進行簡化。其中簡化的思路來自:GitHub

首先咱們建立一個Agent類,而後逐漸向裏面添加功能。

class Agent():

    def __init__(self, action_space):
        # 得到遊戲支持的動做集合
        self.action_set = action_space

        # 建立q-table
        self.q_table = np.zeros((6, 6, 6, 2))

        # 學習率
        self.alpha = 0.7
        # 勵衰減因子
        self.gamma = 0.8
        # 貪婪率
        self.greedy = 0.8

至於爲何q-table的大小是(6,6,6,2),其中的3個6分別表明next_pipe_dist_to_playerplayer_y與next_pipe_top_y的差值🐦的速度,其中的2表明動做的個數。也就是說,表格中的state一共有$6 \times6 \times 6 $種,表格的大小爲\(6 \times6 \times 6 \times 2\)

縮小狀態值的範圍

咱們定義一個函數get_state(s),這個函數專門提取遊戲中的狀態,而後返回進行簡化的狀態數據:

def get_state(self, state):
        """
        提取遊戲state中咱們須要的數據
        :param state: 遊戲state
        :return: 返回提取好的數據
        """
        return_state = np.zeros((3,), dtype=int)
        dist_to_pipe_horz = state["next_pipe_dist_to_player"]
        dist_to_pipe_bottom = state["player_y"] - state["next_pipe_top_y"]
        velocity = state['player_vel']
        if velocity < -15:
            velocity_category = 0
        elif velocity < -10:
            velocity_category = 1
        elif velocity < -5:
            velocity_category = 2
        elif velocity < 0:
            velocity_category = 3
        elif velocity < 5:
            velocity_category = 4
        else:
            velocity_category = 5

        if dist_to_pipe_bottom < 8:  # very close or less than 0
            height_category = 0
        elif dist_to_pipe_bottom < 20:  # close
            height_category = 1
        elif dist_to_pipe_bottom < 50:  # not close
            height_category = 2
        elif dist_to_pipe_bottom < 125:  # mid
            height_category = 3
        elif dist_to_pipe_bottom < 250:  # far
            height_category = 4
        else:
            height_category = 5

        # make a distance category
        if dist_to_pipe_horz < 8:  # very close
            dist_category = 0
        elif dist_to_pipe_horz < 20:  # close
            dist_category = 1
        elif dist_to_pipe_horz < 50:  # not close
            dist_category = 2
        elif dist_to_pipe_horz < 125:  # mid
            dist_category = 3
        elif dist_to_pipe_horz < 250:  # far
            dist_category = 4
        else:
            dist_category = 5

        return_state[0] = height_category
        return_state[1] = dist_category
        return_state[2] = velocity_category
        return return_state

更新Q-table

更新的數學公式以下:

\[{\displaystyle Q^{new}(s_{t},a_{t})\leftarrow \underbrace {Q(s_{t},a_{t})} _{\text{舊的值}}+\underbrace {\alpha } _{\text{學習率}}\cdot \overbrace {{\bigg (}\underbrace {\underbrace {r_{t}} _{\text{獎勵}}+\underbrace {\gamma } _{\text{獎勵衰減因子}}\cdot \underbrace {\max _{a}Q(s_{t+1},a)} _{\text{estimate of optimal future value}}} _{\text{new value (temporal difference target)}}-\underbrace {Q(s_{t},a_{t})} _{\text{舊的值}}{\bigg )}} ^{\text{temporal difference}}} \]

下面是更新Q-table的函數代碼:

def update_q_table(self, old_state, current_action, next_state, r):
    """

    :param old_state: 執行動做前的狀態
    :param current_action: 執行的動做
    :param next_state: 執行動做後的狀態
    :param r: 獎勵
    :return:
    """
    next_max_value = np.max(self.q_table[next_state[0], next_state[1], next_state[2]])

    self.q_table[old_state[0], old_state[1], old_state[2], current_action] = (1 - self.alpha) * self.q_table[
        old_state[0], old_state[1], old_state[2], current_action] + self.alpha * (r + next_max_value)

選擇最佳的動做

而後咱們就是根據q-table對應的Q值選擇最大的那一個,其中第一個表明(也就是0)跳躍,第2個表明不執行任何操做。

選擇的示意圖以下:

代碼以下所示:

def get_best_action(self, state, greedy=False):
    """
    得到最佳的動做
    :param state: 狀態
    :是否使用ϵ-貪婪法
    :return: 最佳動做
    """
	
    # 得到q值
    jump = self.q_table[state[0], state[1], state[2], 0]
    no_jump = self.q_table[state[0], state[1], state[2], 1]
    # 是否執行策略
    if greedy:
        if np.random.rand(1) < self.greedy:
            return np.random.choice([0, 1])
        else:
            if jump > no_jump:
                return 0
            else:
                return 1
    else:
        if jump > no_jump:
            return 0
        else:
            return 1

更新\(\epsilon\)

這個比較簡單,從前面的博客中,咱們知道\(\epsilon\)是隨着訓練次數的增長而減小的,有不少種策略能夠選擇,這裏乘以\(0.95\)吧。

def update_greedy(self):
    self.greedy *= 0.95

執行動做

在官方文檔中,若是小鳥沒有死亡獎勵爲0,越過一個管道,獎勵爲1,死亡獎勵爲-1,咱們稍微的對其進行改變:

def act(self, p, action):
    """
    執行動做
    :param p: 經過p來向遊戲發出動做命令
    :param action: 動做
    :return: 獎勵
    """
    # action_set表示遊戲動做集(119,None),其中119表明跳躍
    r = p.act(self.action_set[action])
    if r == 0:
        r = 1
    if r == 1:
        r = 10
    else:
        r = -1000
    return r

main函數

最後咱們就能夠執行main函數了。

if __name__ == "__main__":
    # 訓練次數
    episodes = 2000_000000
    # 實例化遊戲對象
    game = FlappyBird()
    # 相似遊戲的一個接口,能夠爲咱們提供一些功能
    p = PLE(game, fps=30, display_screen=False)
    # 初始化
    p.init()
    # 實例化Agent,將動做集傳進去
    agent = Agent(p.getActionSet())
    max_score = 0
	
    for episode in range(episodes):
        # 重置遊戲
        p.reset_game()
        # 得到狀態
        state = agent.get_state(game.getGameState())
        agent.update_greedy()
        while True:
            # 得到最佳動做
            action = agent.get_best_action(state)
            # 而後執行動做得到獎勵
            reward = agent.act(p, action)
            # 得到執行動做以後的狀態
            next_state = agent.get_state(game.getGameState())
            # 更新q-table
            agent.update_q_table(state, action, next_state, reward)
            # 得到當前分數
            current_score = p.score()
            state = next_state
            if p.game_over():
                max_score = max(current_score, max_score)
                print('Episodes: %s, Current score: %s, Max score: %s' % (episode, current_score, max_score))
                # 保存q-table
                if current_score > 300:
                    np.save("{}_{}.npy".format(current_score, episode), agent.q_table)
                break

部分的訓練的結果以下:

總結

emm,說實話,我也不知道結果會怎麼樣,由於訓練的時間比較長,我不想放在個人電腦上面跑,而後我就放在樹莓派上面跑,可是樹莓派性能比較低,致使訓練的速度比較慢。可是,我仍是以爲個人方法有點問題,get_state()函數中簡化的方法,我感受不是特別的合理,若是各位有好的見解,能夠在評論區留言哦,而後共同窗習。

項目地址:https://github.com/xiaohuiduan/flappy-bird-q-learning

參考

相關文章
相關標籤/搜索