機器學習框架ML.NET學習筆記【5】多元分類之手寫數字識別(續)

1、概述html

 上一篇文章咱們利用ML.NET的多元分類算法實現了一個手寫數字識別的例子,這個例子存在一個問題,就是輸入的數據是預處理過的,很不直觀,此次咱們要直接經過圖片來進行學習和判斷。思路很簡單,就是寫一個自定義的數據處理通道,輸入爲文件名,輸出爲float數字,裏面保存的是像素信息。git

 樣本包括6萬張訓練圖片和1萬張測試圖片,圖片爲灰度圖片,分辨率爲20*20 。train_tags.tsv文件對每一個圖片的數值進行了標記,以下:github

  

2、源碼算法

 所有代碼: 數組

namespace MulticlassClassification_Mnist
{
    class Program
    {
        //Assets files download from:https://gitee.com/seabluescn/ML_Assets
        static readonly string AssetsFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST";
        static readonly string TrainTagsPath = Path.Combine(AssetsFolder, "train_tags.tsv");
        static readonly string TrainDataFolder = Path.Combine(AssetsFolder, "train");
        static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 1);
          
            TrainAndSaveModel(mlContext);
            TestSomePredictions(mlContext);

            Console.WriteLine("Hit any key to finish the app");
            Console.ReadKey();
        }

        public static void TrainAndSaveModel(MLContext mlContext)
        {
            // STEP 1: 準備數據
            var fulldata = mlContext.Data.LoadFromTextFile<InputData>(path: TrainTagsPath, separatorChar: '\t', hasHeader: false);
            var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.1);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;

            // STEP 2: 配置數據處理管道        
            var dataProcessPipeline = mlContext.Transforms.CustomMapping(new LoadImageConversion().GetMapping(), contractName: "LoadImageConversionAction")
               .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue))
               .Append(mlContext.Transforms.NormalizeMeanVariance( outputColumnName: "FeaturesNormalizedByMeanVar", inputColumnName: "ImagePixels"));


            // STEP 3: 配置訓練算法 (using a maximum entropy classification model trained with the L-BFGS method)
            var trainer = mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(labelColumnName: "Label", featureColumnName: "FeaturesNormalizedByMeanVar");
            var trainingPipeline = dataProcessPipeline.Append(trainer)
                 .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictNumber", "Label"));


            // STEP 4: 訓練模型使其與數據集擬合           
            ITransformer trainedModel = trainingPipeline.Fit(trainData);          

            // STEP 5:評估模型的準確性           
            var predictions = trainedModel.Transform(testData);
            var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Label", scoreColumnName: "Score");
            PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
          
            // STEP 6:保存模型            
            mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);           
        }

        private static void TestSomePredictions(MLContext mlContext)
        {
            // Load Model           
            ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);

            // Create prediction engine 
            var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);
          
            DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));           
            foreach(var image in TestFolder.GetFiles())
            {
                count++;

                InputData img = new InputData()
                {
                    FileName = image.Name
                };
                var result = predEngine.Predict(img);
               
                Console.WriteLine($"Current Source={img.FileName},PredictResult={result.GetPredictResult()}");                
            }
        }       
    }

    class InputData
    {
        [LoadColumn(0)]
        public string FileName;

        [LoadColumn(1)]
        public string Number;

        [LoadColumn(1)]
        public float Serial;       
    }

    class OutPutData : InputData
    {
        public float[] Score;
        public int GetPredictResult()
        {
            float max = 0;
            int index = 0;
            for (int i = 0; i < Score.Length; i++)
            {
                if (Score[i] > max)
                {
                    max = Score[i];
                    index = i;
                }
            }
            return index;
        }       
    }   
}
View Code

  

3、分析app

 整個處理流程和上一篇文章基本一致,這裏解釋兩個不同的地方。框架

一、自定義的圖片讀取處理通道機器學習

namespace MulticlassClassification_Mnist
{
    public class LoadImageConversionInput
    {
        public string  FileName { get; set; }
    }
 
    public class LoadImageConversionOutput
    {
        [VectorType(400)]
        public float[] ImagePixels { get; set; }
        public string ImagePath;
    }

    [CustomMappingFactoryAttribute("LoadImageConversionAction")]
    public class LoadImageConversion : CustomMappingFactory<LoadImageConversionInput, LoadImageConversionOutput>
    {       
        static readonly string TrainDataFolder = @"D:\StepByStep\Blogs\ML_Assets\MNIST\train";

        public void CustomAction(LoadImageConversionInput input, LoadImageConversionOutput output)
        {  
            string ImagePath = Path.Combine(TrainDataFolder, input.FileName);
            output.ImagePath = ImagePath;

            Bitmap bmp = Image.FromFile(ImagePath) as Bitmap;           

            output.ImagePixels = new float[400];
            for (int x = 0; x < 20; x++)
                for (int y = 0; y < 20; y++)
                {
                    var pixel = bmp.GetPixel(x, y);
                    var gray = (pixel.R + pixel.G + pixel.B) / 3 / 16;
                    output.ImagePixels[x + y * 20] = gray;
                }           
            bmp.Dispose();                     
        }

        public override Action<LoadImageConversionInput, LoadImageConversionOutput> GetMapping()
              => CustomAction;
    }
}

 這裏能夠看出,咱們自定義的數據處理通道,輸入爲文件名稱,輸出是一個float數組,這裏數組必需要指定寬度,因爲圖片分辨率爲20*20,因此數組寬度指定爲400,輸出ImagePath爲文件詳細地址,用來調試使用,沒有實際用途。處理思路很是簡單,遍歷每一個Pixel,計算其灰度值,爲了減小工做量咱們把灰度值進行縮小,除以了16 ,因爲後面數據會作歸一化,因此這裏影響不是太明顯。ide

 

二、模型測試學習

            DirectoryInfo TestFolder = new DirectoryInfo(Path.Combine(AssetsFolder, "test"));
            int count = 0;
            int success = 0;
            foreach(var image in TestFolder.GetFiles())
            {
                count++;

                InputData img = new InputData()
                {
                    FileName = image.Name
                };
                var result = predEngine.Predict(img);

                if(int.Parse(image.Name.Substring(0,1))==result.GetPredictResult())
                {
                    success++;
                }                
            }

 咱們把測試目錄裏的全面圖片讀出遍歷了一遍,將其測試結果和實際結果作了一次驗證,其實是把評估(Evaluate)的事情又重複作了一次,兩次測試的成功率基本接近。

 

4、關於圖片特徵提取

咱們是採用圖片全部像素的灰度值來做爲特徵值的,但必需要強調的是:像素值矩陣不是圖片的典型特徵。雖然有時候對於較規則的圖片,經過像素提取方式進行計算,也能夠取得很好的效果,但在處理稍微複雜一點的圖片的時候,就無論用了,緣由很明顯,咱們人類在分析圖片內容時看到的特徵更可能是線條等信息,絕對不是像素值,看下圖:

咱們人類很容易就判斷出這兩個圖片表達的是同一件事情,但其像素值特徵卻相差甚遠。

 傳統的圖片特徵提取方式不少,好比:SIFT、HOG、LBP、Haar等。 如今採用TensorFlow的模型進行特徵提取效果很是好。下一篇文章介紹圖片分類時再進行詳細介紹。 

 

5、資源獲取

源碼下載地址:https://github.com/seabluescn/Study_ML.NET

工程名稱:MulticlassClassification_Mnist_Useful

MNIST資源獲取:https://gitee.com/seabluescn/ML_Assets

點擊查看機器學習框架ML.NET學習筆記系列文章目錄

相關文章
相關標籤/搜索