機器學習(2) - KNN識別MNIST

代碼

 https://github.com/s055523/MNISTTensorFlowSharpgit

數據的得到

數據能夠由http://yann.lecun.com/exdb/mnist/下載。以後,儲存在trainDir中,下次就不須要下載了。github

/// <summary>
        /// 若是文件不存在就去下載
        /// </summary>
        /// <param name="urlBase">下載地址</param>
        /// <param name="trainDir">文件目錄地址</param>
        /// <param name="file">文件名</param>
        /// <returns></returns>
        public static Stream MaybeDownload(string urlBase, string trainDir, string file)
        {
            if (!Directory.Exists(trainDir))
            {
                Directory.CreateDirectory(trainDir);
            }

            var target = Path.Combine(trainDir, file);
            if (!File.Exists(target))
            {
                var wc = new WebClient();
                wc.DownloadFile(urlBase + file, target);
            }
            return File.OpenRead(target);
        }
View Code

數據格式處理

下載下來的文件共有四個,都是擴展名爲gz的壓縮包。算法

train-images-idx3-ubyte.gz  55000張訓練圖片和5000張驗證圖片數據庫

train-labels-idx1-ubyte.gz     訓練圖片對應的數字標籤(即答案)數組

t10k-images-idx3-ubyte.gz   10000張測試圖片session

t10k-labels-idx1-ubyte.gz     測試圖片對應的數字標籤(即答案)ide

處理圖片數據壓縮包

每一個壓縮包的格式爲:測試

偏移量this

類型編碼

意義

0

Int32

2051或2049

一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049)

4

Int32

60000或10000

壓縮包的圖片數

8

Int32

28

每一個圖片的行數

12

Int32

28

每一個圖片的列數

16

Unsigned byte

0 - 255

第一張圖片的第一個像素

17

Unsigned byte

0 - 255

第一張圖片的第二個像素

 

所以,咱們可使用一個統一的方式將數據處理。咱們只須要那些圖片像素。

/// <summary>
        /// 從數據流中讀取下一個int32
        /// </summary>
        /// <param name="s"></param>
        /// <returns></returns>
        int Read32(Stream s)
        {
            var x = new byte[4];
            s.Read(x, 0, 4);
            return DataConverter.BigEndian.GetInt32(x, 0);
        }

        /// <summary>
        /// 處理圖片數據
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        MnistImage[] ExtractImages(Stream input, string file)
        {
            //文件是gz格式的
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2051說明下載的文件不對
                if (Read32(gz) != 2051)
                {
                    throw new Exception("不是2051說明下載的文件不對: " + file);
                }
                //圖片數
                var count = Read32(gz);
                //行數
                var rows = Read32(gz);
                //列數
                var cols = Read32(gz);

                Console.WriteLine($"準備讀取{count}張圖片。");

                var result = new MnistImage[count];
                for (int i = 0; i < count; i++)
                {
                    //圖片的大小(每一個像素佔一個bit)
                    var size = rows * cols;
                    var data = new byte[size];

                    //從數據流中讀取這麼大的一塊內容
                    gz.Read(data, 0, size);

                    //將讀取到的內容轉換爲MnistImage類型
                    result[i] = new MnistImage(cols, rows, data);
                }
                return result;
            }
        }
View Code

準備一個MnistImage類型:

/// <summary>
    /// 圖片類型
    /// </summary>
    public struct MnistImage
    {
        public int Cols, Rows;
        public byte[] Data;
        public float[] DataFloat;

        public MnistImage(int cols, int rows, byte[] data)
        {
            Cols = cols;
            Rows = rows;
            Data = data;
            DataFloat = new float[data.Length];
            for (int i = 0; i < data.Length; i++)
            {
                //數據歸一化(這裏將0-255除255變成了0-1之間的小數)
                //也能夠歸一爲-0.5到0.5之間
                DataFloat[i] = Data[i] / 255f;
            }
        }
    }
View Code

這樣一來,圖片數據就處理完成了。

處理數字標籤數據壓縮包

數字標籤數據壓縮包和圖片數據壓縮包的格式相似。

偏移量

類型

意義

0

Int32

2051或2049

一個定死的魔術數。用來驗證該壓縮包是訓練集(2051)或測試集(2049)

4

Int32

60000或10000

壓縮包的數字標籤數

5

Unsigned byte

0 - 9

第一張圖片對應的數字

6

Unsigned byte

0 - 9

第二張圖片對應的數字

 

它的處理更加簡單。

/// <summary>
        /// 處理標籤數據
        /// </summary>
        /// <param name="input"></param>
        /// <param name="file"></param>
        /// <returns></returns>
        byte[] ExtractLabels(Stream input, string file)
        {
            using (var gz = new GZipStream(input, CompressionMode.Decompress))
            {
                //不是2049說明下載的文件不對
                if (Read32(gz) != 2049)
                {
                    throw new Exception("不是2049說明下載的文件不對:" + file);
                }
                var count = Read32(gz);
                var labels = new byte[count];

                gz.Read(labels, 0, count);

                return labels;
            }
        }
View Code

將數字標籤轉化爲二維數組:one-hot編碼

因爲咱們的數字爲0-9,因此,能夠視爲有十個class。此時,爲了後續的處理方便,咱們將數字標籤轉化爲數組。所以,一組標籤就轉換爲了一個二維數組。

例如,標籤0變成[1,0,0,0,0,0,0,0,0,0]

標籤1變成[0,1,0,0,0,0,0,0,0,0]

以此類推。

/// <summary>
        /// 將數字標籤一維數組轉爲一個二維數組
        /// </summary>
        /// <param name="labels"></param>
        /// <param name="numClasses">多少個類別,這裏是10(0到9)</param>
        /// <returns></returns>
        byte[,] OneHot(byte[] labels, int numClasses)
        {
            var oneHot = new byte[labels.Length, numClasses];
            for (int i = 0; i < labels.Length; i++)
            {
                oneHot[i, labels[i]] = 1;
            }
            return oneHot;
        }
View Code

到此爲止,數據格式處理就所有結束了。下面的代碼展現了數據處理的全過程。

        /// <summary>
        /// 處理數據集
        /// </summary>
        /// <param name="trainDir">數據集所在文件夾</param>
        /// <param name="numClasses"></param>
        /// <param name="validationSize">拿出多少作驗證?</param>
        public void ReadDataSets(string trainDir, int numClasses = 10, int validationSize = 5000)
        {
            const string SourceUrl = "http://yann.lecun.com/exdb/mnist/";
            const string TrainImagesName = "train-images-idx3-ubyte.gz";
            const string TrainLabelsName = "train-labels-idx1-ubyte.gz";
            const string TestImagesName = "t10k-images-idx3-ubyte.gz";
            const string TestLabelsName = "t10k-labels-idx1-ubyte.gz";

            //得到訓練數據,而後處理訓練數據和測試數據
            TrainImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TrainImagesName), TrainImagesName);
            TestImages = ExtractImages(Helper.MaybeDownload(SourceUrl, trainDir, TestImagesName), TestImagesName);
            TrainLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TrainLabelsName), TrainLabelsName);
            TestLabels = ExtractLabels(Helper.MaybeDownload(SourceUrl, trainDir, TestLabelsName), TestLabelsName);

            //拿出前面的一部分作驗證
            ValidationImages = Pick(TrainImages, 0, validationSize);
            ValidationLabels = Pick(TrainLabels, 0, validationSize);

            //拿出剩下的作訓練(輸入0意味着拿剩下全部的)
            TrainImages = Pick(TrainImages, validationSize, 0);
            TrainLabels = Pick(TrainLabels, validationSize, 0);

            //將數字標籤轉換爲二維數組
            //例如,標籤3 =》 [0,0,0,1,0,0,0,0,0,0]
            //標籤0 =》 [1,0,0,0,0,0,0,0,0,0]
            if (numClasses != -1)
            {
                OneHotTrainLabels = OneHot(TrainLabels, numClasses);
                OneHotValidationLabels = OneHot(ValidationLabels, numClasses);
                OneHotTestLabels = OneHot(TestLabels, numClasses);
            }
        }

        /// <summary>
        /// 得到source集合中的一部分,從first開始,到last結束
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="source"></param>
        /// <param name="first"></param>
        /// <param name="last"></param>
        /// <returns></returns>
        T[] Pick<T>(T[] source, int first, int last)
        {
            if (last == 0)
            {
                last = source.Length;
            }

            var count = last - first;
            var ret = source.Skip(first).Take(count).ToArray();
            return ret;
        }

        public static Mnist Load()
        {
            var x = new Mnist();
            x.ReadDataSets(@"D:\人工智能\C#代碼\MNISTTensorFlowSharp\MNISTTensorFlowSharp\data");
            return x;
        }
View Code

在這裏,數據共有下面幾部分:

  1. 訓練圖片數據55000 TrainImages及對應標籤TrainLabels
  2. 驗證圖片數據5000 ValidationImages及對應標籤ValidationLabels
  3. 測試圖片數據10000 TestImages及對應標籤TestLabels

KNN算法的實現

如今,咱們已經有了全部的數據在手。須要實現的是:

  1. 拿出數據中的一部分(例如,5000張圖片)做爲KNN的訓練數據,而後,再從數據中的另外一部分拿一張圖片A
  2. 對這張圖片A,求它和5000張訓練圖片的距離,並找出一張訓練圖片B,它是全部訓練圖片中,和A距離最小的那張(這意味着K=1)
  3. 此時,就認爲A所表明的數字等同於B所表明的數字b
  4. 重複1-3,N次

首先進行數據的收集:

//三個Reader分別從總的數據庫中得到數據
        public BatchReader GetTrainReader() => new BatchReader(TrainImages, TrainLabels, OneHotTrainLabels);
        public BatchReader GetTestReader() => new BatchReader(TestImages, TestLabels, OneHotTestLabels);
        public BatchReader GetValidationReader() => new BatchReader(ValidationImages, ValidationLabels, OneHotValidationLabels);

        /// <summary>
        /// 數據的一部分,包括了全部的有用信息
        /// </summary>
        public class BatchReader
        {
            int start = 0;
            //圖片庫
            MnistImage[] source;
            //數字標籤
            byte[] labels;
            //oneHot以後的數字標籤
            byte[,] oneHotLabels;

            internal BatchReader(MnistImage[] source, byte[] labels, byte[,] oneHotLabels)
            {
                this.source = source;
                this.labels = labels;
                this.oneHotLabels = oneHotLabels;
            }

            /// <summary>
            /// 返回兩個浮點二維數組(C# 7的新語法)
            /// </summary>
            /// <param name="batchSize"></param>
            /// <returns></returns>
            public (float[,], float[,]) NextBatch(int batchSize)
            {
                //一張圖
                var imageData = new float[batchSize, 784];
                //標籤
                var labelData = new float[batchSize, 10];

                int p = 0;
                for (int item = 0; item < batchSize; item++)
                {
                    Buffer.BlockCopy(source[start + item].DataFloat, 0, imageData, p, 784 * sizeof(float));
                    p += 784 * sizeof(float);
                    for (var j = 0; j < 10; j++)
                        labelData[item, j] = oneHotLabels[item + start, j];
                }

                start += batchSize;
                return (imageData, labelData);
            }
        }
View Code

而後,在算法中,獲取數據:

        static void KNN()
        {
            //取得數據
            var mnist = Mnist.Load();

            //拿5000個訓練數據,200個測試數據
            const int trainCount = 5000;
            const int testCount = 200;

            //得到的數據有兩個
            //一個是圖片,它們都是28*28的
            //一個是one-hot的標籤,它們都是1*10的
            (var trainingImages, var trainingLabels) = mnist.GetTrainReader().NextBatch(trainCount);
            (var testImages, var testLabels) = mnist.GetTestReader().NextBatch(testCount);

            Console.WriteLine($"MNIST 1NN");
View Code

下面進行計算。這裏使用了K=1的L1距離。這是最簡單的狀況。

            //創建一個圖表示計算任務
            using (var graph = new TFGraph())
            {
                var session = new TFSession(graph);

                //用來feed數據的佔位符。trainingInput表示N張用來進行訓練的圖片,N是一個變量,因此這裏使用-1
                TFOutput trainingInput = graph.Placeholder(TFDataType.Float, new TFShape(-1, 784));

                //xte表示一張用來測試的圖片
                TFOutput xte = graph.Placeholder(TFDataType.Float, new TFShape(784));

                //計算這兩張圖片的L1距離。這很簡單,實際上就是把784個數字逐對相減,而後取絕對值,最後加起來變成一個總和
                var distance = graph.ReduceSum(graph.Abs(graph.Sub(trainingInput, xte)), axis: graph.Const(1));

                //這裏只是用了最近的那個數據
                //也就是說,最近的那個數據是什麼,那pred(預測值)就是什麼
                TFOutput pred = graph.ArgMin(distance, graph.Const(0));
View Code

最後是開啓Session計算的過程:

                var accuracy = 0f;

                //開始循環進行計算,循環trainCount次
                for (int i = 0; i < testCount; i++)
                {
                    var runner = session.GetRunner();

                    //每次,對一張新的測試圖,計算它和trainCount張訓練圖的距離,並得到最近的那張
                    var result = runner.Fetch(pred).Fetch(distance)
                        //trainCount張訓練圖(數據是trainingImages)
                        .AddInput(trainingInput, trainingImages)
                        //testCount張測試圖(數據是從testImages中拿出來的)
                        .AddInput(xte, Extract(testImages, i))
                        .Run();
                    
                    //最近的點的序號
                    var nn_index = (int)(long)result[0].GetValue();

                    //從trainingLabels中找到答案(這是預測值)
                    var prediction = ArgMax(trainingLabels, nn_index);

                    //正確答案位於testLabels[i]中
                    var real = ArgMax(testLabels, i);

                    //PrintImage(testImages, i);

                    Console.WriteLine($"測試 {i}: " +
                        $"預測: {prediction} " +
                        $"正確答案: {real} (最近的點的序號={nn_index})");
                    //Console.WriteLine(testImages);

                    if (prediction == real)
                    {
                        accuracy += 1f / testCount;
                    }
                }
                Console.WriteLine("準確率: " + accuracy);
View Code

對KNN的改進

本文只是對KNN識別MNIST數據集進行了一個很是簡單的介紹。在實現了最簡單的K=1的L1距離計算以後,正確率約爲91%。你們能夠試着將算法進行改進,例如取K=2或者其餘數,或者計算L2距離等。L2距離的結果比L1好一些,能夠達到93-94%的正確率。

相關文章
相關標籤/搜索