轉載請註明做者:夢裏風林
Google Machine Learning Recipes 7
官方中文博客 - 視頻地址
Github工程地址 https://github.com/ahangchen/GoogleML
歡迎Star,也歡迎到Issue區討論html
mnist = learn.datasets.load_dataset('mnist')
恩,就是這麼簡單,一行代碼下載解壓mnist數據,每一個img已經灰度化成長784的數組,每一個label已經one-hot成長度10的數組python
在個人深度學習筆記看One-hot是什麼東西git
data = mnist.train.images labels = np.asarray(mnist.train.labels, dtype=np.int32) test_data = mnist.test.images test_labels = np.asarray(mnist.test.labels, dtype=np.int32) max_examples = 10000 data = data[:max_examples] labels = labels[:max_examples]
def display(i): img = test_data[i] plt.title('Example %d. Label: %d' % (i, test_labels[i])) plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r) plt.show()
用matplotlib展現灰度圖github
feature_columns = learn.infer_real_valued_columns_from_input(data)
classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10) classifier.fit(data, labels, batch_size=100, steps=1000)
注意要制定n_classes爲labels的數量web
最後可能性最高的label就會做爲預測輸出chrome
傳入測試集,預測,評估分類效果docker
result = classifier.evaluate(test_data, test_labels) print result["accuracy"]
速度很是快,並且準確率達到91.4%數組
能夠只預測某張圖,並查看預測是否跟實際圖形一致瀏覽器
# here's one it gets right print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0])) display(0) # and one it gets wrong print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8])) display(8)
weights = classifier.weights_ a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)