TensorFlow

官網

Tensorflow源碼分析

A、基本概念

  1. Graph

  2. Tensor 

  3. Session

B、Tools

  1. Checkpoint  .Ckpt

  2. Pb

  3. .Ckpt To .Pb

  4. TensorBoard

 B.1  .Ckpt 模型加載

1. 模型的保存

import tensorflow as tf

def store_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    #模型的保存必須有變量
    c = tf.Variable(1, name='c')
    a = tf.add(x, y, name='op')
    result = tf.add(a, c)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
    
        saver = tf.train.Saver()
    
        #若是隻保存其中一部分變量,則使用下面代碼,用列表或者字典均可以
        #saver = tf.train.Saver([x, y])
    
        #這裏面有參數global_step=50,當訓練50步便保存模型
        saver.save(sess, ckpt_file_path)
        # test
        feed_dict = {x: 2, y: 3}
        print(sess.run(result, feed_dict))

def main():
    ckpt_file_path = "./ckpt/model.ckpt"
    store_model_ckpt(ckpt_file_path)

if __name__ == '__main__':
    main()

結果:6node

程序生成並保存四個文件python

  1. checkpoint 文本文件,記錄了模型文件的路徑信息列表
  2. model.ckpt.data-00000-of-00001 網絡權重信息
  3. model.ckpt.index .data和.index這兩個文件是二進制文件,保存了模型中的變量參數(權重)信息
  4. model.ckpt.meta 二進制文件,保存了模型的計算圖結構信息(模型的網絡結構)protobuf

2. 模型恢復加載

針對上面的模型保存例子,還原模型的過程以下:git

import tensorflow as tf

def restore_model_ckpt():
    with tf.Session() as sess:
        #step1:加載模型結構
        saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')
        #step2:只須要指定目錄就能夠恢復全部變量信息
        saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))
        
        #直接獲取保存的變量
        print(sess.run('c:0'))
        
        #獲取placeholder變量,經過get_tensor_by_name
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #獲取須要進行計算的op算子,此op爲加法
        op = sess.graph.get_tensor_by_name('op:0')
        
        #加入新的op操做,新的op爲乘法
        new_op = tf.multiply(op, 2)
        
        #test
        feed_dict = {x:2, y:3}
        
        result = sess.run(new_op,feed_dict)
        print(result)

def main():
    restore_model_ckpt()
    
if __name__ == '__main__':
    main()

結果:10瀏覽器

  1. 首先還原模型結構網絡

  2. 而後還原變量(參數)信息架構

  3. 最後咱們就能夠得到已訓練的模型中的各類信息了(保存的變量、placeholder變量、operator等),同時能夠對獲取的變量添加各類新的操做(見以上代碼註釋)。
  而且,咱們也能夠加載部分模型,在此基礎上加入其它操做,具體能夠參考官方文檔和demo。dom

  針對ckpt模型文件的保存與還原,stackoverflow上有一個回答解釋比較清晰,能夠參考。函數

  同時cv-tricks.com上面的TensorFlow模型保存與恢復的教程也很是好,能夠參考。源碼分析

B. 2 Pb模型文件

 1. pb模型保存

import tensorflow as tf
from tensorflow.python.framework import graph_util

def store_model_pb(pb_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    a = tf.add(x, y)
    #該op算子應該加上name
    op = tf.add(a, b, name='op')
    
    with tf.Session() as sess:
        init = tf.initialize_all_variables()
        sess.run(init)
        
        #導出當前計算圖的GraphDef部分,只須要這一部分就能夠完成從輸入層到輸出層的計算
        graph_def = tf.get_default_graph().as_graph_def()
        
        #將圖中的變量及其取值轉化爲常量,同時將圖中的沒必要要的節點去掉
        output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['op'])
        
        with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
            f.write(output_graph_def.SerializeToString())
        
        #test
        feed_dict = {x: 2, y: 3}
        print(sess.run(op, feed_dict))

def main():
    pb_file_path = "model.pb"
    store_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()
    

結果:6 測試

  在當前文件下面生成model.pb文件

2. pb模型加載

import tensorflow as tf
from tensorflow.python.platform import gfile
    
def restore_model_pb(pb_file_path):
    with tf.Session() as sess:
        with gfile.FastGFile(pb_file_path, 'rb') as f:
            graph_def = tf.GraphDef()
            #轉換成字符串形式
            graph_def.ParseFromString(f.read())
            sess.graph.as_default()
            tf.import_graph_def(graph_def, name='')
       
        #獲取placeholder的變量
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #獲取op算子
        op = sess.graph.get_tensor_by_name('op:0')
        
        feed_dict = {x: 2, y:3}
        result = sess.run(op,feed_dict)
        print(result)
          
def main():
    pb_file_path = "model.pb"
    restore_model_pb(pb_file_path)
    
if __name__ == '__main__':
    main()

結果:5

B 3. 將.Ckpt 轉換爲.Pb

  但不少時候,咱們須要將TensorFlow的模型導出爲單個文件(同時包含模型結構的定義與權重),方便在其餘地方使用(如在Android中部署網絡)。利用tf.train.write_graph()默認狀況下只導出了網絡的定義(沒有權重),而利用tf.train.Saver().save()導出的文件graph_def與權重是分離的,所以須要採用別的方法。 咱們知道,graph_def文件中沒有包含網絡中的Variable值(一般狀況存儲了權重),可是卻包含了constant值,因此若是咱們能把Variable轉換爲constant,便可達到使用一個文件同時存儲網絡架構與權重的目標。

    TensoFlow爲咱們提供了convert_variables_to_constants()方法,該方法能夠固化模型結構,將計算圖中的變量取值以常量的形式保存,並且保存的模型能夠移植到Android平臺。

1、CKPT 轉換成 PB格式

  將CKPT 轉換成 PB格式的文件的過程可簡述以下:

    1. 經過傳入 CKPT 模型的路徑獲得模型的圖和變量數據
    2. 經過 import_meta_graph 導入模型中的圖
    3. 經過 saver.restore 從模型中恢復圖中各個變量的數據
    4. 經過 graph_util.convert_variables_to_constants 將模型持久化

Code:freeze_graph.py

import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(ckpt_file_path, pb_file_path):
    #「input:0」是張量的名稱,而"input"表示的是節點的名稱。
    #此處輸入的應該是節點的名稱
    output_node_names = "op"
    #首先恢復圖結構
    saver = tf.train.import_meta_graph(ckpt_file_path+'.meta',clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    
    with tf.Session() as sess:
        #恢復圖並獲得數據
        saver.restore(sess,ckpt_file_path)
        output_graph_def = graph_util.convert_variables_to_constants(
                sess=sess,
                input_graph_def=input_graph_def,
                #若是有多個輸出節點
                output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(pb_file_path,"wb") as f:
            f.write(output_graph_def.SerializeToString())
            print("%d ops in the final graph." % len(output_graph_def.node)) 
                     
def main():
    # 輸入ckpt模型路徑
    model_folder = "D:\AI\Ckpt\TestCkpt\ckpt"
    #檢查目錄下ckpt文件狀態是否可用
    checkpoint = tf.train.get_checkpoint_state(model_folder) 
    #得ckpt文件路徑
    ckpt_file_path = checkpoint.model_checkpoint_path 
    
    # 輸出pb模型的路徑
    pb_file_path="frozen_model.pb"
    
    # 調用freeze_graph將ckpt轉爲pb
    freeze_graph(ckpt_file_path,pb_file_path)
    
if __name__ == '__main__':
    main()

結果:生成 frozen_model.pb文件,能夠採用上面pb模型加載的方法測試該pb文件

說明:

一、函數freeze_graph中,最重要的就是要肯定「指定輸出的節點名稱」,這個節點名稱必須是原模型中存在的節點,對於freeze操做,咱們須要定義輸出結點的名字。由於網絡實際上是比較複雜的,定義了輸出結點的名字,那麼freeze的時候就只把輸出該結點所須要的子圖都固化下來,其餘無關的就捨棄掉。由於咱們freeze模型的目的是接下來作預測。因此,output_node_names通常是網絡模型最後一層輸出的節點名稱,或者說就是咱們預測的目標。

 二、在保存的時候,經過convert_variables_to_constants函數來指定須要固化的節點名稱,對於鄙人的代碼,須要固化的節點只有一個:output_node_names。注意節點名稱與張量的名稱的區別,例如:「input:0」是張量的名稱,而"input"表示的是節點的名稱。

三、源碼中經過graph = tf.get_default_graph()得到默認的圖,這個圖就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)恢復的圖,所以必須先執行tf.train.import_meta_graph,再執行tf.get_default_graph() 。

四、上面以及說明:在保存的時候,經過convert_variables_to_constants函數來指定須要固化的節點名稱,對於鄙人的代碼,須要固化的節點只有一個:output_node_names。所以,其餘網絡模型,也能夠經過簡單的修改輸出的節點名稱output_node_names,將ckpt轉爲pb文件 。

       PS:注意節點名稱,應包含name_scope 和 variable_scope命名空間,並用「/」隔開,如"InceptionV3/Logits/SpatialSqueeze"

B.4 TensorBoard

  1. 生成graph

# -*- coding: utf-8 -*-
"""
Created on Sat Dec 22 09:49:04 2018

@author: weilong
"""

import tensorflow as tf

#定義簡單的計算圖,實現向量加法的操做
with tf.name_scope("imput1"):
    input1 = tf.constant([1.0, 2.0, 3.0], name="input1")
with tf.name_scope("input2"):
    input2 = tf.Variable(tf.random_uniform([3]), name="input2")
output = tf.add_n([input1, input2], name="add")

#生成寫日誌的writer,並將當前的tensorflow計算圖寫入日誌
writer = tf.summary.FileWriter("./log", tf.get_default_graph())
writer.close()

 2. 將訓練好的model.pb文件在tensorboard中展現其網絡結構

import tensorflow as tf

model = 'model.pb' #請將這裏的pb文件路徑改成本身的
graph = tf.get_default_graph()
graph_def = graph.as_graph_def()
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read())
tf.import_graph_def(graph_def, name='graph')
summaryWriter = tf.summary.FileWriter('log/', graph)

執行以上代碼就會生成文件在log/events.out.tfevents.1535079670.DESKTOP-5IRM000。

 在tensorboard中加載:

tensorboard --logdir=\path\to\log

在瀏覽器中

拷貝網站連接在瀏覽器中便可。

參考:https://blog.csdn.net/guyuealian/article/details/82218092

相關文章
相關標籤/搜索