keras實現mnist手寫數字數據集的訓練

網絡:兩層卷積,兩層全鏈接,一層softmax
代碼:python

import numpy as np
from keras.utils import to_categorical
from keras import Sequential
from keras import layers
from keras import optimizers
from keras.datasets import mnist
from PIL import Image

(train_x, train_y), (test_x, test_y) = mnist.load_data()

train_x = train_x / 255.0
test_x = test_x / 255.0
train_y = to_categorical(train_y)
test_y = to_categorical(test_y)

model = Sequential()
model.add(layers.Reshape((28,28,1,), input_shape=(28, 28, )))
model.add(layers.Conv2D(32, 3, activation='relu'))
model.add(layers.Conv2D(64, 3, activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(10))
model.add(layers.Softmax(10))
model.compile(optimizer=optimizers.RMSprop(lr = 1e-4), loss='categorical_crossentropy', metrics=['acc'])

model.fit(train_x, train_y, epochs=5)
acc = model.evaluate(test_x, test_y)
print('The final accuracy is ' + acc[1])

最後在測試集上的準確率爲98 %左右網絡

相關文章
相關標籤/搜索