matlab練習程序(神經網絡識別mnist手寫數據集)

記得上次練習了神經網絡分類,不過當時應該有些地方寫的仍是不對。html

此次用神經網絡識別mnist手寫數據集,主要參考了深度學習工具包的一些代碼。git

mnist數據集訓練數據一共有28*28*60000個像素,標籤有60000個。github

測試數據一共有28*28*10000個,標籤10000個。微信

這裏神經網絡輸入層是784個像素,用了100個隱含層,最終10個輸出結果。網絡

arc表明的是神經網絡結構,能夠增長隱含層,不過我試了沒太大效果,畢竟梯度消失。ide

由於是最普通的神經網絡,最終識別錯誤率大概在5%左右。工具

迭代曲線:學習

 

代碼以下:測試

 
clear all;
close all;
clc;

load mnist_uint8;

train_x = double(train_x) / 255;
test_x  = double(test_x)  / 255;
train_y = double(train_y);
test_y  = double(test_y);

mu=mean(train_x);    
sigma=max(std(train_x),eps);
train_x=bsxfun(@minus,train_x,mu);          %每一個樣本分別減去平均值
train_x=bsxfun(@rdivide,train_x,sigma);     %分別除以標準差

test_x=bsxfun(@minus,test_x,mu);
test_x=bsxfun(@rdivide,test_x,sigma);

arc = [784 100 10]; %輸入784,隱含層100,輸出10
n=numel(arc);

W = cell(1,n-1);    %權重矩陣
for i=2:n
    W{i-1} = (rand(arc(i),arc(i-1)+1)-0.5) * 8 *sqrt(6 / (arc(i)+arc(i-1)));
end

learningRate = 2;   %訓練速度
numepochs = 5;      %訓練5遍
batchsize = 100;    %一次訓練100個數據

m = size(train_x, 1);       %數據總量
numbatches = m / batchsize;    %一共有numbatches這麼多組

%% 訓練
L = zeros(numepochs*numbatches,1);
ll=1;
for i = 1 : numepochs
    kk = randperm(m);
    for l = 1 : numbatches
        batch_x = train_x(kk((l - 1) * batchsize + 1 : l * batchsize), :);
        batch_y = train_y(kk((l - 1) * batchsize + 1 : l * batchsize), :);

       %% 正向傳播
        mm = size(batch_x,1);
        x = [ones(mm,1) batch_x];
        a{1} = x;
        for ii = 2 : n-1
            a{ii} = 1.7159*tanh(2/3.*(a{ii - 1} * W{ii - 1}'));   
            a{ii} = [ones(mm,1) a{ii}];
        end
        
        a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));
        e = batch_y - a{n};
        L(ll) = 1/2 * sum(sum(e.^2)) / mm; 
        ll=ll+1;
       %% 反向傳播
        d{n} = -e.*(a{n}.*(1 - a{n}));
        for ii = (n - 1) : -1 : 2
            d_act = 1.7159 * 2/3 * (1 - 1/(1.7159)^2 * a{ii}.^2);
            
            if ii+1==n    
                d{ii} = (d{ii + 1} * W{ii}) .* d_act; 
            else 
                d{ii} = (d{ii + 1}(:,2:end) * W{ii}).* d_act;
            end          
        end
         
        for ii = 1 : n-1
            if ii + 1 == n
                dW{ii} = (d{ii + 1}' * a{ii}) / size(d{ii + 1}, 1);
            else
                dW{ii} = (d{ii + 1}(:,2:end)' * a{ii}) / size(d{ii + 1}, 1);      
            end
        end
         
       %% 更新參數
        for ii = 1 : n - 1       
            W{ii} = W{ii} - learningRate*dW{ii};
        end
              
    end
end

%% 測試,至關於把正向傳播再走一遍
mm = size(test_x,1);
x = [ones(mm,1) test_x];
a{1} = x;
for ii = 2 : n-1    
    a{ii} = 1.7159 * tanh( 2/3 .* (a{ii - 1} * W{ii - 1}'));  
    a{ii} = [ones(mm,1) a{ii}];
end
a{n} = 1./(1+exp(-(a{n - 1} * W{n - 1}')));

[~, i] = max(a{end},[],2);
labels = i;                         %識別後打的標籤
[~, expected] = max(test_y,[],2);
bad = find(labels ~= expected);     %有哪些識別錯了
er = numel(bad) / size(x, 1)       %錯誤率

plot(L);
 

測試數據能夠在這裏下載到:https://pan.baidu.com/s/19YPUe9S9xnztg9JGnoXxqwui

關注公衆號: MATLAB基於模型的設計 (ID:xaxymaker) ,天天推送MATLAB學習最多見的問題,天天進步一點點,業精於勤荒於嬉

 打開微信掃一掃哦!

相關文章
相關標籤/搜索