TensorFlow(四):手寫數字識別

一:數據集數組

採用MNIST數據集:--》官網網絡

數據集被分紅兩部分:60000行的訓練數據集和10000行的測試數據集。函數

其中每一張圖片包含28*28個像素,咱們把這個數組展開成一個向量,長度爲28*28=784.在MNIST訓練數據集中mnist.train.images是一個形狀爲[60000,784]的張量,第一個維度數字用來索引圖片,第二個維度數字用來索引每張圖片中的像素點。圖片裏的某個像素的強度值介於0-1之間。測試

MNIST數據集的標籤是介於0-9的數字,咱們把便籤轉化爲‘one-hot vectors’.一個one-hot向量除了某一位數字1之外,其他維度數字都是0.好比標籤0將表示爲([1,0,0,0,0,0,0,0,0,0,0]),標籤3表示爲([0,0,0,1,0,0,0,0,0,0]).因此標籤至關於[60000,10]的數字矩陣。spa

咱們的結果是0-9,咱們的模型可能推測出一張圖片是數字9的機率爲80%,是數字8的機率爲10%,而後其餘數字的機率更小,整體機率加起來等於1.這至關於一個使用softmax迴歸模型的案例。code

下面使用softmax模型來預測:blog

# MNIST數據集 手寫數字
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# 載入數據集,若是沒有下載,程序會自動下載
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
# 每一個批次的大小
batch_size=100
# 計算一共有多少個批次
n_batch=mnist.train.num_examples//batch_size

# 定義兩個placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])

# 建立一個簡單的神經網絡
W=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))
prediction=tf.nn.softmax(tf.matmul(x,W)+b)

# 二次代價函數
loss=tf.reduce_mean(tf.square(y-prediction))
# 使用梯度降低法
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)

# 初始化變量
init=tf.global_variables_initializer()
# 求最大值在哪一個位置,結果存放在一個布爾值列表中
correct_prediction=tf.equal(tf.argmax(y,1),tf.arg_max(prediction,1))# argmax返回一維張量中最大值所在的位置
# 求準確率
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # cast做用是將布爾值轉換爲浮點型。
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21): # 訓練20次
        for batch in range(n_batch): # 每次喂入必定的數據
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        #求準確率
        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print('Iter:'+str(epoch)+',Testing Accuracy:'+str(acc))
        
# 結果
# 能夠看出每次訓練準確率都在提升

Iter:0,Testing Accuracy:0.8301
Iter:1,Testing Accuracy:0.8706
Iter:2,Testing Accuracy:0.8811
Iter:3,Testing Accuracy:0.8883
Iter:4,Testing Accuracy:0.8943
Iter:5,Testing Accuracy:0.8966
Iter:6,Testing Accuracy:0.9002
Iter:7,Testing Accuracy:0.9017
Iter:8,Testing Accuracy:0.9043
Iter:9,Testing Accuracy:0.9052
Iter:10,Testing Accuracy:0.9061
Iter:11,Testing Accuracy:0.9071
Iter:12,Testing Accuracy:0.908
Iter:13,Testing Accuracy:0.9096
Iter:14,Testing Accuracy:0.9094
Iter:15,Testing Accuracy:0.9102
Iter:16,Testing Accuracy:0.9116
Iter:17,Testing Accuracy:0.9119
Iter:18,Testing Accuracy:0.9126
Iter:19,Testing Accuracy:0.9134
Iter:20,Testing Accuracy:0.9136
相關文章
相關標籤/搜索