1 import tensorflow as tf; 2 from tensorflow.examples.tutorials.mnist import input_data 3 4 ##定義網絡結構 5 input_nodes = 784 6 output_nodes = 10 7 layer1_nodes = 500 8 #定義超參數 9 #自動設置學習率 10 learning_rate_base= 0.8; 11 learning_decay = 0.99 ; 12 decay_step=100 ; 13 14 #滑動平均 15 moving_average__decay = 0.99 16 regularizer_rate = 0.0001; 17 train_step=30000 18 batch_size= 100 19 20 21 def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None): 22 if(average_class==None): 23 layer1=tf.nn.relu( tf.matmul(tensor1,weight1)+ bias1 ) 24 return tf.matmul( layer1,weight2 ) + bias2 25 else: 26 layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1)) 27 return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2) 28 29 def get_weight(shape): 30 weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32) 31 tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer_rate)(weight)) 32 return weight 33 34 def get_bias(shape): 35 return tf.Variable(tf.zeros(shape)) 36 37 def train(mnist): 38 #定義輸入輸出 39 train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name='train_x') 40 train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name='train_y' ) 41 42 weight1=get_weight( [input_nodes,layer1_nodes] ) 43 bias1 =get_bias([layer1_nodes]) 44 45 weight2=get_weight([layer1_nodes,output_nodes]); 46 bias2 =get_bias([output_nodes]) 47 results = inference(train_x, weight1, bias1, weight2, bias2, None) 48 49 #定義學習率 50 global_step = tf.Variable(0, trainable=False) 51 learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, mnist.train.num_examples / batch_size, learning_decay,staircase=True) 52 53 #定義損失、優化器 54 55 ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=results,labels=tf.argmax( train_y,1) ) ) 56 loss=ce+tf.add_n( tf.get_collection('losses') ) 57 tf.summary.scalar('lost',loss) 58 59 optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step); 60 61 #定義滑動平均 62 ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step); 63 maintain_average_op = ema.apply( tf.trainable_variables()) 64 with tf.control_dependencies([optimizer,maintain_average_op]): 65 train_op=tf.no_op(name='train') 66 67 #預測準確率 68 average_y=inference(train_x,weight1,bias1,weight2,bias2,ema); 69 correction_prediction = tf.equal( tf.argmax( average_y,1 ) ,tf.argmax(train_y,1)) 70 accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32)); 71 72 with tf.Session() as sess: 73 tf.global_variables_initializer().run() 74 75 validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels} 76 test_feed ={train_x:mnist.test.images,train_y:mnist.test.labels} 77 78 #彙總 79 merged_summary_op = tf.summary.merge_all() 80 summaryWriter = tf.summary.FileWriter('./log/mnist_with_summaries',sess.graph) 81 82 #迭代訓練 83 for i in range(train_step): 84 if(i%1000 == 0 ): 85 validate_acc=sess.run(accuracy,feed_dict=validate_feed); 86 print('After %d training steps,using aaverage model is %g '%(i,validate_acc)) 87 88 xt,yt=mnist.train.next_batch(batch_size); 89 sess.run( train_op,feed_dict={ train_x :xt,train_y:yt} ); 90 summary_str=sess.run( merged_summary_op,feed_dict={ train_x :xt,train_y:yt} ); 91 summaryWriter.add_summary(summary_str,i) 92 93 94 test_acc=sess.run(accuracy,feed_dict=test_feed) 95 print('accuracy is %g'%(test_acc)); 96 def main(): 97 mnist= input_data.read_data_sets('./MNIST_data',one_hot=True) 98 train(mnist); 99 100 if __name__ == '__main__': 101 main()