訓練簡單小遊戲的強化學習工具箱

代碼地址以下:
http://www.demodashi.com/demo/14072.htmlhtml

詳細

先上效果圖:前端

  • 啓動界面

啓動界面

  • 主界面

主界面

  • 設置界面

設置界面

  • 服務器界面(使用highchart模板畫出每一局得分狀況)

服務器界面

配置的兩款簡單小遊戲以及訓練效果:jquery

  • 貪吃蛇

  • 「是男人就下一百層」(修改)

跳跳人
*原圖像太大被迫修改大小git

使用說明:

【設置窗口】

→在上面的主界面中點擊倒三角形狀的鍵,屏幕上會彈出一個黑色的設置窗。在該窗口界面上,用戶能夠經過拖動滑塊條、在框內輸入具體數值兩種方法設置模型參數。滑塊條和編輯框互聯。github

【在服務器上查看訓練結果】

→點擊最小化按鈕,將會複製瀏覽器地址到剪切板上,能夠將其粘貼到瀏覽器中實時監測訓練狀況。窗口中的折線圖每隔五秒從temp.db數據庫中獲取更新的數據並加入到折線圖中,實施實時數據可視化。web

【關閉按鈕】


→當點擊關閉按鈕時,若訓練次數超過1000幀,將會彈出窗口詢問是否保存記錄。不然會因爲訓練次數過少,對訓練沒有意義而直接退出不保存結果,以提升效率。算法


→點擊確認sql

→成功保存數據庫


【新建模式訓練】

→選擇訓練遊戲json

→開始訓練(點擊播放按鈕)


→鼠標放在進度條上能看到具體數值

【加載模式訓練】

→點擊切換按鈕



→此時再點擊播放按鈕,會彈出窗口用於選擇加載模型

→點擊開始按鈕開始訓練,同時設置窗口按鈕、模式轉換按鈕都會失效,以確保訓練順利進行。

一、相關配置

  • Python 3
  • TensorFlow-gpu
  • pygame
  • OpenCV-Python
  • PyQt5
  • sys
  • threading
  • multiprocessing
  • shelve
  • os
  • sqlite3
  • socket
  • pyperclip
  • flask
  • glob
  • shutil
  • numpy
  • pandas
  • time
  • importlib

二、文件目錄

|————MyLibrary.py 用於設置遊戲中人物等類
|————run_window.py 啓動主程序,包括啓動界面
|————mainwindow.py 主界面程序
|————setting.py 參數調節窗口程序
|————message_box.py 消息框窗口程序
|————DQL.py 人工智能主程序,負責選擇和啓動遊戲、啓動深度強化學習內核
|————DQLBrain.py 深度強化學習內核
|————game_setting.py 存儲已有遊戲決策狀態數、庫名等信息,新遊戲加入必須將相關信息也加入在其中
|————flask_tk.py 服務器文件
|————jumpMan.py 跳跳人遊戲文件
|————greedySnake.py 貪吃蛇遊戲文件
|————resource 窗口圖片資源文件夾
|————save_networks 已得出的模型文件
|————templates
   |————index.html 網頁前端模板文件
|————static
   |————exporting.js
   |————highcharts-zh_CN.js
   |————highstock.js
   |————jquery.js
|————temp.db 臨時數據庫,用於服務器和AI端數據交互使用
|————greedy_snake.data-00000-of-00001
|————greedy_snake.index
|————greedy_snake.meta 以上三個爲一個訓練好的模型
|————greedy_snake.db.bak
|————greedy_snake.db.dat
|————greedy_snake.db.dir 以上三個爲一個模型文件
|————setting_resource.py 設定窗口的資源文件
|————resource_message_box.py 消息框窗口的資源文件
|————resource.py 主窗口的資源文件
|————document.py 根據數據庫文件自動化生成報告

三、實現過程

整個demo主要分爲四大部分:主窗口、算法和遊戲內核、服務器以及管理版本數據庫文件部分。

各模塊之間的關係

  • 啓動界面

import sys
    from  mainWindow import MAINWINDOW
    from PyQt5.QtWidgets import QApplication,QSplashScreen
    from PyQt5 import QtCore,QtGui,QtWidgets
    if __name__ == '__main__':
        app = QApplication(sys.argv)

        #初始化啓動界面
        splash=QtWidgets.QSplashScreen(QtGui.QPixmap("啓動界面.png"))

        #展現啓動界面
        splash.show()

        #設置計時器
        timer = QtCore.QElapsedTimer()

        #計時器開始
        timer.start()

        #保證啓動界面出現3s
        while timer.elapsed() < 3000:
            app.processEvents()

        #初始化主界面
        MainWindow = MAINWINDOW()

        #展現主界面
        MainWindow.show()

        #主界面徹底加載後,啓動界面消失
        splash.finish(MainWindow)

        sys.exit(app.exec_())
  • 主界面(均使用Qtdesigner完成)

import gameSetting
    import resource
    from PyQt5 import QtWidgets,QtCore,QtGui
    from collections import deque
    from threading import Thread
    from multiprocessing import Process
    import shelve
    import sqlite3
    import socket
    import pyperclip
    from DQL import AI
    import setting
    import messageBox
    import webServers
    import glob
    import shutil
    
    game_start=False
    
    class myThread(Thread):
        def __init__(self,game,model,replay_memory,timestep,setting):
            Thread.__init__(self)
            self.game=game
            self.model=model
            self.setting=setting
            self.replay_memory=replay_memory
            self.timestep=timestep
    
        def run(self):
            self.AI = AI(self.game,self.model,self.replay_memory,self.timestep,int(self.setting["Explore"]),float(self.setting["Initial"]),float(self.setting["Final"]),float(self.setting["Gamma"]),int(self.setting["Replay"]),int(self.setting["Batch"]),)
            self.AI.playGame()
    
        def stop(self):
            self.AI.closeGame()
    
    class MAINWINDOW(QtWidgets.QWidget):
        def __init__(self, parent=None):
    
            #父類初始化
            super().__init__()
    
            #主窗體對象初始化
            self.setObjectName("Form")
            self.setEnabled(True)
            self.resize(681, 397)
            self.setStyleSheet("background-color: rgb(255, 255, 255);")
            self.setWindowFlags(QtCore.Qt.FramelessWindowHint)
    
            #進度條初始化
            self.progressBar = QtWidgets.QProgressBar(self)
            self.progressBar.setEnabled(True)
            self.progressBar.setGeometry(QtCore.QRect(140, 348, 291, 23))
            self.progressBar.setProperty("value", 0)
            self.progressBar.setTextVisible(False)
            self.progressBar.setObjectName("progressxzBar")
    
            #啓動按鈕初始化
            self.control = QtWidgets.QPushButton(self)
            self.control.setGeometry(QtCore.QRect(10, 325, 71, 71))
            self.control.setStyleSheet("border-image: url(:/bottom/resource/開始按鈕.png);")
            self.control.setText("")
            self.control.setObjectName("control")
            self.control_state=False
    
            #下拉框初始化
            self.game_selection = QtWidgets.QComboBox(self)
            self.game_selection.setEnabled(True)
            self.game_selection.setGeometry(QtCore.QRect(530, 343, 141, 31))
            self.game_selection.setAutoFillBackground(False)
            self.game_selection.setStyleSheet("QComboBox{border-image: url(:/list/resource/下拉框.png)} \n""QComboBox::drop-down {image: url(:/bottom/resource/下拉框按鈕.png)  }")
            self.game_selection.setEditable(False)
            self.game_selection.setInsertPolicy(QtWidgets.QComboBox.NoInsert)
            self.game_selection.setIconSize(QtCore.QSize(0, 0))
            self.game_selection.setFrame(False)
            self.game_selection.setObjectName("game_selection")
    
            #模式選擇按鈕加載
            self.mode = QtWidgets.QPushButton(self)
            self.mode.setGeometry(QtCore.QRect(440, 340, 71, 41))
            self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
            self.mode.setText("")
            self.mode.setObjectName("mode")
            self.mode_state = False
    
            #背景圖初始化
            self.label = QtWidgets.QLabel(self)
            self.label.setGeometry(QtCore.QRect(0, 0, 681, 331))
            self.label.setStyleSheet("border-image: url(:/image/resource/Background.png);")
            self.label.setText("")
            self.label.setObjectName("label")
    
            #設置按鈕初始化
            self.setting = QtWidgets.QPushButton(self)
            self.setting.setGeometry(QtCore.QRect(570, 10, 31, 21))
            self.setting.setStyleSheet("border-image: url(:/bottom/resource/菜單.png);")
            self.setting.setText("")
            self.setting.setObjectName("setting")
    
            #獲取ip地址按鈕初始化
            self.pushButton_3 = QtWidgets.QPushButton(self)
            self.pushButton_3.setGeometry(QtCore.QRect(610, 10, 31, 23))
            self.pushButton_3.setStyleSheet("border-image: url(:/bottom/resource/最小化.png);")
            self.pushButton_3.setText("")
            self.pushButton_3.setObjectName("pushButton_3")
    
            #關閉按鈕初始化
            self.bottom_close = QtWidgets.QPushButton(self)
            self.bottom_close.setGeometry(QtCore.QRect(650, 10, 21, 23))
            self.bottom_close.setStyleSheet("border-image: url(:/bottom/resource/關閉.png);")
            self.bottom_close.setText("")
            self.bottom_close.setObjectName("bottom_close") 
            
            #重設界面
            self.init_window(self)
    
            #按鍵消息槽設置
            self.connectBottom()
            QtCore.QMetaObject.connectSlotsByName(self)
    
        #初始化窗口
        def init_window(self, Form):
            _translate = QtCore.QCoreApplication.translate
            Form.setWindowTitle(_translate("Form", "深度強化學習工具箱"))
    
            #子窗口對象獲取
            self.setting_form =  setting. SETTING()
            self.message_box=messageBox.MESSAGE_BOX()
    
            #遊戲列表加載
            game_setting_dict = gameSetting.getSetting()
            for i,game in enumerate(game_setting_dict.keys()):
                self.game_selection.addItem("")
                self.game_selection.setItemText(i, _translate("Form", game))
            self.game_selection.setCurrentText(_translate("Form", list(game_setting_dict.keys())[0]))
            self.game_selection.setCurrentIndex(0)
    
            #啓動服務器
            flask_process = Process(target=webServers.start)
            flask_process.daemon = True
            flask_process.start()
    
        #統一實現按鍵與消息函數鏈接
        def connectBottom(self):
            self.control.clicked.connect(self.loadGame)
            self.bottom_close.clicked.connect(self.closeWindow)
            self.mode.clicked.connect(self.setMode)
            self.setting.clicked.connect(self.openSetting)
            self.pushButton_3.clicked.connect(self.getIp)
    
        #界面可拖動設置
        def mousePressEvent(self, event):
            if event.button() == QtCore.Qt.LeftButton:
                self.m_drag = True
                self.m_DragPosition = event.globalPos() - self.pos()
                event.accept()
                self.setCursor(QtGui.QCursor(QtCore.Qt.OpenHandCursor))
    
        def mouseMoveEvent(self, QMouseEvent):
            if QtCore.Qt.LeftButton and self.m_drag:
                self.move(QMouseEvent.globalPos() - self.m_DragPosition)
                QMouseEvent.accept()
    
        def mouseReleaseEvent(self, QMouseEvent):
            self.m_drag = False
            self.setCursor(QtGui.QCursor(QtCore.Qt.ArrowCursor))
    
        #加載按鍵操做
        def loadGame(self):
            self.mode.setEnabled(False)
            self.setting.setEnabled(False)
    
            #開啓遊戲標誌
            global game_start
            game_start=True
    
            #control_state爲按鍵標誌,false爲還沒開始遊戲,true爲已經開始遊戲。按鍵外形隨狀態改變
            if self.control_state:
                self.closeWindow()
            else:
                #改變按鍵狀態
                self.control.setStyleSheet("border-image: url(:/bottom/resource/終止按鈕.png);")
                self.control_state =True
    
                #初始化AI須要的變量
                self.program_name = ""
                game=self.game_selection.currentText()
                model = ""
                replay_memory = deque()
                self.actual_timestep=0
                setting=self.setting_form.getSetting()
    
                #若是導入已有項目文件,那麼更新上述變量
                if self.mode_state:
                    program_path = QtWidgets.QFileDialog.getOpenFileName(self, "請選擇你想要加載的項目",
                                                                   "../",
                                                                   "Model File (*.dat)")
                    try:
                        #獲取項目名字(無後綴,包含地址)
                        self.program_name=program_path[0][:-7]
    
                        #打開項目文件
                        with shelve.open(self.program_name+'.db') as f:
                            #加載項目信息
                            game=f["game"]
                            model = self.program_name
                            replay_memory = f["replay"]
                            setting=f["setting"]
                            self.actual_timestep = int(f["timestep"])
                            self.setting_form.updateSetting(setting)
                            self.update_dataset(f["result"])
                    except:
                        pass
    
                #啓動遊戲線程
                self.game_thread = myThread(game,model,replay_memory,self.actual_timestep,setting)
                self.game_thread.start()
    
                #啓動狀態更新計時器
                self.state_Timer = QtCore.QTimer()
                self.state_Timer.timeout.connect(self.updateState)
                self.state_Timer.start(5000)
    
        #關閉窗口
        def closeWindow(self):
            timestep=0
    
            #若是遊戲根本沒啓動或者啓動時間太短,那麼按退出鍵則直接退出
            #這裏用try是由於有時候遊戲啓動太慢,超過五秒
            try:
                timestep=self.state["TIMESTEP"]
            except:
                pass
    
            if timestep>1000:
                #啓動對話框
                reply = self.message_box.exec_()
                if reply:
                    # 關閉遊戲窗口
                    try:
                        self.game_thread.AI.closeGame()
                    except:
                        pass
                    #新建模式
                    if not self.program_name:
                        save_program_path = QtWidgets.QFileDialog.getSaveFileName(self, "請選擇你保存項目的位置",
                                                                             "../",
                                                                             "Program File(*.db)")
    
                        #確保完成了完整保存操做後再進行操做
                        if save_program_path:
    
                            #獲取保存的程序地址和名稱(無後綴)
                            program_name = save_program_path[0].split(".")[0]
    
                            #打開程序地址
                            self.saveProgram(save_program_path,0)
    
                            #保存模型
                            self.saveModel(program_name)
    
                    #加載模式
                    else:
                        program_name=self.program_name
                        try:
                            self.saveProgram(program_name+'.db',1)
                        except:
                            pass
   
                        #保存模型
                        self.saveModel(program_name)
    
            #清空臨時數據庫
            with sqlite3.connect('temp.db', check_same_thread=False) as f:
                c = f.cursor()
                c.execute('delete from scores')
                f.commit()
    
            #關閉主界面窗口並終止計時器、服務器線程
            self.close()
    
        #統一處理保存項目文件
        def saveProgram(self,save_program_path,state):
            with shelve.open(save_program_path[0]) as f:
                # AI運行的設定
                f["setting"] = self.setting_form.getSetting()
    
                # AI運行的狀態
                state = self.game_thread.AI.getState()
    
                f["game"] = self.game_selection.currentText()
                f["epsilon"] = state["EPSILON"]
                f["result"] = [[i[0] * 1000, i[1]] for i in
                               sqlite3.connect('temp.db', check_same_thread=False).cursor().execute(
                                   'select * from scores').fetchall()]
                f["replay"] = self.game_thread.AI.getReplay()
    
                if state:
                    f["timestep"]=int(state["TIMESTEP"]) + int(f["timestep"])
                else:
                    f["timestep"] = state["TIMESTEP"]
    
        #定時更新主窗口狀態
        def updateState(self):
            #嘗試獲取遊戲狀態,若是啓動時間過慢仍未啓動則跳過這次獲取
            try:
                self.state = self.game_thread.AI.getState()
            except:
                pass
            else:
                actual_timestep=self.state["TIMESTEP"]
                self.progressBar.setToolTip("Timestep:"+str(actual_timestep)+"    STATE:"+self.state["STATE"]+"     EPSILON:"+str(self.state["EPSILON"]))
                self.progressBar.setProperty("value",min(float(actual_timestep)/float(self.setting_form.getSetting()["Explore"])*100,100))
    
            #每隔5秒才向數據庫讀取一次,優化速度
            try:
                self.game_thread.AI.data_base.commit()
            except:
                pass
    
    
        # 經過按鍵更改AI模式
        def setMode(self):
            if not self.mode_state:
                self.mode_state = True
                self.mode.setStyleSheet("border-image: url(:/bottom/resource/加載模式.png);\n""")
            else:
                self.mode_state = False
                self.mode.setStyleSheet("border-image: url(:/bottom/resource/空白模式.png);\n""")
    
    
        # 獲取本機ip地址
        def getIp(self):
            try:
                sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                sock.connect(('8.8.8.8', 80))
                ip = sock.getsockname()[0]
            finally:
                sock.close()
            pyperclip.copy(ip + ':9090')
    
    
        #定時更新數據庫
        def updateDataset(self,results):
            with shelve.open('temp.db',writeback=True) as f:
                c=f.cursor()
                for result in results:
                    c.execute("insert into scores values (%s,%s)" % (result[0], result[1]))
                f.commit()
    
    
        # 保存模型
        def saveModel(self, program_name):
            for file in glob.glob("./saved_networks/network-dqn-*"):
                postfix = file.split('.')[-1]
                try:
                    shutil.copy(file, program_name + '.' + postfix)
                except:
                    pass
    
    
        # 設置按鍵操做
        def openSetting(self):
            self.setting_form.show()
  • 設置窗口

from PyQt5 import QtCore, QtGui, QtWidgets
    import setting_resource
    
    class SETTING(QtWidgets.QWidget):
        def __init__(self):
    
            #父類初始化
            super().__init__()
    
            #主窗口初始化
            self.setObjectName("Dialog")
            self.resize(547, 402)
            self.setStyleSheet("")
    
            #初始化肯定按鈕
            self.pushButton = QtWidgets.QPushButton(self)
            self.pushButton.setGeometry(QtCore.QRect(160, 320, 75, 23))
            self.pushButton.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/設定肯定按鈕.png);")
            self.pushButton.setText("")
            self.pushButton.setObjectName("pushButton")
    
            #初始化取消按鈕
            self.pushButton_2 = QtWidgets.QPushButton(self)
            self.pushButton_2.setGeometry(QtCore.QRect(320, 320, 75, 23))
            self.pushButton_2.setStyleSheet("color: rgb(255, 255, 255);\n""border-image: url(:/image/resource/設定取消按鈕.png);")
            self.pushButton_2.setText("")
            self.pushButton_2.setObjectName("pushButton_2")
    
            #初始化各個編輯框
            self.line_explore = QtWidgets.QLineEdit(self)
            self.line_explore.setGeometry(QtCore.QRect(450, 60, 61, 20))
            self.line_explore.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_explore.setObjectName("line_explore")
            self.line_initial = QtWidgets.QLineEdit(self)
            self.line_initial.setGeometry(QtCore.QRect(450, 100, 61, 20))
            self.line_initial.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_initial.setObjectName("line_Initial")
            self.line_final = QtWidgets.QLineEdit(self)
            self.line_final.setGeometry(QtCore.QRect(450, 140, 61, 20))
            self.line_final.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_final.setObjectName("line_final")
            self.line_gamma = QtWidgets.QLineEdit(self)
            self.line_gamma.setGeometry(QtCore.QRect(450, 180, 61, 20))
            self.line_gamma.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_gamma.setObjectName("line_gamma")
            self.line_replay = QtWidgets.QLineEdit(self)
            self.line_replay.setGeometry(QtCore.QRect(450, 220, 61, 20))
            self.line_replay.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_replay.setObjectName("line_replay")
            self.line_batch = QtWidgets.QLineEdit(self)
            self.line_batch.setGeometry(QtCore.QRect(450, 260, 61, 20))
            self.line_batch.setStyleSheet("color: rgb(0, 0, 0);")
            self.line_batch.setObjectName("line_batch")
            self.exploreSlider = QtWidgets.QSlider(self)
            self.exploreSlider.setGeometry(QtCore.QRect(120, 60, 300, 19))
            self.exploreSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.exploreSlider.setMinimum(200000)
            self.exploreSlider.setMaximum(10000000)
            self.exploreSlider.setProperty("value", 200000)
            self.exploreSlider.setOrientation(QtCore.Qt.Horizontal)
            self.exploreSlider.setObjectName("exploreSlider")
            self.label = QtWidgets.QLabel(self)
            self.label.setGeometry(QtCore.QRect(50, 60, 48, 19))
            self.label.setStyleSheet("color: rgb(255, 255, 255);")
            self.label.setObjectName("label")
            self.label_2 = QtWidgets.QLabel(self)
            self.label_2.setGeometry(QtCore.QRect(50, 100, 48, 19))
            self.label_2.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_2.setObjectName("label_2")
            self.initialSlider = QtWidgets.QSlider(self)
            self.initialSlider.setGeometry(QtCore.QRect(120, 100, 300, 19))
            self.initialSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.initialSlider.setMaximum(1000)
            self.initialSlider.setProperty("value", 0)
            self.initialSlider.setOrientation(QtCore.Qt.Horizontal)
            self.initialSlider.setObjectName("initialSlider")
            self.label_3 = QtWidgets.QLabel(self)
            self.label_3.setGeometry(QtCore.QRect(50, 140, 42, 19))
            self.label_3.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_3.setObjectName("label_3")
            self.finalSlider = QtWidgets.QSlider(self)
            self.finalSlider.setGeometry(QtCore.QRect(120, 140, 300, 19))
            self.finalSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.finalSlider.setMaximum(1000)
            self.finalSlider.setProperty("value", 0)
            self.finalSlider.setOrientation(QtCore.Qt.Horizontal)
            self.finalSlider.setObjectName("finalSlider")
            self.label_4 = QtWidgets.QLabel(self)
            self.label_4.setGeometry(QtCore.QRect(50, 180, 42, 19))
            self.label_4.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_4.setObjectName("label_4")
            self.gammaSlider = QtWidgets.QSlider(self)
            self.gammaSlider.setGeometry(QtCore.QRect(120, 180, 300, 19))
            self.gammaSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.gammaSlider.setMaximum(100)
            self.gammaSlider.setProperty("value", 99)
            self.gammaSlider.setOrientation(QtCore.Qt.Horizontal)
            self.gammaSlider.setObjectName("gammaSlider")
            self.label_6 = QtWidgets.QLabel(self)
            self.label_6.setGeometry(QtCore.QRect(50, 220, 42, 19))
            self.label_6.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_6.setObjectName("label_6")
            self.replaySlider = QtWidgets.QSlider(self)
            self.replaySlider.setGeometry(QtCore.QRect(120, 220, 300, 19))
            self.replaySlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.replaySlider.setMaximum(100000)
            self.replaySlider.setProperty("value", 50000)
            self.replaySlider.setOrientation(QtCore.Qt.Horizontal)
            self.replaySlider.setObjectName("replaySlider")
            self.label_7 = QtWidgets.QLabel(self)
            self.label_7.setGeometry(QtCore.QRect(50, 260, 36, 19))
            self.label_7.setStyleSheet("color: rgb(255, 255, 255);")
            self.label_7.setObjectName("label_7")
            self.batchSlider = QtWidgets.QSlider(self)
            self.batchSlider.setGeometry(QtCore.QRect(120, 260, 300, 19))
            self.batchSlider.setStyleSheet("QSlider::handle:horizontal {     \n""    image: url(:/image/resource/Handle.png);\n""}\n""QSlider::groove:horizontal {        \n""    image: url(:/image/resource/Base.png);\n""}\n""")
            self.batchSlider.setMaximum(100)
            self.batchSlider.setProperty("value", 32)
            self.batchSlider.setOrientation(QtCore.Qt.Horizontal)
            self.batchSlider.setObjectName("batchSlider")
            self.label_5 = QtWidgets.QLabel(self)
            self.label_5.setGeometry(QtCore.QRect(0, 0, 551, 411))
            self.label_5.setStyleSheet("background-image: url(:/background/resource/設定背景.png);")
            self.label_5.setText("")
            self.label_5.setObjectName("label_5")
    
            #組件掛起待用
            self.label_5.raise_()
            self.pushButton.raise_()
            self.pushButton_2.raise_()
            self.line_explore.raise_()
            self.line_initial.raise_()
            self.line_final.raise_()
            self.line_gamma.raise_()
            self.line_replay.raise_()
            self.line_batch.raise_()
            self.exploreSlider.raise_()
            self.label.raise_()
            self.label_2.raise_()
            self.initialSlider.raise_()
            self.label_3.raise_()
            self.finalSlider.raise_()
            self.label_4.raise_()
            self.gammaSlider.raise_()
            self.label_6.raise_()
            self.replaySlider.raise_()
            self.label_7.raise_()
            self.batchSlider.raise_()
    
            #重設界面
            self.retranslateUi(self)
    
            #編輯框和滑條互聯
            self.connect()
    
            #按鈕消息槽激活
            self.pushButton.clicked.connect(self.saveSetting)
            self.pushButton_2.clicked.connect(self.cancel)
            QtCore.QMetaObject.connectSlotsByName(self)
    
        def retranslateUi(self, Dialog):
            _translate = QtCore.QCoreApplication.translate
            Dialog.setWindowTitle(_translate("Dialog", "設置"))
    
            #初始化各編輯框
            self.line_explore.setText(_translate("Dialog", "200000"))
            self.line_initial.setText(_translate("Dialog", "0"))
            self.line_final.setText(_translate("Dialog", "0"))
            self.line_gamma.setText(_translate("Dialog", "0.99"))
            self.line_replay.setText(_translate("Dialog", "50000"))
            self.line_batch.setText(_translate("Dialog", "32"))
            self.label.setText(_translate("Dialog", "Explore:"))
            self.label_2.setText(_translate("Dialog", "Initial:"))
            self.label_3.setText(_translate("Dialog", "Final:"))
            self.label_4.setText(_translate("Dialog", "Gamma:"))
            self.label_6.setText(_translate("Dialog", "Replay:"))
            self.label_7.setText(_translate("Dialog", "Batch:"))
    
            #初始化設定
            self.setting={"Explore":200000,"Initial":0,"Final":0,"Gamma":0.99,"Replay":50000,"Batch":32}
    
        #編輯框和滑動條互聯
        def connect(self):
    
            self.exploreSlider.valueChanged.connect(self.changeLineExplore)
            self.line_explore.textChanged.connect(self.changeSliderExplore)
    
            self.initialSlider.valueChanged.connect(self.changeLineInitial)
            self.line_initial.textChanged.connect(self.changeSliderInitial)
    
            self.finalSlider.valueChanged.connect(self.changeLineFinal)
            self.line_final.textChanged.connect(self.changeSliderFinal)
    
            self.gammaSlider.valueChanged.connect(self.changeLineGamma)
            self.line_gamma.textChanged.connect(self.changeSliderGamma)
    
            self.replaySlider.valueChanged.connect(self.changeLineReplay)
            self.line_replay.textChanged.connect(self.changeSliderReplay)
    
            self.batchSlider.valueChanged.connect(self.changeLineBatch)
            self.line_batch.textChanged.connect(self.changeSliderBatch)
    
        def changeLineExplore(self):
            try:
                self.line_explore.setText(str(self.exploreSlider.value()))
            except:
                pass
    
        def changeSliderExplore(self):
            try:
                self.exploreSlider.setValue(int(self.line_explore.text()))
            except:
                pass
    
        def changeLineInitial(self):
            try:
                self.line_initial.setText(str(self.initialSlider.value()/1000))
            except:
                pass
    
        def changeSliderInitial(self):
            try:
                self.initialSlider.setValue(int(float(self.line_initial.text())*1000))
            except:
                pass
    
        def changeLineFinal(self):
            try:
                self.line_final.setText(str(self.finalSlider.value()/1000))
            except:
                pass
    
        def changeSliderFinal(self):
            try:
                self.finalSlider.setValue(int(float(self.line_final.text()*1000)))
            except:
                pass
    
        def changeLineGamma(self):
            try:
                self.line_gamma.setText(str(self.gammaSlider.value()/100))
            except:
                pass
    
        def changeSliderGamma(self):
            try:
                self.gammaSlider.setValue(int(100*float(self.line_gamma.text())))
            except:
                pass
    
        def changeLineReplay(self):
            try:
                self.line_replay.setText(str(self.replaySlider.value()))
            except:
                pass
    
        def changeSliderReplay(self):
            try:
                self.replaySlider.setValue(int(self.line_replay.text()))
            except:
                pass
    
        def changeLineBatch(self):
            try:
                self.line_batch.setText(str(self.batchSlider.value()))
            except:
                pass
    
        def changeSliderBatch(self):
            try:
                self.batchSlider.setValue(int(self.line_batch.text()))
            except:
                pass
    
        #外部獲取AI設置
        def getSetting(self):
            return self.setting
    
        #保存設定
        def saveSetting(self):
            self.setting={"Explore":self.line_explore.text(),"Initial":self.line_initial.text(),"Final":self.line_final.text(),"Gamma":self.line_gamma.text(),"Replay":self.line_replay.text(),"Batch":self.line_batch.text()}#還要作一個數字判斷
            self.hide()
    
        #取消設定
        def cancel(self):
            self.hide()
            return 0
    
        #經過導入文檔更新設定
        def updateSetting(self,setting):
            self.setting={"Explore":setting["Explore"],"Initial":setting["Initial"],"Final":setting["Final"],"Gamma":setting["Gamma"],"Replay":setting["Replay"],"Batch":setting["Batch"]}#還要作一個數字判斷
            self.line_explore.setText(str(setting["Explore"]))
            self.line_final.setText(str(setting["Final"]))
            self.line_Initial.setText(str(setting["Initial"]))
            self.line_gamma.setText(str(setting["Gamma"]))
            self.line_replay.setText(str(setting["Replay"]))
            self.line_batch.setText(str(setting["Batch"]))
  • 深度強化學習
    該部分代碼參考https://blog.csdn.net/songrotek/article/details/50951537。 深度強化學習原理我這裏再也不贅述,你們能夠查看該blog,有很詳細的講解。
    主要由兩部分組成:DQL.py統一管理遊戲和算法,DQLBrain.py則是深度強化學習算法核心。下面分別展現:
    • DQL.py
      import cv2
      from DQLBrain import Brain
      import numpy as np
      from collections import deque
      import sqlite3
      import pygame
      import time
      import gameSetting
      import importlib

      #全部遊戲的統一設置
      SCREEN_X = 288
      SCREEN_Y = 512
      FPS = 60
      
      class AI:
          def __init__(self, title,model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size):
              #初始化常量
              self.scores = deque()
              self.games_info = gameSetting.getSetting()
      
              #鏈接臨時數據庫(並確保已經存在對應的表)
              self.data_base = sqlite3.connect('temp.db', check_same_thread=False)
              self.c = self.data_base.cursor()
              try:
                  self.c.execute('create table scores (time integer, score integer) ')
              except:
                  pass
      
              #建立Deep-Reinforcement Learning對象
              self.brain = Brain(self.games_info[title]["action"],model_path,replay_memory,current_timestep,explore,initial_epsilon,final_epsilon,gamma,replay_size,batch_size)
      
              #建立遊戲窗口
              self.startGame(title,SCREEN_X,SCREEN_Y)
      
              #加載對應的遊戲
              game=importlib.import_module(self.games_info[title]['class'])
              self.game=game.Game(self.screen)
      
          def startGame(self,title,SCREEN_X, SCREEN_Y):
              #窗口的初始化
              pygame.init()
              screen_size = (SCREEN_X, SCREEN_Y)
              pygame.display.set_caption(title)
      
              #屏幕的建立
              self.screen = pygame.display.set_mode(screen_size)
      
              #遊戲計時器的建立
              self.clock = pygame.time.Clock()
      
          #爲下降畫面複雜度,將畫面進行預處理
          def preProcess(self, observation):
      
              #將512*288的畫面裁剪爲80*80並將RGB(三通道)畫面轉換成灰度圖(一通道)
              observation = cv2.cvtColor(cv2.resize(observation, (80, 80)), cv2.COLOR_BGR2GRAY)
      
              #將非黑色的像素都變成白色
              threshold,observation = cv2.threshold(observation, 1, 255, cv2.THRESH_BINARY)
      
              #返回(80,80,1),最後一維是保證圖像是一個tensor(張量),用於輸入tensorflow
              return np.reshape(observation, (80, 80, 1))
      
          #開始遊戲
          def playGame(self):
      
              #先隨便給一個決策輸入,啓動遊戲
              observation0, reward0, terminal,score =self.game.frameStep(np.array([1, 0, 0]))
              observation0 = self.preProcess(observation0)
              self.brain.setInitState(observation0[:,:,0])
      
              #開始正式遊戲
              i = 1
              while True:
                  i = i + 1
                  action = self.brain.getAction()
                  next_bservation, reward, terminal,score = self.game.frameStep(action)
      
                  #處理遊戲界面銷燬消息
                  if (terminal == -1):
                      self.closeGame()
                      return
                  else:
      
                  #繼續遊戲
                      next_bservation = self.preProcess(next_bservation)
                      self.brain.setPerception(next_bservation, action, reward, terminal)
      
                  #提取每一局的成績
                  if terminal:
                      t = int(time.time())
                      self.c.execute("insert into scores values (%s,%s)" % (t, score))
      
          #關閉遊戲
          def closeGame(self):
              pygame.quit()
              self.brain.close()
              time.sleep(0.5)#確保brain中寫入數據庫的操做已經完成
              self.data_base.close()
      
          #得到當前遊戲狀態
          def getState(self):
              return self.brain.getState()
      
          #得到當前replay數據,以加入項目文件
          def getReplay(self):
              return self.brain.replay_memory
    • DQLBrain.py
      observe=100

      class Brain:
          def __init__(self, actions,model_path,replay_memory=deque(),current_timestep=0,explore=200000.,initial_epsilon=0.0,final_epsilon=0.0,gamma=0.99,replay_size=50000,batch_size=32):
      
              # 設置超參數:
      
              # 學習率
              self.gamma = gamma
      
              # 訓練以前觀察的次數
              self.observe = observe
      
              # 容錯率降低的次數
              self.explore = explore
      
              # 一開始的容錯率
              self.initial_epsilon = initial_epsilon
      
              #最終的容錯率
              self.final_epsilon = final_epsilon
      
              # replay buffer的大小
              self.replay_size = replay_size
      
              # minibatch的大小
              self.batch_size = batch_size
      
              self.update_time = 100
      
              self.whole_state = dict()
      
              #初始化replay buffer
              self.replay_memory = replay_memory
      
              # 初始化其餘參數
              self.timestep = 0
              self.initial_timestep=current_timestep
              self.accual_timestep=self.initial_timestep+self.timestep
      
              #當主界面採用加載模式時,算法核心必須從新加載項目文件中的已經記錄的容錯率
              self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep
              if self.epsilon<self.final_epsilon:
                  self.epsilon=self.final_epsilon
              self.actions = actions
      
              # 初始化 Q_t+1 網絡
              self.state_input, self.QValue, self.conv1_w, self.conv1_b, self.conv2_w, self.conv2_b, self.conv3_w, self.conv3_b, self.fc1_w, self.fc1_b, self.fc2_w, self.fc2_b = self.createQNetwork()
      
              # 初始化 Q_t 網絡
              self.state_inputT, self.QValueT, self.conv1_wT, self.conv1_bT, self.conv2_wT, self.conv2_bT, self.conv3_wT, self.conv3_bT, self.fc1_wT, self.fc1_bT, self.fc2_wT, self.fc2_bT = self.createQNetwork()
              self.copyTargetQNetwork = [self.conv1_wT.assign(self.conv1_w), self.conv1_bT.assign(self.conv1_b), self.conv2_wT.assign(self.conv2_w), self.conv2_bT.assign(self.conv2_b), self.conv3_wT.assign(self.conv3_w), self.conv3_bT.assign(self.conv3_b), self.fc1_wT.assign(self.fc1_w), self.fc1_bT.assign(self.fc1_b), self.fc2_wT.assign(self.fc2_w), self.fc2_bT.assign(self.fc2_b)]
      
              #損失函數的設置
              self.action_input = tf.placeholder("float", [None, self.actions])
              self.y_input = tf.placeholder("float", [None])
              Q_Action = tf.reduce_sum(tf.multiply(self.QValue, self.action_input), reduction_indices=1)
              self.cost = tf.reduce_mean(tf.square(self.y_input - Q_Action))
              self.optimizer = tf.train.AdamOptimizer(1e-6).minimize(self.cost)
      
              # 保存和從新加載模型
              self.saver = tf.train.Saver(max_to_keep=1)
              self.session = tf.InteractiveSession()
              self.session.run(tf.initialize_all_variables())
      
          def createQNetwork(self):
      
              # 初始化結構
              # 第一層卷積層 8*8*4*32
              W_conv1 = self.weightVariable([8, 8, 4, 32])
              b_conv1 = self.biasVariable([32])
      
              # 第二層卷積層 4*4*32*64:
              W_conv2 = self.weightVariable([4, 4, 32, 64])
              b_conv2 = self.biasVariable([64])
      
              #第三層卷積層 3*3*64*64
              W_conv3 = self.weightVariable([3, 3, 64, 64])
              b_conv3 = self.biasVariable([64])
      
              #全鏈接層1600*512
              W_fc1 = self.weightVariable([1600, 512])
              b_fc1 = self.biasVariable([512])
      
              #輸出層 512*actions
              W_fc2 = self.weightVariable([512, self.actions])
              b_fc2 = self.biasVariable([self.actions])
      
              # input layer
              stateInput = tf.placeholder("float", [None, 80, 80, 4])
      
              # 開始創建網絡
              # 隱藏層
      
              h_conv1 = tf.nn.relu(self.conv2d(stateInput, W_conv1, 4) + b_conv1)
      
              #20*20*32 to 10*10*32
              h_pool1 = self.maxPool_2x2(h_conv1)
      
              h_conv2 = tf.nn.relu(self.conv2d(h_pool1, W_conv2, 2) + b_conv2)
      
              #stride=1,5*5*64 to 5*5*64
              h_conv3 = tf.nn.relu(self.conv2d(h_conv2, W_conv3, 1) + b_conv3)
      
              #5*5*64 to 1*1600
              h_conv3_flat = tf.reshape(h_conv3, [-1, 1600])
              h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, W_fc1) + b_fc1)
      
              #輸出層
              QValue = tf.matmul(h_fc1, W_fc2) + b_fc2
      
              return stateInput, QValue, W_conv1, b_conv1, W_conv2, b_conv2, W_conv3, b_conv3, W_fc1, b_fc1, W_fc2, b_fc2
      
          def trainQNetwork(self):
      
              #從replay buffer中抽樣
              minibatch = random.sample(self.replay_memory, self.batch_size)
              state_batch = [data[0] for data in minibatch]
              action_batch = [data[1] for data in minibatch]
              reward_batch = [data[2] for data in minibatch]
              nextState_batch = [data[3] for data in minibatch]
      
              #計算損失函數
              y_batch = []
              QValue_batch = self.QValueT.eval(feed_dict={self.state_inputT: nextState_batch})
              for i in range(0, self.batch_size):
                  terminal = minibatch[i][4]
                  if terminal:
                      y_batch.append(reward_batch[i])
                  else:
                      y_batch.append(reward_batch[i] + self.gamma * np.max(QValue_batch[i]))
              self.optimizer.run(feed_dict={self.y_input: y_batch, self.action_input: action_batch, self.state_input: state_batch})
      
              # 每運行100epoch保存一次網絡
              if self.timestep % 1000 == 0:
                  self.saver.save(self.session, './saved_networks/network' + '-dqn', global_step=self.timestep+self.initial_timestep)
      
              #更新Q網絡
              if self.timestep % self.update_time == 0:
                  self.session.run(self.copyTargetQNetwork)
      
          def setPerception(self, nextObservation, action, reward, terminal):
      
              new_state = np.append(self.current_state[:, :, 1:], nextObservation, axis=2)
              self.replay_memory.append((self.current_state, action, reward, new_state, terminal))
      
              #控制replay buffer的大小
              if len(self.replay_memory) > self.replay_size:
                  self.replay_memory.popleft()
              if self.timestep > self.observe:
                  self.trainQNetwork()
      
              # 將訓練信息輸出到主界面中
              if self.timestep <= self.observe:
                  state = "observe"
              elif self.timestep  > self.observe and self.timestep  <= self.observe + self.explore:
                  state = "explore"
              else:
                  state = "train"
      
              self.whole_state={"TIMESTEP":self.timestep +self.initial_timestep,"STATE":state, "EPSILON":self.epsilon,"ACTUAL":int(self.timestep+self.initial_timestep)}
      
              self.current_state = new_state
              self.timestep  += 1
      
          def getAction(self):
              QValue = self.QValue.eval(feed_dict={self.state_input: [self.current_state]})[0]
              action = np.zeros(self.actions)
      
              #epsilon策略
              if random.random() <= self.epsilon:
                  action_index = random.randrange(self.actions)
                  action[action_index] = 1
              else:
                  action_index = np.argmax(QValue)
                  action[action_index] = 1
      
              # 改變episilon
              if self.epsilon > self.final_epsilon and self.accual_timestep > self.observe:
                  self.epsilon = self.initial_epsilon - (self.initial_epsilon - self.final_epsilon) / self.explore * self.accual_timestep
      
              return action
      
          def setInitState(self, observation):
              self.current_state = np.stack((observation, observation, observation, observation), axis=2)
      
          def weightVariable(self, shape):
              initial = tf.truncated_normal(shape, stddev=0.01)
              return tf.Variable(initial)
      
          def biasVariable(self, shape):
              initial = tf.constant(0.01, shape=shape)
              return tf.Variable(initial)
      
          def conv2d(self, x, W, stride):
              return tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding="SAME")
      
          def maxPool_2x2(self, x):
              return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
      
          def close(self):
              self.session.close()
      
          def getState(self):
              return self.whole_state
  • 服務器

  主要採用highchart的API。在static文件夾中放好上述的四項文件後,在template文件夾中寫好服務器界面的代碼index.html(爲了方便你們學習,界面寫得至關簡陋hh):

<head>

    <script src='/static/jquery.js'></script>
    <script src='/static/highstock.js'></script>
    <script src='/static/exporting.js'></script>

    </head>
    <body>

        <div id="container" style="min-width:310px;height:400px"></div>

        <script>
    $(function () {
        // 使用當前時區,不然東八區會差八個小時
        Highcharts.setOptions({
            global: {
                useUTC: false
            }
        });
        $.getJSON('/data', function (data) {
            // Create the chart
            $('#container').highcharts('StockChart', {
            chart:{
            events:{
            
                load:function(){
                
                    var series = this.series[0]
                    setInterval(function(){
                    $.getJSON('/data',function(res){
                        $.each(res,function(i,v){
                            series.addPoint(v)
                        })
                    })
                    },3000)
                }
            }
            },
                rangeSelector : {
                    selected : 1
                },
                title : {
                    text : '每局分數'
                },
                series : [{
                    name : '訓練表現',
                    data : data,
                    tooltip: {
                        valueDecimals: 2
                    }
                }]
            });
        });
    });
    </script>
    </body>
    </html>

  同時還須要編寫一個實時調用該模板的py文件:Webservice.py:

from flask import Flask,render_template,request
    import sqlite3
    import json
    
    app=Flask(__name__)
    
    #鏈接臨時數據庫
    data_base = sqlite3.connect('temp.db', check_same_thread=False)
    c = data_base.cursor()
    
    #設置前端模板
    @app.route('/')
    def index():
        return render_template("index.html")
    
    
    #設置數據來源
    @app.route('/data')
    def data():
        global tmp_time,c
        sql='select * from scores'
        c.execute(sql)
        arr=[]
        for i in c.fetchall():
            arr.append([i[0]*1000,i[1]])
        return json.dumps(arr)
    
    #啓動服務器並設定端口,設置0.0.0.0表示對內網全部主機都進行監聽
    def start():
        app.run(host='0.0.0.0',port=9090)

結語  

  不過貌似PyQt5和tensorflow會有衝突,所以實際運行的時候會偶爾出現崩潰。另外服務器沒法由外網的機器鏈接。若是你們知道怎麼解決這些問題請在下方留言告訴我,謝謝!最後再來一次:github地址爲https://github.com/qq303067814/DQLearning-Toolbox, 若是講解中有部分還想繼續瞭解的話能夠直接查看源代碼,或者在留言中提出。訓練簡單小遊戲的強化學習工具箱

代碼地址以下:
http://www.demodashi.com/demo/14072.html

注:本文著做權歸做者,由demo大師代發,拒絕轉載,轉載須要做者受權

相關文章
相關標籤/搜索