tensorflow筆記(四)之MNIST手寫識別系列一

tensorflow筆記(四)之MNIST手寫識別系列一
html

版權聲明:本文爲博主原創文章,轉載請指明轉載地址python

http://www.cnblogs.com/fydeblog/p/7436310.html
git

前言

這篇博客將利用神經網絡去訓練MNIST數據集,經過學習到的模型去分類手寫數字。github

我會將本篇博客的jupyter notebook放在最後,方便你下載在線調試!推薦結合官方的tensorflow教程來看這個notebook!ubuntu

1. MNIST數據集的導入

這裏介紹一下MNIST,MNIST是在機器學習領域中的一個經典問題。該問題解決的是把28x28像素的灰度手寫數字圖片識別爲相應的數字,其中數字的範圍從0到9.網絡

首先咱們要導入MNIST數據集,這裏須要用到一個input_data.py文件,在你安裝tensorflow的examples/tutorials/MNIST目錄下,若是tensorflow的目錄下沒有這個文件夾(通常是你的tensorflow版本不夠新,1.2版本有的),還請本身導入或者更新一下tensorflow的版本,導入的方法是在tensorflow的github(https://github.com/tensorflow/tensorflow/tree/master/tensorflow  )下下載examples文件夾,粘貼到tensorflow的根目錄下。更新tensorflow版本的話,請在ubuntu終端下運行pip install --upgrade tensorflow就能夠了dom

好了,咱們仍是一步步來進行整個過程機器學習

首先咱們先導入咱們須要用到的模塊函數

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist  import  input_data

而後咱們用input_data模塊導入MNIST數據集學習

mnist = input_data.read_data_sets('MNIST_data',one_hot = True)

上面總共下載了四個壓縮文件,內容分別以下:
train-images-idx3-ubyte.gz    訓練集圖片 - 55000 張 訓練圖片, 5000 張 驗證圖片
train-labels-idx1-ubyte.gz      訓練集圖片對應的數字標籤
t10k-images-idx3-ubyte.gz   測試集圖片 - 10000 張 圖片
t10k-labels-idx1-ubyte.gz      測試集圖片對應的數字標籤

圖片數據將被解壓成2維的tensor:[image index, pixel index] 其中每一項表示某一圖片中特定像素的強度值, 範圍從 [0, 255] 到 [-0.5, 0.5]。 "image index"表明數據集中圖片的編號, 從0到數據集的上限值。"pixel index"表明該圖片中像素點得個數, 從0到圖片的像素上限值。

以train-*開頭的文件中包括60000個樣本,其中分割出55000個樣本做爲訓練集,其他的5000個樣本做爲驗證集。由於全部數據集中28x28像素的灰度圖片的尺寸爲784,因此訓練集輸出的tensor格式爲[55000, 784]

執行read_data_sets()函數將會返回一個DataSet實例,其中包含了如下三個數據集。 數據集 目的 data_sets.train 55000 組 圖片和標籤, 用於訓練。 data_sets.validation 5000 組 圖片和標籤, 用於迭代驗證訓練的準確性。 data_sets.test 10000 組 圖片和標籤, 用於最終測試訓練的準確性。

具體的MNIST數據集的解壓和重構咱們能夠不瞭解,會用這個數據集就能夠了。(固然別問我這個東西,這個過程我也不知道,嘿嘿)

這裏說一下上述代碼中的one_hot,MNIST的標籤數據是"one-hot vectors"。 一個one-hot向量除了某一位的數字是1之外其他各維度數字都是0。因此在此教程中,數字n將表示成一個只有在第n維度(從0開始)數字爲1的10維向量。好比,標籤0將表示成([1,0,0,0,0,0,0,0,0,0,0])。

2.實踐

咱們首先定義兩個佔位符,來表示訓練數據及其相應標籤數據,將會在訓練部分進行feed進去

xs = tf.placeholder(tf.float32,[None,784]) # 784 = 28X28
ys = tf.placeholder(tf.float32,[None,10]) # 10 = (0~9) one_hot

如今咱們再來定義神經網絡的權重和誤差

Weights = tf.Variable(tf.random_normal([784,10]))
biases = tf.Variable(tf.zeros([1,10])+0.2)

先說一下,這個神經網絡是輸入直接映射到輸出,沒有隱藏層,輸入是每張圖像28X28的像素,也就是784,輸出是10個長度的向量,也就是10,因此權重是[784,10],誤差是[1,10].

y_pre = tf.nn.softmax(tf.matmul(xs,Weights)+biases) 

咱們知道雖然最後的輸出結果是10個長度的向量,但他們的值可能不太直觀,打個比方,好比都是0.015之類的數,僅僅是打比方哈

爲了顯示輸出結果對每一個數的相應機率,咱們加了一個softmax函數,它的原理很簡單,拿10個單位的向量[x0,x1,...,x9]爲例,若是想知道數字0的機率是多少,用exp(x0)/(exp(x0)+exp(x1)+...+exp(x9)),其餘數字的機率相似推導,你也能夠參考我放在博客上的圖片,很直觀。

cross_entropy =tf.reduce_mean( -tf.reduce_sum(ys*tf.log(y_pre),reduction_indices=[1]))#compute cross_entropy

此次的損失表示形式跟以前都不太同樣哈,此次是計算交叉熵,交叉熵是用來衡量咱們的預測用於描述真相的有效性。咱們能夠想想,以一張圖片爲例,y_pre和ys都是一個10個長度的向量,不一樣的是y_pre每一個序號對應的值不爲0,而ys是one_hot向量,只有一個爲1,其他全爲0,那麼按照上述公式,只有1對應序號i(假如是i)的log(y_pre(i))保留下來了,並且y_pre(i)越大(也就是機率越大),log(y_pre(i))越小(注意計算交叉熵前面有負號的),反之越大,符合咱們對損失的概念。

我試過用官方教程的交叉熵公式,打印交叉熵時出現nan,溢出了,建議用這個好一些

train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

上面是用優化器最小化交叉熵,這裏學習率的選取很重要,官方的0.01過小,收斂得慢,還沒達到訓練損失最小就中止了,結果就是測試集偏差較大,推薦選大點,0.5左右差很少了,再大反而會發散了。

init = tf.global_variables_initializer()

上面是生出初始化init

sess  = tf.Session()

創建一個會話

sess.run(init)

初始化變量

for i in xrange(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train,feed_dict={xs:batch_xs,ys:batch_ys})
    if i %50==0:
        print sess.run(cross_entropy,feed_dict={xs:batch_xs,ys:batch_ys})

上面是程序訓練過程,這裏說一下xrange和range的區別,它們兩個的用法基本相同,但返回的類型不一樣,xrange返回的是生成器,range返回的是列表,全部xrange更節省內存,推薦用xrange,python3當中已經沒有xrange了,只有range,但它的功能和python2當中的xrange同樣

下面咱們來計算計算精度

correct_prediction = tf.equal(tf.argmax(ys,1), tf.argmax(y_pre,1))

tf.argmax 是一個很是有用的函數,它能給出某個tensor對象在某一維上的其數據最大值所在的索引值。tf.argmax(y_pre,1)返回的是模型對於任一輸入x預測到的標籤值,而 tf.argmax(ys,1) 表明正確的標籤,咱們能夠用 tf.equal 來檢測咱們的預測是否真實標籤匹配,這行代碼返回的是匹配的布爾值,成功1,失敗0

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

tf.cast將布爾類型的correct_prediction轉化成float型,而後取平均獲得精確度

print sess.run(accuracy, feed_dict={xs: mnist.test.images, ys: mnist.test.labels})

精確度87.79%,官方說的91%我是沒達到過,我訓練最高不超過89%。

3.結尾

但願這篇博客能對你的學習有所幫助,謝謝觀看!同時,有興趣的朋友能夠多改改參數試試不一樣的結果,好比學習率,batch_size等等,這對你的理解也是有幫助的!

下一篇筆記將寫用cnn去分類MNIST數據集,敬請期待!

連接: https://pan.baidu.com/s/1oWXk2Iai5f7I4U411XP8hQ

相關文章
相關標籤/搜索