使用LAP數據集進行年齡訓練及估計

1、背景html

本來是打算按《DEX Deep EXpectation of apparent age from a single image》進行表面年齡的訓練,可因爲IMDB-WIKI的數據集比較龐大,各個年齡段分佈不均勻,難以劃分訓練集及驗證集。後來爲了先跑通整個訓練過程的主要部分,就直接用LAP數據集,參考caffe的finetune_flickr_style,進行一些參數修改,利用bvlc_reference_caffenet.caffemodel完成年齡估計的finetune。ios

2、訓練數據集準備git

一、下載LAP數據集,包括Train、Validation、Test,以及對應的年齡label,http://chalearnlap.cvc.uab.es/dataset/18/description/,須要註冊。github

二、將標註好的csv文件轉換爲caffe識別的txt格式。csv每一行的信息爲:圖片名、年齡、標準差。訓練的時候不須要標準差信息,咱們只要將圖片名和年齡寫入到txt中,並按空格隔開,獲得Train.txt以下:app

image

一樣,完成驗證集cvs文件的轉換,獲得Validation.txt。函數

3、模型及相關文件拷貝工具

一、拷貝預訓練好的vgg16模型caffe\models\bvlc_reference_caffenet\bvlc_reference_caffenet.caffemodel至工做目錄下,該文件約232M;測試

二、拷貝caffe\models\finetune_flickr_style文件夾中deploy.prototxt、solver.prototxt、train_val.prototxt至工做目錄下;spa

三、拷貝imageNet的均值文件caffe\data\ilsvrc12\imagenet_mean.binaryproto至工做目錄下。code

4、參數修改

一、修改train_val.prototxt

image

以及最後的輸出層個數,由於咱們要訓練的爲[0,100]歲的輸出,共101類,因此:

image

二、修改solver.protxt

image

三、修改用於實際測試的部署文件deploy.protxt

image

輸出層的個數也要改:

image

5、開始訓練

一、新建train.bat

caffe train -solver solver.prototxt -weights bvlc_reference_caffenet.caffemodel
rem caffe train --solver solver.prototxt --snapshot snapshot/bvlc_iter_48000.solverstate
pause

雙擊便可開始訓練,當訓練過程當中出現意外中斷,可註釋第一行,關閉第二行註釋,根據實際狀況修改保存,繼續雙擊訓練。

個人電腦CPU是i5 6500,顯卡爲gtx1050Ti,8G內存,大體要訓練10個小時吧,中途也出現了一些內存不足訓練終止的狀況。

二、訓練結束

QQ截圖20181005101028-lap2_2

6、模型評價

年齡估計本來是一個線性問題,不是一個明確的分類問題,人都沒法準確無誤地獲得某人的年齡,更況且是機器呢。因此評價這個年齡分類模型的好壞不能簡單地經過精度來衡量,能夠用MAE(平均絕對偏差)以及ε-error來衡量,其中

image

一、對驗證集Validation.txt的全部圖片進行預測

藉助 https://github.com/eveningglow/age-and-gender-classification ,其環境搭建可參考http://www.javashuo.com/article/p-tkoovjxg-o.html

修改main函數

int split(std::string str, std::string pattern, std::vector<std::string> &words)
{
    std::string::size_type pos;
    std::string word;
    int num = 0;
    str += pattern;
    std::string::size_type size = str.size();
    for (auto i = 0; i < size; i++) {
        pos = str.find(pattern, i);
        if (pos == i) {
            continue;//if first string is pattern
        }
        if (pos < size) {
            word = str.substr(i, pos - i);
            words.push_back(word);
            i = pos + pattern.size() - 1;
            num++;
        }
    }
    return num;
}

//param example: model/deploy_age2.prototxt model/age_net.caffemodel model/mean.binaryproto img/0008.jpg
int main(int argc, char* argv[])
{
    if (argc != 5)
    {
        cout << "Command shoud be like ..." << endl;
        cout << "AgeAndGenderClassification ";
        cout << " \"AGE_NET_MODEL_FILE_PATH\" \"AGE_NET_WEIGHT_FILE_PATH\" \"MEAN_FILE_PATH\" \"TEST_IMAGE\" " << endl;
        std::cout << "argc = " << argc << std::endl;
        getchar();
        return 0;
    }

    // Get each file path
    string age_model(argv[1]);
    string age_weight(argv[2]);
    string mean_file(argv[3]);
    //string test_image(argv[4]);

    // Probability vector
    vector<Dtype> prob_age_vec;

    // Set mode
    Caffe::set_mode(Caffe::GPU);

    // Make AgeNet
    AgeNet age_net(age_model, age_weight, mean_file);

    // Initiailize both nets
    age_net.initNetwork();

    //讀取待測試的圖片名
    std::ifstream fin("E:\\caffe\\DEX_age_gender_predict\\lap2\\Validation.txt");
    std::string line;
    std::vector<std::string> test_images;
    std::vector<int> test_images_age;
    while (!fin.eof()) {
        std::getline(fin, line);
        std::vector<std::string> words;
        split(line, " ", words);
        test_images.push_back(words[0]);
        test_images_age.push_back(atoi(words[1].c_str()));
    }
    std::cout << "test_images size = " << test_images.size() << std::endl;

    std::ofstream fout("E:\\caffe\\DEX_age_gender_predict\\lap2\\Validation_predict.txt");
    for (int k = 0; k < test_images.size(); ++k) {
        std::cout << "k = " << k << std::endl;
        std::string test_image;
        test_image = test_images[k];

        // Classify and get probabilities
        Mat test_img = imread(test_image, CV_LOAD_IMAGE_COLOR);
        int age = age_net.classify(test_img, prob_age_vec);

        // Print result and show image
        //std::cout << "prob_age_vec size = " << prob_age_vec.size() << std::endl;
        //for (int i = 0; i < prob_age_vec.size(); ++i) {
        //    std::cout << "[" << i << "] = " << prob_age_vec[i] << std::endl;
        //}

        //Dtype prob;
        //int index;
        //get_max_value(prob_age_vec, prob, index);
        //std::cout << "prob = " << prob << ", index = " << index << std::endl;

        //imshow("AgeAndGender", test_img);
        //waitKey(0);
        fout << test_images[k] << " " << test_images_age[k] << " " << age << std::endl;


    }

    std::cout << "finish!" << std::endl;
    getchar();
    return 0;
}

個人命令參數爲:E:\caffe\DEX_age_gender_predict\lap2\deploy.prototxt E:\caffe\DEX_age_gender_predict\lap2\snapshot\bvlc_iter_50000.caffemodel model\mean.binaryproto img\0008.jpg

可根據實際狀況修改。可獲得Validation_predict.txt文件。運行過程當中可能會由於內存不足中斷運行,可能要分批次運行屢次。

二、計算MAE及ε-error

(1)將Validation_predict.txt文件及驗證集的標註文件Reference.csv拷貝到新建的vs項目的工做目錄下;

(2)計算

#include <iostream>
#include <string>
#include <fstream>
#include <vector>

int split(std::string str, std::string pattern, std::vector<std::string> &words)
{
    std::string::size_type pos;
    std::string word;
    int num = 0;
    str += pattern;
    std::string::size_type size = str.size();
    for (auto i = 0; i < size; i++) {
        pos = str.find(pattern, i);
        if (pos == i) {
            continue;//if first string is pattern
        }
        if (pos < size) {
            word = str.substr(i, pos - i);
            words.push_back(word);
            i = pos + pattern.size() - 1;
            num++;
        }
    }
    return num;
}

int main(int argc, char** argv)
{
    //u, sigma, x
    std::vector<int> u;
    std::vector<float> sigma;
    std::vector<int> predict;

    std::string line;
    std::ifstream csv_file("Reference.csv");
    while (!csv_file.eof()) {
        std::getline(csv_file, line);
        std::vector<std::string> words;
        split(line, ";", words);
        u.push_back(atoi(words[1].c_str()));
        sigma.push_back(atof(words[2].c_str()));
    }
    std::ifstream predict_file("Validation_predict.txt");
    while (!predict_file.eof()) {
        std::getline(predict_file, line);
        std::vector<std::string> words;
        split(line, " ", words);
        predict.push_back(atoi(words[2].c_str()));
    }
    if (u.size() != predict.size()) {
        std::cout << "u.size() != predict.size()" << std::endl;
        getchar();
        return -1;
    }

    //MAE
    int sum_err = 0;
    float MAE = 0;
    for (int i = 0; i < u.size(); ++i) {
        sum_err += abs(u[i] - predict[i]);
    }
    MAE = static_cast<float>(sum_err) / u.size();
    std::cout << "MAE = " << MAE << std::endl;//11.7184

    //esro-error
    std::vector<float> errors;
    float err = 0;
    float error = 0.0;
    for (int i = 0; i < u.size(); ++i) {
        err = 1.0 - exp(-1.0*(predict[i] - u[i])*(predict[i] - u[i]) / (2 * sigma[i] * sigma[i]));
        errors.push_back(err);
        error += err;
    }
    error /= errors.size();
    std::cout << "error = " << error << std::endl;//0.682652
    

    std::cout << "finish!" << std::endl;
    getchar();
    return 0;
}

最終獲得MAE爲11.7184, ε-error爲0.682652。

7、實際應用中預測

一、可利用caffe提供的classification工具對輸入圖片地進行估計

classification deploy.prototxt snapshot\bvlc_iter_50000.caffemodel imagenet_mean.binaryproto ..\age_labels.txt ..\test_image\test_3.jpg
pause

其中,age_labels.txt爲0-100個label的說明信息,每一個label對應一行,共101行,個人寫法以下:

image

 

end

相關文章
相關標籤/搜索