【Tensorflow1.0】訓練結果的保存與加載

訓練完成之後咱們就能夠直接使用訓練好的模板進行預測了python

可是每次在預測以前都要進行訓練,不是一個常規操做,畢竟有些複雜的模型須要訓練好幾天甚至更久git

因此將訓練好的模型進行保存,當有須要的時候從新加載這個模型進行預測或者繼續訓練,這纔是一個常規操做github

咱們依然使用最簡單的例子進行說明,這裏沿用Tensorflow入門——實現最簡單的線性迴歸模型的預測 這個例子進行session

====================================================dom

模型的保存優化

在tensorflow中保存模型使用的是tf.train.Saver對象,咱們須要在保存以前先實例化這個對象this

saver = tf.train.Saver()

對於模型的保存,其實就是保存整個session對象,再給定一個path就實現了模型的保存(對應的path須要存在,若是不存在會報錯)spa

saver.save(sess, SAVE_PATH + 'model')

保存完成之後,能夠看到對應的目錄下面生成了4個文件.net

model.meta中保存的是模型,而這個模型僅僅是計算流和參數的定義,能夠認爲是一個未經訓練的模型rest

model.index和model.data-00000-of-00001中保存的是參數值,也就是真正訓練的結果

checkpoint中保存的是最後幾回保存的信息,從文件名就能夠看出它是一個檢查點,記錄了其餘幾個文件之間的關係,這是一個txt文件,咱們能夠打開看一下(在這個例子中咱們只保存了一次,若是保存屢次的話這個文件中會記錄屢次保存結果的信息)

下面是運行的log

epoch= 0 _loss= 6029.333 _w= [0.005] _n= [0.005]
epoch= 5000 _loss= 10.897877 _w= [4.2031364] _n= [-1.905781]
epoch= 10000 _loss= 112.455055 _w= [4.7837024] _n= [-11.81817]
epoch= 15000 _loss= 6.2376847 _w= [5.1548934] _n= [-19.740992]
epoch= 20000 _loss= 2.9357195 _w= [5.2787647] _n= [-22.662355]
epoch= 25000 _loss= 0.022824269 _w= [5.3112087] _n= [-23.141117]
epoch= 30000 _loss= 1.3711997 _w= [5.326612] _n= [-23.255548]
epoch= 35000 _loss= 0.005477888 _w= [5.3088646] _n= [-23.289743]
epoch= 40000 _loss= 2.8727396 _w= [5.315157] _n= [-23.191956]
epoch= 45000 _loss= 0.009563584 _w= [5.300157] _n= [-23.18857]
訓練完成,開始預測。。。
x= 0.1610020536371326 y預測= [-22.44688] y實際= -22.401859054114084
x= 7.379937860774309 y預測= [16.030691] y實際= 16.075068797927063
x= 5.1744928042152685 y預測= [4.2754745] y實際= 4.320046646467379
x= 10.26990231423617 y預測= [31.434462] y實際= 31.478579334878784
x= 23.219346463697207 y預測= [100.45616] y實際= 100.49911665150611
x= 7.101197776563807 y預測= [14.544985] y實際= 14.589384149085088
x= 3.097841295090581 y預測= [-6.7932644] y實際= -6.7485058971672025
x= 6.474682013005717 y預測= [11.205599] y實際= 11.250055129320469
x= 13.811264369891983 y預測= [50.310234] y實際= 50.35403909152427
x= 29.260954830177415 y預測= [132.65846] y實際= 132.70088924484563

====================================================

模型的加載

由於保存時分紅了模型和參數值兩部分進行保存,因此在加載模型的時候也須要將模型和參數值(訓練結果)兩步分開進行加載

上面講到了meta文件是模型,checkpoint是參數值,這裏分別使用tf.train下的import_meta_graph和latest_checkpoint方法來加載

saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta')
saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))

這樣,以前保存起來的模型就被咱們從新加載成功了,可是在預測或者繼續訓練以前,咱們須要從新定義相關的變量

可是也不是憑空的從新定義,由於這些參數已經在以前保存的模型中定義過了,咱們只須要從已經加載的模型中將相關參數的定義給找出來就能夠了

爲了找回參數的定義,咱們須要稍微修改一下模型,將這些須要在從新加載階段找回的參數定義給上命名(若是是用來預測,咱們須要找回X和OUT,若是是用來繼續訓練,咱們須要找回X、OUT、loss),因此這裏咱們將模型中相關的參數都給上命名

X = tf.placeholder(tf.float32, name='X')
Y = tf.placeholder(tf.float32, name='Y')

W = tf.Variable(tf.zeros([1]), name='W')
B = tf.Variable(tf.zeros([1]), name='B')
OUT = tf.add(tf.multiply(X, W), B, name='OUT')

loss = tf.reduce_mean(tf.square(Y - OUT), name='loss')
optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)

在找回參數以前,須要獲取計算圖對象(關於計算圖的概念,如今能夠沒必要先了解)

graph = tf.get_default_graph()

而後經過get_all_collection_keys,來查看這個模型中的內容

print(graph.get_all_collection_keys())

能夠看到一共有三項,分別是train_op:優化器,trainable_variables:可訓練的變量,variables:全部變量

['train_op', 'trainable_variables', 'variables']

咱們再經過get_collection方法把這些對象也打印出來看一下

print(graph.get_collection('train_op'))
print(graph.get_collection('trainable_variables'))
print(graph.get_collection('variables'))

可是從中發現,咱們須要找回的參數都不在這裏

[<tf.Operation 'Adam' type=NoOp>]
[<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>]
[<tf.Variable 'W:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'beta1_power:0' shape=() dtype=float32_ref>, <tf.Variable 'beta2_power:0' shape=() dtype=float32_ref>, <tf.Variable 'W/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'W/Adam_1:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam:0' shape=(1,) dtype=float32_ref>, <tf.Variable 'B/Adam_1:0' shape=(1,) dtype=float32_ref>]

繼續經過get_operations方法來查看全部的操做數

print(graph.get_operations())

從如下內容中咱們發現了須要找回的參數X、Y、OUT等

[<tf.Operation 'X' type=Placeholder>, <tf.Operation 'Y' type=Placeholder>, <tf.Operation 'zeros' type=Const>, <tf.Operation 'W' type=VariableV2>, <tf.Operation 'W/Assign' type=Assign>, <tf.Operation 'W/read' type=Identity>, <tf.Operation 'zeros_1' type=Const>, <tf.Operation 'B' type=VariableV2>, <tf.Operation 'B/Assign' type=Assign>, <tf.Operation 'B/read' type=Identity>, <tf.Operation 'Mul' type=Mul>, <tf.Operation 'OUT' type=Add>, <tf.Operation 'sub' type=Sub>, <tf.Operation 'Square' type=Square>, <tf.Operation 'Rank' type=Rank>, <tf.Operation 'range/start' type=Const>, <tf.Operation 'range/delta' type=Const>, <tf.Operation 'range' type=Range>, <tf.Operation 'loss' type=Mean>, <tf.Operation 'gradients/Shape' type=Const>, <tf.Operation 'gradients/grad_ys_0' type=Const>, <tf.Operation 'gradients/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/Shape' type=Shape>, <tf.Operation 'gradients/loss_grad/Size' type=Size>, <tf.Operation 'gradients/loss_grad/add' type=Add>, <tf.Operation 'gradients/loss_grad/mod' type=FloorMod>, <tf.Operation 'gradients/loss_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/loss_grad/range/start' type=Const>, <tf.Operation 'gradients/loss_grad/range/delta' type=Const>, <tf.Operation 'gradients/loss_grad/range' type=Range>, <tf.Operation 'gradients/loss_grad/Fill/value' type=Const>, <tf.Operation 'gradients/loss_grad/Fill' type=Fill>, <tf.Operation 'gradients/loss_grad/DynamicStitch' type=DynamicStitch>, <tf.Operation 'gradients/loss_grad/Maximum/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/loss_grad/Tile' type=Tile>, <tf.Operation 'gradients/loss_grad/Shape_2' type=Shape>, <tf.Operation 'gradients/loss_grad/Shape_3' type=Const>, <tf.Operation 'gradients/loss_grad/Const' type=Const>, <tf.Operation 'gradients/loss_grad/Prod' type=Prod>, <tf.Operation 'gradients/loss_grad/Const_1' type=Const>, <tf.Operation 'gradients/loss_grad/Prod_1' type=Prod>, <tf.Operation 'gradients/loss_grad/Maximum_1/y' type=Const>, <tf.Operation 'gradients/loss_grad/Maximum_1' type=Maximum>, <tf.Operation 'gradients/loss_grad/floordiv_1' type=FloorDiv>, <tf.Operation 'gradients/loss_grad/Cast' type=Cast>, <tf.Operation 'gradients/loss_grad/truediv' type=RealDiv>, <tf.Operation 'gradients/Square_grad/Const' type=Const>, <tf.Operation 'gradients/Square_grad/Mul' type=Mul>, <tf.Operation 'gradients/Square_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/sub_grad/Shape' type=Shape>, <tf.Operation 'gradients/sub_grad/Shape_1' type=Shape>, <tf.Operation 'gradients/sub_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/sub_grad/Sum' type=Sum>, <tf.Operation 'gradients/sub_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/sub_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/sub_grad/Neg' type=Neg>, <tf.Operation 'gradients/sub_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/sub_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/sub_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/OUT_grad/Shape' type=Shape>, <tf.Operation 'gradients/OUT_grad/Shape_1' type=Const>, <tf.Operation 'gradients/OUT_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/OUT_grad/Sum' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/OUT_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/OUT_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/OUT_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/OUT_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'gradients/Mul_grad/Shape' type=Shape>, <tf.Operation 'gradients/Mul_grad/Shape_1' type=Const>, <tf.Operation 'gradients/Mul_grad/BroadcastGradientArgs' type=BroadcastGradientArgs>, <tf.Operation 'gradients/Mul_grad/Mul' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape' type=Reshape>, <tf.Operation 'gradients/Mul_grad/Mul_1' type=Mul>, <tf.Operation 'gradients/Mul_grad/Sum_1' type=Sum>, <tf.Operation 'gradients/Mul_grad/Reshape_1' type=Reshape>, <tf.Operation 'gradients/Mul_grad/tuple/group_deps' type=NoOp>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency' type=Identity>, <tf.Operation 'gradients/Mul_grad/tuple/control_dependency_1' type=Identity>, <tf.Operation 'beta1_power/initial_value' type=Const>, <tf.Operation 'beta1_power' type=VariableV2>, <tf.Operation 'beta1_power/Assign' type=Assign>, <tf.Operation 'beta1_power/read' type=Identity>, <tf.Operation 'beta2_power/initial_value' type=Const>, <tf.Operation 'beta2_power' type=VariableV2>, <tf.Operation 'beta2_power/Assign' type=Assign>, <tf.Operation 'beta2_power/read' type=Identity>, <tf.Operation 'W/Adam/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam' type=VariableV2>, <tf.Operation 'W/Adam/Assign' type=Assign>, <tf.Operation 'W/Adam/read' type=Identity>, <tf.Operation 'W/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'W/Adam_1' type=VariableV2>, <tf.Operation 'W/Adam_1/Assign' type=Assign>, <tf.Operation 'W/Adam_1/read' type=Identity>, <tf.Operation 'B/Adam/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam' type=VariableV2>, <tf.Operation 'B/Adam/Assign' type=Assign>, <tf.Operation 'B/Adam/read' type=Identity>, <tf.Operation 'B/Adam_1/Initializer/zeros' type=Const>, <tf.Operation 'B/Adam_1' type=VariableV2>, <tf.Operation 'B/Adam_1/Assign' type=Assign>, <tf.Operation 'B/Adam_1/read' type=Identity>, <tf.Operation 'Adam/learning_rate' type=Const>, <tf.Operation 'Adam/beta1' type=Const>, <tf.Operation 'Adam/beta2' type=Const>, <tf.Operation 'Adam/epsilon' type=Const>, <tf.Operation 'Adam/update_W/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/update_B/ApplyAdam' type=ApplyAdam>, <tf.Operation 'Adam/mul' type=Mul>, <tf.Operation 'Adam/Assign' type=Assign>, <tf.Operation 'Adam/mul_1' type=Mul>, <tf.Operation 'Adam/Assign_1' type=Assign>, <tf.Operation 'Adam' type=NoOp>, <tf.Operation 'init' type=NoOp>, <tf.Operation 'save/filename/input' type=Const>, <tf.Operation 'save/filename' type=PlaceholderWithDefault>, <tf.Operation 'save/Const' type=PlaceholderWithDefault>, <tf.Operation 'save/SaveV2/tensor_names' type=Const>, <tf.Operation 'save/SaveV2/shape_and_slices' type=Const>, <tf.Operation 'save/SaveV2' type=SaveV2>, <tf.Operation 'save/control_dependency' type=Identity>, <tf.Operation 'save/RestoreV2/tensor_names' type=Const>, <tf.Operation 'save/RestoreV2/shape_and_slices' type=Const>, <tf.Operation 'save/RestoreV2' type=RestoreV2>, <tf.Operation 'save/Assign' type=Assign>, <tf.Operation 'save/Assign_1' type=Assign>, <tf.Operation 'save/Assign_2' type=Assign>, <tf.Operation 'save/Assign_3' type=Assign>, <tf.Operation 'save/Assign_4' type=Assign>, <tf.Operation 'save/Assign_5' type=Assign>, <tf.Operation 'save/Assign_6' type=Assign>, <tf.Operation 'save/Assign_7' type=Assign>, <tf.Operation 'save/restore_all' type=NoOp>]

恢復參數:

這裏須要注意的是在後面須要加上「:0」,表明第0個參數(這個涉及到另外一個概念,之後再細講)

X = graph.get_tensor_by_name('X:0')
Y = graph.get_tensor_by_name('Y:0')
W = graph.get_tensor_by_name('W:0')
B = graph.get_tensor_by_name('B:0')
OUT = graph.get_tensor_by_name('OUT:0')
loss = graph.get_tensor_by_name('loss:0')

恢復優化器:

optimizer = graph.get_collection('train_op')

仍然將以前代碼中的預測和訓練相關的邏輯拷過來執行一下

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from D:/test/tf1/xw_b/model
從新加載,開始預測。。。
x= 26.764991404677083 y預測= [[119.67893]] y實際= 119.39740418692885
x= 25.85141169466281 y預測= [[114.797356]] y實際= 114.52802433255279
x= 17.046457082367727 y預測= [[67.749466]] y實際= 67.59761624901998
x= 5.918111849660451 y預測= [[8.286896]] y實際= 8.283536158690204
x= 7.409698341670607 y預測= [[16.256956]] y實際= 16.233692161104333
x= 15.469762867798304 y預測= [[59.324646]] y實際= 59.19383608536495
x= 11.519144276233455 y預測= [[38.215134]] y實際= 38.13703899232431
x= 27.85137286496477 y預測= [[125.48383]] y實際= 125.18781737026221
x= 26.50150532742774 y預測= [[118.271034]] y實際= 117.99302339518984
x= 15.664275922154658 y預測= [[60.364]] y實際= 60.23059066508432
繼續訓練
epoch= 0 _loss= 16.00476 _w= [5.3422985] _n= [-23.3365]
epoch= 5000 _loss= 19.420956 _w= [5.3203373] _n= [-23.186474]
epoch= 10000 _loss= 0.30325127 _w= [5.3471537] _n= [-23.290209]
epoch= 15000 _loss= 3.018042 _w= [5.32293] _n= [-23.245607]
epoch= 20000 _loss= 12.473472 _w= [5.309146] _n= [-23.24814]
epoch= 25000 _loss= 17.09799 _w= [5.3170156] _n= [-23.342768]
epoch= 30000 _loss= 18.25596 _w= [5.3193855] _n= [-23.225794]
epoch= 35000 _loss= 0.32235628 _w= [5.339825] _n= [-23.196495]
epoch= 40000 _loss= 2.6598516 _w= [5.304051] _n= [-23.248428]
epoch= 45000 _loss= 6.564373 _w= [5.328891] _n= [-23.212101]
繼續訓練完成,開始預測。。。
x= 24.14983880390778 y預測= [[105.329315]] y實際= 105.45864082482846
x= 8.654129156050717 y預測= [[22.795414]] y實際= 22.86650840175032
x= 17.410606725772045 y預測= [[69.434525]] y實際= 69.53853384836499
x= 17.55599000188004 y預測= [[70.20888]] y實際= 70.31342671002061
x= 24.43148021367975 y預測= [[106.82939]] y實際= 106.95978953891309
x= 20.286380740475614 y預測= [[84.751595]] y實際= 84.86640934673503
x= 2.8131286438423353 y預測= [[-8.3151655]] y實際= -8.266024328320354
x= 11.781139561484927 y預測= [[39.450626]] y實際= 39.53347386271466
x= 4.611147529065006 y預測= [[1.2615166]] y實際= 1.3174163299164796
x= 6.625783852577516 y預測= [[11.991955]] y實際= 12.055427934238164

使用恢復之後的模型直接進行預測,匹配程度也很是高,而進行繼續訓練也沒問題

====================================================

完整代碼以下,在python3.6.八、tensorflow1.13環境下成功運行

https://github.com/yukiti2007/sample/blob/master/python/tensorflow/wx_b_save.py

import random

import tensorflow as tf

SAVE_PATH = "D:/test/tf1/xw_b/"


def create_data(for_train=False):
    w = 5.33
    b = -23.26
    x = random.random() * 30
    y = w * x + b

    if for_train:
        noise = (random.random() - 0.5) * 10
        y += noise

    return x, y


def train():
    X = tf.placeholder(tf.float32, name='X')
    Y = tf.placeholder(tf.float32, name='Y')

    W = tf.Variable(tf.zeros([1]), name='W')
    B = tf.Variable(tf.zeros([1]), name='B')
    OUT = tf.add(tf.multiply(X, W), B, name='OUT')

    loss = tf.reduce_mean(tf.square(Y - OUT), name='loss')
    optimizer = tf.train.AdamOptimizer(0.005).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(50000):
            x_data, y_data = create_data(True)
            _, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})
            if 0 == epoch % 5000:
                print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)

        print("訓練完成,開始預測。。。")
        for step in range(10):
            x_data, y_data = create_data(False)
            prediction_value = sess.run(OUT, feed_dict={X: x_data})
            print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)

        saver = tf.train.Saver()
        saver.save(sess, SAVE_PATH + 'model')


def predict():
    sess = tf.Session()
    saver = tf.train.import_meta_graph(SAVE_PATH + 'model.meta')
    saver.restore(sess, tf.train.latest_checkpoint(SAVE_PATH))

    graph = tf.get_default_graph()
    X = graph.get_tensor_by_name('X:0')
    Y = graph.get_tensor_by_name('Y:0')
    W = graph.get_tensor_by_name('W:0')
    B = graph.get_tensor_by_name('B:0')
    OUT = graph.get_tensor_by_name('OUT:0')
    loss = graph.get_tensor_by_name('loss:0')
    optimizer = graph.get_collection('train_op')
    # print(graph.get_all_collection_keys())
    # print(graph.get_collection('train_op'))
    # print(graph.get_collection('trainable_variables'))
    # print(graph.get_collection('variables'))
    # print(graph.get_operations())

    print("從新加載,開始預測。。。")
    for step in range(10):
        x_data, y_data = create_data(False)
        prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})
        print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)

    print("繼續訓練")
    for epoch in range(50000):
        x_data, y_data = create_data(True)
        _, _loss, _w, _b = sess.run([optimizer, loss, W, B], feed_dict={X: x_data, Y: y_data})
        if 0 == epoch % 5000:
            print("epoch=", epoch, "_loss=", _loss, "_w=", _w, "_n=", _b)

    print("繼續訓練完成,開始預測。。。")
    for step in range(10):
        x_data, y_data = create_data(False)
        prediction_value = sess.run(OUT, feed_dict={X: [[x_data]]})
        print("x=", x_data, "y預測=", prediction_value, "y實際=", y_data)


if __name__ == "__main__":
    train()
    predict()
相關文章
相關標籤/搜索