1.準備數據,使用佔位符,動態加載訓練數據git
x=tf.placeholder(tf.float32,[None,784]) y_true=tf.placeholder(tf.int32,[None,10])
2.初始化參數,創建模型dom
weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0)) bias=tf.Variable(tf.canstant(0.0,shape=[10])) y_predict=tf.matmul(x,weight)+bias
3.求平均交叉熵損失優化
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))
4.梯度降低優化scala
train_op=tf.GradientDescentOptimizer(0.3).minimize(loss)
5.求準確率rest
equal_list=tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1)) accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))
完整代碼:code
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os mnist = input_data.read_data_sets('./data/MNISI_data/', one_hot=True) def full_connection(): # 1.準備數據 with tf.variable_scope("data"): x = tf.placeholder(tf.float32, [None, 784]) y_true = tf.placeholder(tf.int32, [None, 10]) # 2.創建模型 with tf.variable_scope('predict_model'): weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w') bias = tf.Variable(tf.constant(0.0, shape=[10])) y_predict = tf.matmul(x, weight) + bias # 3.平均交叉熵損失 with tf.variable_scope('loss'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict)) # 4.梯度降低優化 with tf.variable_scope('optimizer'): train_op = tf.train.GradientDescentOptimizer(0.4).minimize(loss) # 5.求準確率 with tf.variable_scope('acc'): equal_list = tf.equal(tf.arg_max(y_true, 1), tf.arg_max(y_predict, 1)) accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32)) init_op = tf.initialize_all_variables() # 收集變量,tensorboard使用 tf.summary.scalar('loss', loss) tf.summary.scalar('accuracy', accuracy) tf.summary.histogram('weight', weight) tf.summary.histogram('bias', bias) merged = tf.summary.merge_all() saver = tf.train.Saver() is_train = False with tf.Session() as sess: if is_train == True: sess.run(init_op) fileWriter = tf.summary.FileWriter('./temp/summary/test', graph=sess.graph) if os.path.exists('./temp/ckpt/checkpoint'): # 加載訓練的模型 saver.restore(sess, './temp/ckpt/full_conn') for i in range(4000): # 每次批量貨期50個數據集 mnist_x, mnist_y = mnist.train.next_batch(50) sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y}) summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y}) fileWriter.add_summary(summary, i) print("訓練低%d步,準確率爲:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y}))) # 保存訓練完的模型 saver.save(sess, './temp/ckpt/full_conn') else: saver.restore(sess, './temp/ckpt/full_conn') for i in range(100): # 每次批量貨期1個數據集 x_test, y_test = mnist.test.next_batch(1) print('低%d張圖片,手寫數字圖片目標:%d--%d' % ( i, tf.arg_max(y_test, 1).eval(), tf.arg_max(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval() )) if __name__ == '__main__': full_connection()