一:數據集數組
採用MNIST數據集:--》官網網絡
數據集被分紅兩部分:60000行的訓練數據集和10000行的測試數據集。函數
其中每一張圖片包含28*28個像素,咱們把這個數組展開成一個向量,長度爲28*28=784.在MNIST訓練數據集中mnist.train.images是一個形狀爲[60000,784]的張量,第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。圖片裏的某個像素的強度值介於0-1之間。測試
MNIST數據集的標籤是介於0-9的數字,咱們把便籤轉化爲‘one-hot vectors’.一個one-hot向量除了某一位數字1之外,其他維度數字都是0.好比標籤0將表示爲([1,0,0,0,0,0,0,0,0,0,0]),標籤3表示爲([0,0,0,1,0,0,0,0,0,0]).因此標籤至關於[60000,10]的數字矩陣。spa
咱們的結果是0-9,咱們的模型可能推測出一張圖片是數字9的機率爲80%,是數字8的機率爲10%,而後其餘數字的機率更小,整體機率加起來等於1.這至關於一個使用softmax迴歸模型的案例。code
下面使用softmax模型來預測:blog
# MNIST數據集 手寫數字 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 載入數據集,若是沒有下載,程序會自動下載 mnist=input_data.read_data_sets('MNIST_data',one_hot=True) # 每一個批次的大小 batch_size=100 # 計算一共有多少個批次 n_batch=mnist.train.num_examples//batch_size # 定義兩個placeholder x=tf.placeholder(tf.float32,[None,784]) y=tf.placeholder(tf.float32,[None,10]) # 建立一個簡單的神經網絡 W=tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) prediction=tf.nn.softmax(tf.matmul(x,W)+b) # 二次代價函數 loss=tf.reduce_mean(tf.square(y-prediction)) # 使用梯度降低法 train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss) # 初始化變量 init=tf.global_variables_initializer() # 求最大值在哪一個位置,結果存放在一個布爾值列表中 correct_prediction=tf.equal(tf.argmax(y,1),tf.arg_max(prediction,1))# argmax返回一維張量中最大值所在的位置 # 求準確率 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # cast做用是將布爾值轉換爲浮點型。 with tf.Session() as sess: sess.run(init) for epoch in range(21): # 訓練20次 for batch in range(n_batch): # 每次喂入必定的數據 batch_xs,batch_ys=mnist.train.next_batch(batch_size) sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) #求準確率 acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print('Iter:'+str(epoch)+',Testing Accuracy:'+str(acc))
# 結果 # 能夠看出每次訓練準確率都在提升 Iter:0,Testing Accuracy:0.8301 Iter:1,Testing Accuracy:0.8706 Iter:2,Testing Accuracy:0.8811 Iter:3,Testing Accuracy:0.8883 Iter:4,Testing Accuracy:0.8943 Iter:5,Testing Accuracy:0.8966 Iter:6,Testing Accuracy:0.9002 Iter:7,Testing Accuracy:0.9017 Iter:8,Testing Accuracy:0.9043 Iter:9,Testing Accuracy:0.9052 Iter:10,Testing Accuracy:0.9061 Iter:11,Testing Accuracy:0.9071 Iter:12,Testing Accuracy:0.908 Iter:13,Testing Accuracy:0.9096 Iter:14,Testing Accuracy:0.9094 Iter:15,Testing Accuracy:0.9102 Iter:16,Testing Accuracy:0.9116 Iter:17,Testing Accuracy:0.9119 Iter:18,Testing Accuracy:0.9126 Iter:19,Testing Accuracy:0.9134 Iter:20,Testing Accuracy:0.9136