轉載請註明出處:html
http://www.javashuo.com/article/p-ocbrtjgk-t.htmlgit
代碼網址:github
https://github.com/darkknightzh/trainEagerMnist網絡
參考網址:app
https://github.com/tensorflow/models/blob/master/official/mnist/mnist_eager.py函數
tensorflow使用eager時,須要下面幾句話(若是不使用第三句話,則依舊能夠使用靜態圖):rest
import tensorflow as tf import tensorflow.contrib.eager as tfe tfe.enable_eager_execution()
tensorflow使用eager模式後,感受和pytorch同樣方便。使用eager後,不須要tf.placeholder,用起來更加方便。code
目前貌似tf.keras.layers和tf.layers支持eager,slim不支持。orm
整體流程以下:
initial optimizer for I in range(epochs): for imgs, targets in training_data: with tf.GradientTape() as tape: logits = model(imgs, training=True) loss_value = calc_loss(logits, targets) grads = tape.gradient(loss_value, model.variables) optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter) update training_accurate, total_loss test model save model
能夠使用下面三種方式建立模型
先在__init__中定義用到的層,而後重載call函數,構建網絡。模型前向計算時,會調用call函數。以下面代碼所示:
1 class simpleModel(tf.keras.Model): 2 def __init__(self, num_classes): 3 super(simpleModel, self).__init__() 4 5 input_shape = [28, 28, 1] 6 data_format = 'channels_last' 7 self.reshape = tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],)) 8 9 self.conv1 = tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu') 10 self.batch1 = tf.keras.layers.BatchNormalization() 11 self.pool1 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 12 13 self.conv2 = tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu') 14 self.batch2 = tf.keras.layers.BatchNormalization() 15 self.pool2 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 16 17 self.conv3 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu') 18 self.batch3 = tf.keras.layers.BatchNormalization() 19 self.pool3 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 20 21 self.conv4 = tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu') 22 self.batch4 = tf.keras.layers.BatchNormalization() 23 self.pool4 = tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 24 25 self.flat = tf.keras.layers.Flatten() 26 self.fc5 = tf.keras.layers.Dense(1024, activation='relu') 27 self.batch5 = tf.keras.layers.BatchNormalization() 28 29 self.fc6 = tf.keras.layers.Dense(num_classes) 30 self.batch6 = tf.keras.layers.BatchNormalization() 31 32 def call(self, inputs, training=None): 33 x = self.reshape(inputs) 34 35 x = self.conv1(x) 36 x = self.batch1(x, training=training) 37 x = self.pool1(x) 38 39 x = self.conv2(x) 40 x = self.batch2(x, training=training) 41 x = self.pool2(x) 42 43 x = self.conv3(x) 44 x = self.batch3(x, training=training) 45 x = self.pool3(x) 46 47 x = self.conv4(x) 48 x = self.batch4(x, training=training) 49 x = self.pool4(x) 50 51 x = self.flat(x) 52 x = self.fc5(x) 53 x = self.batch5(x, training=training) 54 55 x = self.fc6(x) 56 x = self.batch6(x, training=training) 57 # x = tf.layers.dropout(x, rate=0.3, training=training) 58 return x 59 60 def get_acc(self, target): 61 correct_prediction = tf.equal(tf.argmax(self.logits, 1), tf.argmax(target, 1)) 62 acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 63 return acc 64 65 def get_loss(self): 66 return self.loss 67 68 def loss_fn(self, images, target, training): 69 self.logits = self(images, training) # call call(self, inputs, training=None) function 70 self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=target)) 71 return self.loss 72 73 def grads_fn(self, images, target, training): # do not return loss and acc if unnecessary 74 with tfe.GradientTape() as tape: 75 loss = self.loss_fn(images, target, training) 76 return tape.gradient(loss, self.variables)
以下面代碼所示:
1 def create_model1(): 2 data_format = 'channels_last' 3 input_shape = [28, 28, 1] 4 l = tf.keras.layers 5 max_pool = l.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format) 6 # The model consists of a sequential chain of layers, so tf.keras.Sequential (a subclass of tf.keras.Model) makes for a compact description. 7 return tf.keras.Sequential( 8 [ 9 l.Reshape(target_shape=input_shape, input_shape=(28 * 28,)), 10 l.Conv2D(16, 5, padding='same', data_format=data_format, activation=tf.nn.relu), 11 l.BatchNormalization(), 12 max_pool, 13 14 l.Conv2D(32, 5, padding='same', data_format=data_format, activation=tf.nn.relu), 15 l.BatchNormalization(), 16 max_pool, 17 18 l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu), 19 l.BatchNormalization(), 20 max_pool, 21 22 l.Conv2D(64, 5, padding='same', data_format=data_format, activation=tf.nn.relu), 23 l.BatchNormalization(), 24 max_pool, 25 26 l.Flatten(), 27 l.Dense(1024, activation=tf.nn.relu), 28 l.BatchNormalization(), 29 30 # # l.Dropout(0.4), 31 l.Dense(10), 32 l.BatchNormalization() 33 ])
以下面代碼所示:
1 def create_model2(): 2 data_format = 'channels_last' 3 input_shape = [28, 28, 1] 4 5 model = tf.keras.Sequential() 6 7 model.add(tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(input_shape[0] * input_shape[1],))) 8 9 model.add(tf.keras.layers.Conv2D(16, 5, padding="same", activation='relu')) 10 model.add(tf.keras.layers.BatchNormalization()) 11 model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)) 12 13 model.add(tf.keras.layers.Conv2D(32, 5, padding="same", activation='relu')) 14 model.add(tf.keras.layers.BatchNormalization()) 15 model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)) 16 17 model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')) 18 model.add(tf.keras.layers.BatchNormalization()) 19 model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)) 20 21 model.add(tf.keras.layers.Conv2D(64, 5, padding="same", activation='relu')) 22 model.add(tf.keras.layers.BatchNormalization()) 23 model.add(tf.keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same', data_format=data_format)) 24 25 model.add(tf.keras.layers.Flatten()) 26 model.add(tf.keras.layers.Dense(1024, activation='relu')) 27 model.add(tf.keras.layers.BatchNormalization()) 28 29 model.add(tf.keras.layers.Dense(10)) 30 model.add(tf.keras.layers.BatchNormalization()) 31 32 return model
在更新梯度時,須要加上下面的幾句話
1 with tf.GradientTape() as tape: 2 logits = model(imgs, training=True) 3 loss_value = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labs)) 4 grads = tape.gradient(loss_value, model.variables) 5 optimizer.apply_gradients(zip(grads, model.variables), global_step=step_counter)
第二行獲得特徵,第三行獲得損失,第四行獲得梯度,第五行將梯度應用到模型,更新模型參數。
代碼以下
1 def saveModelV1(model_dir, model, global_step, modelname='model1'): 2 tfe.Saver(model.variables).save(os.path.join(model_dir, modelname), global_step=global_step) 3 def restoreModelV1(model_dir, model): 4 dummy_input = tf.constant(tf.zeros((1, 28, 28, 1))) # Run the model once to initialize variables 5 dummy_pred = model(dummy_input, training=False) 6 7 saver = tfe.Saver(model.variables) # Restore the variables of the model 8 saver.restore(tf.train.latest_checkpoint(model_dir))
代碼以下
1 step_counter = tf.train.get_or_create_global_step() 2 checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer, step_counter=step_counter) 3 4 def saveModelV2(model_dir, checkpoint, modelname='model2'): 5 checkpoint_prefix = os.path.join(model_dir, modelname) 6 checkpoint.save(checkpoint_prefix) 7 8 def restoreModelV2(model_dir, checkpoint): 9 checkpoint.restore(tf.train.latest_checkpoint(model_dir))
代碼未嚴格按照整體流程的步驟,僅供參考,見https://github.com/darkknightzh/trainEagerMnist
其中eagerFlag爲使用eager的方式,0爲不使用eager(使用靜態圖),1爲使用V1的方式,2爲使用V2的方式。當使用靜態圖時,不要加tfe.enable_eager_execution(),不然會報錯。具體可參考代碼。