TensorFlow2.0(12):模型保存與序列化

 

注:本系列全部博客將持續更新併發布在github上,您能夠經過github下載本系列全部文章筆記文件。javascript

 

模型訓練好以後,咱們就要想辦法將其持久化保存下來,否則關機或者程序退出後模型就不復存在了。本文介紹兩種持久化保存模型的方法:css

 

在介紹這兩種方法以前,咱們得先建立並訓練好一個模型,仍是以mnist手寫數字識別數據集訓練模型爲例:html

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, Sequential
In [2]:
model = Sequential([  # 建立模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
history = model.fit(x_train, y_train,  # 進行簡單的1次迭代訓練
                    batch_size=64,
                    epochs=1)
 
Train on 60000 samples
60000/60000 [==============================] - 3s 46us/sample - loss: 2.3700
 

方法一:model.save()

 

經過模型自帶的save()方法能夠將模型保存到一個指定文件中,保存的內容包括:html5

  • 模型的結構
  • 模型的權重參數
  • 經過compile()方法配置的模型訓練參數
  • 優化器及其狀態
In [3]:
model.save('mymodels/mnist.h5')
 

使用save()方法保存後,在mymodels目錄下就會有一個mnist.h5文件。須要使用模型時,經過keras.models.load_model()方法從文件中再次加載便可。java

In [4]:
new_model = keras.models.load_model('mymodels/mnist.h5')
 
WARNING:tensorflow:Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.
 

新加載出來的new_model在結構、功能、參數各方面與model是同樣的。node

 

經過save()方法,也能夠將模型保存爲SavedModel 格式。SavedModel格式是TensorFlow所特有的一種序列化文件格式,其餘編程語言實現的TensorFlow中一樣支持:python

In [5]:
model.save('mymodels/mnist_model', save_format='tf')  # 將模型保存爲SaveModel格式
 
WARNING:tensorflow:From /home/chb/anaconda3/envs/study_python/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: mymodels/mnist_model/assets
In [6]:
new_model = keras.models.load_model('mymodels/mnist_model')  # 加載模型
 

方法二:model.save_weights()

 

save()方法會保留模型的全部信息,但有時候,咱們僅對部分信息感興趣,例如僅對模型的權重參數感興趣,那麼就能夠經過save_weights()方法進行保存。jquery

In [14]:
model.save_weights('mymodels/mnits_weights')  # 保存模型權重信息
In [15]:
new_model = Sequential([  # 建立新的模型
    layers.Dense(256, activation=tf.nn.relu),
    layers.Dense(128, activation=tf.nn.relu),
    layers.Dense(64, activation=tf.nn.relu),
    layers.Dense(32, activation=tf.nn.relu),
    layers.Dense(10)
    ]
)
new_model.compile(loss='sparse_categorical_crossentropy',
              optimizer=keras.optimizers.RMSprop())
new_model.load_weights('mymodels/mnits_weights')  # 將保存好的權重信息加載的新的模型中
Out[15]:
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f49c42b87d0>
相關文章
相關標籤/搜索