【技術博客】Pytorch代碼生成

開發組在開發過程當中,都不可避免地遇到了一些困難或問題,但都最終想出辦法克服了。咱們認爲這樣的經驗是有必要記錄下來的,所以就有了【技術博客】。前端


Pytorch代碼生成經驗文檔

關於模型代碼的生成,主要思路爲從根節點開始進行廣度優先搜索,從而自頂向下依次生成相關層的代碼。這裏和搜索相關的主要有三個數據結構:node

  • Q:隊列,記錄後續繼續搜索的節點,即爲後續的Node。
  • graph:字典,記錄整顆搜索樹,每一個key對應一個Node,Node爲本身封裝的一個類,裏面包含每層的一些信息。記錄搜索樹的目的是爲了後續的正確性驗證,以下爲Node的定義:
class Node:
    def __init__(self, id = None, name = None, in_channels = 1, out_channels = 1, kernel_size = 3, 
        stride = 1, padding = 0, data = None, activity = None, pool_way = None, cat_dim = None):
        self.fa = np.array([], dtype = str)
        self.next = np.array([], dtype = str)
        self.id = id
        self.name = name
        self.data = data
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.pool_way = pool_way
        self.activity = activity
        self.data_shape = np.array([], dtype = int)
        self.cat_dim = cat_dim

    def add_fa(self, f):
        self.fa = np.append(self.fa, f)
        
    def add_next(self, nx):
        self.next = np.append(self.next, nx)
  • done:字典,記錄某節點相關代碼是否已經生成,每一個key對應一個boolean值。


同時還有如下須要關注的地方:python

  • 廣度優先搜索。BFS爲代碼的主要框架。從’start’節點開始搜索,直到遍歷結束,作一個線性的掃描。代碼框架以下(省略了主要代碼):數據結構

    def make_graph(nets, nets_conn, init_func, forward_func):
          #code here
    
          Q = queue.Queue()
          Q.put(‘start’)
    
        #code here
    
        while not Q.empty():
            cur_id = Q.get()
            if GL.done[cur_id]:
                continue
    
            ''''''''''''
    
            Main codes here
    
            ''''''''''''
    
            GL.done[cur_id] = True
    
        return init_func, forward_func
  • 關於全局變量的處理。因爲一開始忽略了python變量的特性(不須要聲明),因此在一開始第一全局變量的時候是直接定義在文件開頭的,可是這樣存在的問題是:若是在局部函數中引用全局變量,則此時則是從新定義了一個變量而不是引用,用global關鍵字代碼看上去又很臃腫。因此採起的辦法是從新定義了一個GLOB模塊,裏面存放着須要的全部全局變量。相似於這樣:app

    class GLOB:
        def __init__(self):
            self.graph = {}
            self.done = {}
            self.layer_used_time = {'view_layer': 0, 'linear_layer': 0, 'conv1d_layer': 0, 'conv2d_layer': 0, 'element_wise_add_layer':0, 'concatenate_layer':0}
            self.nn_linear = 'torch.nn.Linear'
            self.nn_conv1d = 'torch.nn.Conv1d'
            self.nn_conv2d = 'torch.nn.Conv2d'
            self.nn_view = '.view'
            self.nn_sequential = 'torch.nn.Sequential'
            self.start_layer = ['start']
            self.norm_layer = ['conv1d_layer', 'conv2d_layer', 'view_layer', 'linaer_layer']
            self.multi_layer = ['element_wise_add_layer', 'concatenate_layer']
            self.layers_except_start = self.norm_layer + self.multi_layer

    這樣,只須要在代碼裏初始化一個GLOB對象GL,這樣在任何地方引用全局變量都不會形成困擾。框架

  • 關於變量名生成。每層的輸出數據的名字格式爲:層名 + 「data_出現的次數」。有一個數據結構」layer_used_time」(字典)專門負責記錄每一個層出現的次數,同時,會在該層的代碼生成結構後更新layer_used_time和done的值。ide

  • 關於什麼時候初始化和更新graph。在咱們的代碼中,當從隊列中取出一個節點後會執行一個函數:get_next_nodes_and_update_pre_nodes()。該函數的目的是獲取和初始化當前節點的兒子節點,記錄前端傳入該層的其餘參數,更新其父子節點,同時返回當前節點的全部祖先節點代碼是否已經生成完畢。另外,在該函數內部也會作模型的一部分正確性驗證,主要驗證搭建的模型裏除了拼接層和相加層之外的層是否存在多個父節點或沒有節點。該函數實現的功能較多,後期會考慮重構。函數

  • 關於正確性驗證。考慮到用戶在搭建模型時不必定可以保證參數的正確,因此咱們對參數的合理性是「寬容」的,可是也有硬性的要求,好比只能有一個開始節點,同時除了拼接層和相加層能夠有多個父節點之外,其餘層有且僅有一個父節點。code

  • 關於生成的模型NET中forward函數的返回值。因爲搭建的模型容許出現網狀結構,因此不能保證模型的出口只有一個,因此現階段生成的模型會返回全部出度爲0的層的輸出值,具體順序參見代碼。orm

附最終生成的代碼效果圖(例):

相關文章
相關標籤/搜索