使用tensorflow實現cnn進行mnist識別

第一個CNN代碼,暫時對於CNN的BP還不熟悉。可是經過這個代碼對於tensorflow的運行機制有了初步的理解python

 

 1 '''
 2 softmax classifier for mnist  3 
 4 created on 2019.9.28  5 author: vince  6 '''
 7 import math  8 import logging  9 import numpy 10 import random 11 import matplotlib.pyplot as plt 12 import tensorflow as tf 13 from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 14 from sklearn.metrics import accuracy_score 15 
16 def weight_bais_variable(shape): 17     init = tf.random.truncated_normal(shape = shape, stddev = 0.01); 18     return tf.Variable(init); 19 
20 def bais_variable(shape): 21     init = tf.constant(0.1, shape=shape); 22     return tf.Variable(init); 23 
24 def conv2d(x, w): 25     return tf.nn.conv2d(x, w, [1, 1, 1, 1], padding = "SAME"); 26 
27 def max_pool_2x2(x): 28     return tf.nn.max_pool2d(x, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME"); 29 
30 def cnn(x, rate): 31     with tf.name_scope('reshape'): 32         x_image = tf.reshape(x, [-1, 28, 28, 1]); 33     
34     #first layer, conv & pool 
35     with tf.name_scope('conv1'): 36         w_conv1 = weight_bais_variable([5, 5, 1, 32]); 37         b_conv1 = bais_variable([32]); 38         h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1); #28 * 28 * 32
39     with tf.name_scope('pool1'): 40         h_pool1 = max_pool_2x2(h_conv1); #14 * 14 * 32
41     
42     #second layer, conv & pool 
43     with tf.name_scope('conv2'): 44         w_conv2 = weight_bais_variable([5, 5, 32, 64]); 45         b_conv2 = bais_variable([64]); 46         h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2); #14 * 14 * 64 
47     with tf.name_scope('pool2'): 48         h_pool2 = max_pool_2x2(h_conv2);  #7 * 7 * 64 
49 
50     #first full connect layer, feature graph -> feature vector 
51     with tf.name_scope('fc1'): 52         w_fc1 = weight_bais_variable([7 * 7 * 64, 1024]); 53         b_fc1 = bais_variable([1024]); 54         h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]); 55         h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1); 56     with tf.name_scope("dropout1"): 57         h_fc1_drop = tf.nn.dropout(h_fc1, rate); 58 
59     #second full connect layer, 
60     with tf.name_scope('fc2'): 61         w_fc2 = weight_bais_variable([1024, 10]); 62         b_fc2 = bais_variable([10]); 63         #h_fc2 = tf.matmul(h_fc1_drop, w_fc2) + b_fc2;
64         h_fc2 = tf.matmul(h_fc1, w_fc2) + b_fc2; 65     return h_fc2; 66 
67 
68 def main(): 69     logging.basicConfig(level = logging.INFO, 70             format = '%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s', 71             datefmt = '%a, %d %b %Y %H:%M:%S'); 72 
73     mnist = read_data_sets('../data/MNIST',one_hot=True)    # MNIST_data指的是存放數據的文件夾路徑,one_hot=True 爲採用one_hot的編碼方式編碼標籤
74 
75     x = tf.placeholder(tf.float32, [None, 784]); 76     y_real = tf.placeholder(tf.float32, [None, 10]); 77     rate = tf.placeholder(tf.float32); 78 
79     y_pre = cnn(x, rate); 80 
81     sess = tf.InteractiveSession(); 82  sess.run(tf.global_variables_initializer()); 83 
84     loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_pre, labels = y_real)); 85     train_op = tf.train.GradientDescentOptimizer(0.5).minimize(loss); 86 
87     correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(y_real, 1)); 88     prediction_op= tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); 89     for _ in range(300): 90         batch_xs, batch_ys = mnist.train.next_batch(128); 91         sess.run(train_op, feed_dict = {x : batch_xs, y_real : batch_ys, rate: 0.5}); 92         if _ % 10 == 0: 93             accuracy = sess.run(prediction_op, feed_dict = {x : mnist.test.images, y_real : mnist.test.labels, rate: 0.0 }); 94             logging.info("%s : %s" % (_, accuracy)); 95 
96 if __name__ == "__main__": 97     main();
相關文章
相關標籤/搜索