fcn訓練及預測tgs數據集

1、背景python

kaggle上有這樣一個題目,關於鹽份預測的語義分割題目。TGS Salt Identification Challenge | Kaggle  https://www.kaggle.com/c/tgs-salt-identification-challengeios


2、過程git

一、下載數據,https://www.kaggle.com/c/tgs-salt-identification-challenge/datagithub

數聽說明:ide

train.csv
id rle_mask
4000項,即有4000張圖片

depths.csv
id  z
z(地下深度,英尺)
22000項(爲train和test圖片張數總和)
[50, 959]

test
18000張圖片


sample_submission.csv
5f3b26ac68,1 2626 2628 100
數據從1開始,行數和列數都要調整爲從1開始,
對於python來講,不須要轉置,對於opencv來講要轉置

數據處理:函數

(1)對於每張拍攝好的原始圖片,有對應的深度信息,爲了方便fcn訓練,咱們把深度信息也存入到圖片中。能夠用opencv將已標註好的原圖,和帶預測的原圖的b通道保存原來的灰度信息,將g通道保存depth/256(整數倍),將r通道保存depth%256(取餘數)。ui

#include "opencv2/opencv.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <iostream>
#include <fstream>
#include <vector>
#include <map>
#include <hash_map>

int split(std::string str, std::string pattern, std::vector<std::string> &words)
{
    std::string::size_type pos;
    std::string word;
    int num = 0;
    str += pattern;
    std::string::size_type size = str.size();
    for (auto i = 0; i < size; i++) {
        pos = str.find(pattern, i);
        if (pos == i) {
            continue;//if first string is pattern
        }
        if (pos < size) {
            word = str.substr(i, pos - i);
            words.push_back(word);
            i = pos + pattern.size() - 1;
            num++;
        }
    }
    return num;
}

int main(int argc, char** argv)
{
    std::string input_path = "E:\\kaggle\\competition\\tgs\\test_images\\";
    std::string out_path = "E:\\kaggle\\competition\\tgs\\test_images_out\\";
    //read image name
    std::ifstream fin(input_path + "filelist.txt");
    std::string line;
    std::vector<std::string> image_names;
    while (!fin.eof()) {
        std::getline(fin, line);
        image_names.push_back(line);
        //std::cout << line << std::endl;
        //getchar();
    }

    std::ifstream depth_file("E:\\kaggle\\competition\\tgs\\depths.csv");
    std::getline(depth_file, line);
    //std::vector<int> depths;
    std::map<std::string, int> depths;
    while (!depth_file.eof()) {
        std::getline(depth_file, line);
        std::vector<std::string> words;
        int ret = split(line, ",", words);
        if (ret == 2) {
            //depths.push_back(atoi(words[1].c_str()));
            //std::cout << "depth " << depths[depths.size() - 1] << std::endl;
            //getchar();
            depths[words[0]] = atoi(words[1].c_str());
            //std::cout << words[0]  << ": " << depths[words[0]] << std::endl;
            //getchar();
        }
    }
    std::cout << "depths size = " << depths.size() << std::endl;

    //read image
    for (int i = 0; i < image_names.size(); ++i) {
        std::cout << "i = " << i << std::endl;
        cv::Mat src = cv::imread(input_path + image_names[i] + ".png", cv::IMREAD_GRAYSCALE);
        if (src.empty()) {
            return -1;
        }
        cv::Mat dst = cv::Mat::zeros(src.size(), CV_8UC3);
        for (int r = 0; r < dst.rows; ++r) {
            uchar *sptr = src.ptr<uchar>(r);
            cv::Vec3b *ptr = dst.ptr<cv::Vec3b>(r);
            for (int c = 0; c < dst.cols; ++c) {
                ptr[c][0] = sptr[c];
                ptr[c][1] = depths[image_names[i]] / 256;
                ptr[c][2] = depths[image_names[i]] % 256;
            }
        }
        cv::imwrite(out_path + image_names[i] + ".png", dst);
    }


    std::cout << "finish!" << std::endl;
    getchar();
    return 0;
}

更改輸入輸出路徑,對於已標註好的原圖片用一樣的方式處理。spa

(2)由於咱們的題目是二分類的語義分割,因此分割的結果的label只能是0和1,因此必須將masks中的圖片爲255更改成1。code

int main(int argc, char** argv)
{
    //read image name
    std::ifstream fin("filelist.txt");
    std::string line;
    std::vector<std::string> image_names;
    while (!fin.eof()) {
        std::getline(fin, line);
        image_names.push_back(line);
    }

    //modify value
    for (int num = 0; num < image_names.size(); ++num) {
        cv::Mat src;
        src = cv::imread("masks/" + image_names[num], -1);
        if (src.empty()) {
            std::cout << "fail: " << num << image_names[num] << std::endl;
            getchar();
            return -1;
        }

        cv::Mat dst;
        src.convertTo(dst, CV_8UC1);
        for (int j = 0; j < dst.rows; ++j) {
            uchar *ptr = dst.ptr<uchar>(j);
            for (int i = 0; i < dst.cols; ++i) {
                if (ptr[i] >= 128) {
                    ptr[i] = 1;
                }
            }
        }
        cv::imwrite("out/" + image_names[num], dst);
    }
}

也可直接下載處理好的圖片:連接:https://pan.baidu.com/s/1CAPIvQ6PayZ97eqeTpBcow 密碼:h3t9blog

二、下載語義分割的開源代碼

shelhamer/fcn.berkeleyvision.org: Fully Convolutional Networks for Semantic Segmentation by Jonathan Long*, Evan Shelhamer*, and Trevor Darrell. CVPR 2015 and PAMI 2016.  https://github.com/shelhamer/fcn.berkeleyvision.org

三、下載修改好的代碼

https://github.com/litingpan/fcn

四、將tgs-fcn32s、tgs-fcn16s、tgs-fcn8s複製到fcn.berkeleyvision.org文件夾中,將data/tgs複製到fcn.berkeleyvision.org/data文件夾中,將1中處理好的數據拷貝至tgs對應文件夾中。

五、訓練

(1)fcn32s訓練

fcn.berkeleyvision.org\tgs-fcn32s>python solve.py

image

(2)訓練fcn16s

fcn.berkeleyvision.org\tgs-fcn16s>python solve.py

image

(3)訓練fcn8s

fcn.berkeleyvision.org\tgs-fcn8s>python solve.py

image

能夠看到通過32倍、16倍、8倍上採樣最終達到overall accuracy(整體精度)爲0.928,mean accuracy(平均精度)爲0.887,mean IU(平均交併比)爲0.827,fwavacc(帶權重交併比)爲0.866。

六、預測

咱們用fcn8s訓練好的模型進行預測。

fcn.berkeleyvision.org\tgs-fcn8s>python infers.py

輸出的結果在fcn.berkeleyvision.org\data\tgs\predict\masksout文件夾中,由於值是否是0就是1,因此感受圖片都是黑色的,若是想要可視化能夠用opencv將1改成255,從新保存圖片。

七、將預測結果存到csv文件中。

int main(int argc, char** argv)
{
    //read image name
    std::ifstream fin("masksout/filelist.txt");
    std::string line;
    std::vector<std::string> image_names;
    while (!fin.eof()) {
        std::getline(fin, line);
        image_names.push_back(line);
    }

    std::ofstream fout("submission.csv");
    fout << "id,rle_mask" << std::endl;
    for (int k = 0; k < image_names.size(); ++k) {
        std::cout << "k = " << k << std::endl;
        cv::Mat src = cv::imread("masksout/" + image_names[k]);

        if (src.empty()) {
            return -1;
        }

        fout << image_names[k].substr(0, image_names[k].size()-4) << ",";
        cv::Mat gray;
        cv::cvtColor(src, gray, CV_BGR2GRAY);

        cv::Mat trans;
        cv::transpose(gray, trans);

        //fill hole
        cv::Mat tmp;
        cv::Mat hole;
        trans.convertTo(tmp, CV_8UC1, 255);
        dip::fillHole(tmp, hole);


        bool flag = false;
        int sum = 0;
        std::vector<int> list;
        int start_id = 0;
        for (int j = 0; j < src.rows; ++j) {
            uchar *ptr = hole.ptr<uchar>(j);
            for (int i = 0; i < src.cols; ++i) {

                if (ptr[i] && !flag) {
                    flag = true;
                    start_id = j*gray.rows + i+1;
                    sum = 0;
                    sum++;
                }
                else if (ptr[i] && flag) {
                    sum++;
                }
                else if (!ptr[i] && flag){
                    flag = false;
                    list.push_back(start_id);
                    list.push_back(sum);
                    //std::cout << "start_id = " << start_id << ", " << "sum = " << sum << std::endl;
                    //getchar();
                }
            }
        }//for j
        for (int n = 0; n < list.size(); ++n) {
            if (n == 0) {
                fout << list[0];
            }
            else {
                fout << " " << list[n];
            }
        }
        fout << std::endl;
        if (list.size() % 2 != 0) {
            std::cout << "error " << image_names[k] << std::endl;
        }

    }
    std::cout << "finish!" << std::endl;
    getchar();
    return 0;
}

其中fillHole函數爲

namespace dip {

    void fillHole(const cv::Mat &src, cv::Mat &dst)
    {
        cv::Size sz = src.size();
        cv::Mat tmp = cv::Mat::zeros(sz.height + 2, sz.width + 2, src.type());
        src.copyTo(tmp(cv::Range(1, sz.height + 1), cv::Range(1, sz.width + 1)));
        cv::floodFill(tmp, cv::Point(0, 0), cv::Scalar(255));
        cv::Mat cut;
        tmp(cv::Range(1, sz.height + 1), cv::Range(1, sz.width + 1)).copyTo(cut);
        dst = src | (~cut);
    }


}

八、提交結果

360截圖20180823233949918

看來這個結果離比賽要求的答案還差很遠。


end

相關文章
相關標籤/搜索