神經網絡-手寫字體識別

3層神經網絡,自定義輸入節點、隱藏層、輸出節點的個數,使用sigmoid函數做爲激活函數,梯度降低法進行權重的優化。node

使用MNIST數據集,進行手寫數字識別python

  1 #!/usr/bin/env python
  2 # -*- coding:utf-8 -*-
  3 
  4 #!/usr/bin/env python
  5 # -*- coding:utf-8 -*-
  6 
  7 import numpy
  8 import scipy.special
  9 
 10 
 11 #手寫數字識別神經網絡
 12 class NeuralNetwork():
 13     def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
 14         '''
 15         神經網絡初始化
 16         :param inputnodes: 輸入節點的數量
 17         :param hiddennodes: 隱藏層節點的數量
 18         :param outputnodes: 輸出節點的數量
 19         :param learningrate: 學習率
 20         :return:
 21         '''
 22         self.inodes = inputnodes
 23         self.hnodes = hiddennodes
 24         self.onodes = outputnodes
 25         self.learn = learningrate
 26         self.wih = numpy.random.rand(self.hnodes,self.inodes) - 0.5
 27         self.who = numpy.random.rand(self.onodes,self.hnodes) - 0.5
 28         # self.wih = numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.inodes,self.inodes))
 29         # self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.hnodes,self.hnodes))
 30         self.activate_function = lambda x : scipy.special.expit(x)
 31         # print(self.who)
 32         # print(self.wih)
 33     def train(self,input_list,target_list):
 34         '''
 35         訓練神經網絡首先計算樣本輸出,而後在與目標值進行對比,更新權重
 36         :param input_list: 輸入值
 37         :param target_list: 目標值
 38         :return:
 39         '''
 40         #針對樣本計算輸出,與query函數同樣
 41         inputs = numpy.array(input_list).T
 42         targets = numpy.array(target_list).T
 43         hidden_inputs = numpy.dot(self.wih,inputs)
 44         hidden_outputs = self.activate_function(hidden_inputs)
 45         final_inputs = numpy.dot(self.who,hidden_outputs)
 46         final_outpust = self.activate_function(final_inputs)
 47 
 48         #將計算獲得的輸出與目標值對比,更新權重
 49         output_error = targets - final_outpust
 50         hidden_error = numpy.dot(self.who.T,output_error)
 51 
 52         # print(output_error.shape)
 53         # print(final_outpust.shape)
 54         # print(hidden_outputs.T.shape)
 55         # self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)),numpy.transpose(hidden_outputs))
 56         # self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
 57 
 58         self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)).reshape((self.onodes,1)),hidden_outputs.reshape((1,self.hnodes)))
 59         self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)).reshape((self.hnodes,1)),inputs.reshape((1,self.inodes)))
 60 
 61 
 62 
 63     def query(self,input_list):
 64         '''
 65         計算輸出
 66         :param input_list:
 67         :return:
 68         '''
 69         inputs = numpy.array(input_list).T
 70         hidden_inputs = numpy.dot(self.wih,inputs)
 71         hidden_outputs = self.activate_function(hidden_inputs)
 72         final_inputs = numpy.dot(self.who,hidden_outputs)
 73         final_outpust = self.activate_function(final_inputs)
 74 
 75         return final_outpust
 76 
 77 #初始化一個神經網絡對象
 78 n = NeuralNetwork(784,100,10,0.5)
 79 
 80 #訓練數據
 81 with open('dataset/mnist_train.csv','r') as f:
 82     train_data = f.readlines()
 83 
 84 #訓練神經網絡
 85 for line in train_data:
 86     data = line.split(',')
 87     inputs = (numpy.asfarray(data[1:]) / 255 * 0.99) + 0.01
 88     targets = numpy.zeros(n.onodes)+0.01
 89     targets[int(data[0])] = 0.99
 90 
 91     n.train(inputs,targets)
 92 
 93 
 94 #測試神經網絡
 95 with open('dataset/mnist_test_10.csv','r') as f:
 96     test_data = f.readlines()
 97 
 98 for line in test_data:
 99     label = int(line[0])
100     data = line.split(',')
101     input_list = numpy.asfarray(data[1:])
102     output = n.query(input_list)
103 
104     print(label)
105     print(output)

代碼實現了手寫數字的識別,能夠在此基礎上,進行改進研究,好比調節學習率、初始化權重的方式,激活函數等變化時對結果的影響。網絡

相關文章
相關標籤/搜索