Tensorflow 實現Mnist圖片預測

1.準備數據,使用佔位符,動態加載訓練數據git

x=tf.placeholder(tf.float32,[None,784])
y_true=tf.placeholder(tf.int32,[None,10])

2.初始化參數,創建模型dom

weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
bias=tf.Variable(tf.canstant(0.0,shape=[10]))
y_predict=tf.matmul(x,weight)+bias

3.求平均交叉熵損失優化

loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))

4.梯度降低優化scala

train_op=tf.GradientDescentOptimizer(0.3).minimize(loss)

5.求準確率rest

equal_list=tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))

完整代碼:code

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets('./data/MNISI_data/', one_hot=True)


def full_connection():
    # 1.準備數據
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32, [None, 784])
        y_true = tf.placeholder(tf.int32, [None, 10])
    # 2.創建模型
    with tf.variable_scope('predict_model'):
        weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w')
        bias = tf.Variable(tf.constant(0.0, shape=[10]))
        y_predict = tf.matmul(x, weight) + bias
    # 3.平均交叉熵損失
    with tf.variable_scope('loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
    # 4.梯度降低優化
    with tf.variable_scope('optimizer'):
        train_op = tf.train.GradientDescentOptimizer(0.4).minimize(loss)
    # 5.求準確率
    with tf.variable_scope('acc'):
        equal_list = tf.equal(tf.arg_max(y_true, 1), tf.arg_max(y_predict, 1))
        accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    init_op = tf.initialize_all_variables()
    # 收集變量,tensorboard使用
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.histogram('weight', weight)
    tf.summary.histogram('bias', bias)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()
    is_train = False
    with tf.Session() as sess:
        if is_train == True:
            sess.run(init_op)
            fileWriter = tf.summary.FileWriter('./temp/summary/test', graph=sess.graph)
            if os.path.exists('./temp/ckpt/checkpoint'):
                # 加載訓練的模型
                saver.restore(sess, './temp/ckpt/full_conn')
            for i in range(4000):
                # 每次批量貨期50個數據集
                mnist_x, mnist_y = mnist.train.next_batch(50)
                sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})
                summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y})
                fileWriter.add_summary(summary, i)
                print("訓練低%d步,準確率爲:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))
            # 保存訓練完的模型
            saver.save(sess, './temp/ckpt/full_conn')
        else:
            saver.restore(sess, './temp/ckpt/full_conn')
            for i in range(100):
                # 每次批量貨期1個數據集
                x_test, y_test = mnist.test.next_batch(1)
                print('低%d張圖片,手寫數字圖片目標:%d--%d' % (
                    i,
                    tf.arg_max(y_test, 1).eval(),
                    tf.arg_max(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()
                ))


if __name__ == '__main__':
    full_connection()
相關文章
相關標籤/搜索