【tf.keras】TensorFlow 1.x 到 2.0 的 API 變化

TensorFlow 2.0 版本將 keras 做爲高級 API,對於 keras boy/girl 來講,這就很友好了。tf.keras 從 1.x 版本遷移到 2.0 版本,須要修改幾個地方。python

1. 設置隨機種子

import tensorflow as tf

# TF 1.x
tf.set_random_seed(args.seed)
# TF 2.0
tf.random.set_seed(args.seed)

2. 設置並行線程數和動態分配顯存

import tensorflow as tf
from tensorflow.python.keras import backend as K

# TF 1.x
config = tf.ConfigProto(intra_op_parallelism_threads=1,
                         inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True  # 不所有佔滿顯存, 按需分配
K.set_session(tf.Session(config=config))

# TF 2.0
config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1,
                                  inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True  # 不所有佔滿顯存, 按需分配
K.set_session(tf.compat.v1.Session(config=config))

3. model.fit() 生成的 log 中,acc 更名 accuracy,val_acc 更名 val_accuracy。故在 callbacks.ModelCheckpoint 中須要作修改:

from tensorflow.python.keras import callbacks

# TF 1.x
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_acc', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)

# TF 2.0
ck_callback = callbacks.ModelCheckpoint('./model.h5', monitor='val_accuracy', mode='max',
                                            verbose=1, save_best_only=True, save_weights_only=True)
相關文章
相關標籤/搜索