介紹強化學習(Reinforcement Learning,RL)的概念,並用DQN訓練一個會玩FlappyBird的模型python
這個遊戲不少人都玩過,很虐,如下是一個用pygame重現的FlappyBird,https://github.com/sourabhv/FlapPyBirdgit
若是沒有pygame則安裝github
pip install pygame
運行flappy.py
便可開始遊戲,若是出現按鍵沒法控制的狀況,用pythonw
運行代碼便可算法
pythonw flappy.py
無監督學習沒有標籤,例如聚類;有監督學習有標籤,例如分類;而強化學習介於二者之間,標籤是經過不斷嘗試積累的網絡
RL包括幾個組成部分:app
這樣一來,遊戲的進行過程,無非是從一個初始S開始,執行A、獲得R、進入下一個S,如此往復,直到進入一個終止Sdom
$$ s_0,a_0,r_1,s_1,a_1,r_2,s_2,...,s_{n-1},a_{n-1},r_n,s_n $$ide
定義一個函數,用來計算遊戲過程當中回報的總和函數
$$ R=r_1+r_2+r_3+...+r_n $$學習
以及從某個時刻開始以後的回報總和
$$ R_t=r_t+r_{t+1}+r_{t+2}+...+r_n $$
但咱們對將來每一步能獲取的回報並非徹底確定的,因此不妨乘上一個0到1之間的衰減係數
$$ R_t=r_t+\gamma r_{t+1}+\gamma^2 r_{t+2}+...+\gamma^{n-t} r_n $$
這樣一來,能夠獲得相鄰兩步總回報之間的遞推關係
$$ R_t=r_t+\gamma R_{t+1} $$
DQN是強化學習中的一種經常使用算法,主要是引入了Q函數(Quality,價值函數),用於計算在某個S下執行某個A能夠獲得的最大總回報
$$ Q(s_t,a_t)=\max R_{t+1} $$
有了Q函數以後,對於當前狀態S,只須要計算每個A對應的Q值,而後選擇Q值最大的一個A,即是最優的行動策略(策略函數)
$$ \pi(s)=argmax_a Q(s,a) $$
當Q函數收斂後,還能夠獲得Q函數的遞推公式
$$ Q(s_t,a_t)=r_t+\gamma \max Q(s_{t+1},a_{t+1}) $$
可使用神經網絡實現Q函數並訓練:
關於強化學習和DQN的原理介紹,能夠參考如下文章,https://ai.intel.com/demystifying-deep-reinforcement-learning/
基於如下項目進行修改,https://github.com/yenchenlin/DeepLearningFlappyBird
game
中的代碼對以前的flappy.py
進行了簡化和修改,去掉了背景圖並固定角色和水管顏色,遊戲會自動開始,掛掉以後也會自動繼續,主要是便於模型自動進行和採集數據
加載庫
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import random import cv2 import sys sys.path.append('game/') import wrapped_flappy_bird as fb from collections import deque
定義一些參數
ACTIONS = 2 GAMMA = 0.99 OBSERVE = 10000 EXPLORE = 3000000 INITIAL_EPSILON = 0.1 FINAL_EPSILON = 0.0001 REPLAY_MEMORY = 50000 BATCH = 32 IMAGE_SIZE = 80
定義一些網絡輸入和輔助函數,每個S由連續的四幀遊戲截圖組成
S = tf.placeholder(dtype=tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 4], name='S') A = tf.placeholder(dtype=tf.float32, shape=[None, ACTIONS], name='A') Y = tf.placeholder(dtype=tf.float32, shape=[None], name='Y') k_initializer = tf.truncated_normal_initializer(0, 0.01) b_initializer = tf.constant_initializer(0.01) def conv2d(inputs, kernel_size, filters, strides): return tf.layers.conv2d(inputs, kernel_size=kernel_size, filters=filters, strides=strides, padding='same', kernel_initializer=k_initializer, bias_initializer=b_initializer) def max_pool(inputs): return tf.layers.max_pooling2d(inputs, pool_size=2, strides=2, padding='same') def relu(inputs): return tf.nn.relu(inputs)
定義網絡結構,典型的卷積、池化、全鏈接層結構
h0 = max_pool(relu(conv2d(S, 8, 32, 4))) h0 = relu(conv2d(h0, 4, 64, 2)) h0 = relu(conv2d(h0, 3, 64, 1)) h0 = tf.contrib.layers.flatten(h0) h0 = tf.layers.dense(h0, units=512, activation=tf.nn.relu, bias_initializer=b_initializer) Q = tf.layers.dense(h0, units=ACTIONS, bias_initializer=b_initializer, name='Q') Q_ = tf.reduce_sum(tf.multiply(Q, A), axis=1) loss = tf.losses.mean_squared_error(Y, Q_) optimizer = tf.train.AdamOptimizer(1e-6).minimize(loss)
用一個隊列實現記憶模塊,開始遊戲,對於初始狀態選擇什麼都不作
game_state = fb.GameState() D = deque() do_nothing = np.zeros(ACTIONS) do_nothing[0] = 1 img, reward, terminal = game_state.frame_step(do_nothing) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) S0 = np.stack((img, img, img, img), axis=2)
繼續進行遊戲並訓練模型
sess = tf.Session() sess.run(tf.global_variables_initializer()) t = 0 success = 0 saver = tf.train.Saver() epsilon = INITIAL_EPSILON while True: if epsilon > FINAL_EPSILON and t > OBSERVE: epsilon = INITIAL_EPSILON - (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE * (t - OBSERVE) Qv = sess.run(Q, feed_dict={S: [S0]})[0] Av = np.zeros(ACTIONS) if np.random.random() <= epsilon: action_index = np.random.randint(ACTIONS) else: action_index = np.argmax(Qv) Av[action_index] = 1 img, reward, terminal = game_state.frame_step(Av) if reward == 1: success += 1 img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1)) S1 = np.append(S0[:, :, 1:], img, axis=2) D.append((S0, Av, reward, S1, terminal)) if len(D) > REPLAY_MEMORY: D.popleft() if t > OBSERVE: minibatch = random.sample(D, BATCH) S_batch = [d[0] for d in minibatch] A_batch = [d[1] for d in minibatch] R_batch = [d[2] for d in minibatch] S_batch_next = [d[3] for d in minibatch] T_batch = [d[4] for d in minibatch] Y_batch = [] Q_batch_next = sess.run(Q, feed_dict={S: S_batch_next}) for i in range(BATCH): if T_batch[i]: Y_batch.append(R_batch[i]) else: Y_batch.append(R_batch[i] + GAMMA * np.max(Q_batch_next[i])) sess.run(optimizer, feed_dict={S: S_batch, A: A_batch, Y: Y_batch}) S0 = S1 t += 1 if t > OBSERVE and t % 10000 == 0: saver.save(sess, './flappy_bird_dqn', global_step=t) if t <= OBSERVE: state = 'observe' elif t <= OBSERVE + EXPLORE: state = 'explore' else: state = 'train' print('Current Step %d Success %d State %s Epsilon %.6f Action %d Reward %f Q_MAX %f' % (t, success, state, epsilon, action_index, reward, np.max(Qv)))
運行dqn_flappy.py
便可從零開始訓練模型,一開始角色各類亂跳,一根水管都跳不過去,但隨着訓練的進行,角色會經過學習得到愈來愈穩定的表現
也能夠直接使用如下代碼運行訓練好的模型
# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import cv2 import sys sys.path.append('game/') import wrapped_flappy_bird as fb ACTIONS = 2 IMAGE_SIZE = 80 sess = tf.Session() sess.run(tf.global_variables_initializer()) saver = tf.train.import_meta_graph('./flappy_bird_dqn-8500000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) graph = tf.get_default_graph() S = graph.get_tensor_by_name('S:0') Q = graph.get_tensor_by_name('Q/BiasAdd:0') game_state = fb.GameState() do_nothing = np.zeros(ACTIONS) do_nothing[0] = 1 img, reward, terminal = game_state.frame_step(do_nothing) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) S0 = np.stack((img, img, img, img), axis=2) while True: Qv = sess.run(Q, feed_dict={S: [S0]})[0] Av = np.zeros(ACTIONS) Av[np.argmax(Qv)] = 1 img, reward, terminal = game_state.frame_step(Av) img = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2GRAY) _, img = cv2.threshold(img, 1, 255, cv2.THRESH_BINARY) img = np.reshape(img, (IMAGE_SIZE, IMAGE_SIZE, 1)) S0 = np.append(S0[:, :, 1:], img, axis=2)