Caffe源碼-DataTransformer類

DataTransformer類簡介

DataTransformer類中主要用於圖像預處理操做,layer中可設置TransformationParameter類型的消息來對輸入圖像進行減均值、隨機鏡像、隨機裁剪或縮放。DataTransformer類中主要包含重載函數Transform(),能夠對各類類型的圖像數據進行預處理,並存入到Blob類型的數據中。類中還包含了如下變量。網絡

TransformationParameter param_; //預處理參數
shared_ptr<Caffe::RNG> rng_;    //隨機數生成器
Phase phase_;                   //網絡狀態,TRAIN仍是TEST
Blob<Dtype> data_mean_;         //數據的均值,從mean_file中讀取到的均值數據
vector<Dtype> mean_values_;     //均值數值,以mean_value形式設置一系列數據

其中TransformationParameter消息中包含的內容以下。app

// Message that stores parameters used to apply transformation to the data layer's data
message TransformationParameter {
  // For data pre-processing, we can do simple scaling and subtracting the data mean,
  // if provided. Note that the mean subtraction is always carried out before scaling.
  optional float scale = 1 [default = 1];       //數值縮放係數    //縮放操做老是在減均值以後進行
  // Specify if we want to randomly mirror data.
  optional bool mirror = 2 [default = false];   //預處理時是否須要隨機鏡像
  // Specify if we would like to randomly crop an image.
  optional uint32 crop_size = 3 [default = 0];  //裁剪後的圖像尺寸,非0值表示預處理時須要裁剪圖像
  // mean_file and mean_value cannot be specified at the same time
  optional string mean_file = 4;                //均值文件的路徑,均值文件爲二進制proto類型文件
  // if specified can be repeated once (would subtract it from all the channels)
  // or can be repeated the same number of times as channels
  // (would subtract them from the corresponding channel)
  // mean_file與mean_value不能同時設置
  repeated float mean_value = 5;                //均值數值,mean_value的個數等於1或圖像通道數
  // Force the decoded image to have 3 color channels.
  optional bool force_color = 6 [default = false];  //編碼數據解碼時強制轉化爲3通道彩色圖
  // Force the decoded image to have 1 color channels.
  optional bool force_gray = 7 [default = false];   //編碼數據解碼時強制轉化爲單通道灰度圖
}

data_transformer.cpp源碼

template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param, Phase phase)
    : param_(param), phase_(phase) {    //構造函數,讀取均值文件中的數據或者均值數值
  // check if we want to use mean_file
  if (param_.has_mean_file()) {         //設置了均值文件
    //TransformationParameter消息中不能同時設置mean_file和mean_value參數
    CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time";
    const string& mean_file = param.mean_file();    //均值文件名
    if (Caffe::root_solver()) {
      LOG(INFO) << "Loading mean file from: " << mean_file;   //主線程中打印文件名
    }
    BlobProto blob_proto;
    ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); //從該二進制proto文件中讀取數據到blob_proto消息中
    data_mean_.FromProto(blob_proto);   //將BlobProto類型的消息中的數據拷貝到Blob類型的變量中
  }
  // check if we want to use mean_value
  if (param_.mean_value_size() > 0) {   //若是設置了均值數值
    CHECK(param_.has_mean_file() == false) <<
      "Cannot specify mean_file and mean_value at the same time"; //一樣先檢查不能同時設置
    for (int c = 0; c < param_.mean_value_size(); ++c) {
      mean_values_.push_back(param_.mean_value(c));   //將設置的值所有保存到類中
    }
  }
}

//對Datum類中的圖像進行預處理操做(減均值/裁剪/鏡像/數值縮放),將處理後的圖像數據存入緩衝區中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Dtype* transformed_data) {
  const string& data = datum.data();            //圖像原始數據
  const int datum_channels = datum.channels();  //原始圖像的通道數
  const int datum_height = datum.height();      //原始圖像高度
  const int datum_width = datum.width();        //原始圖像寬度

  const int crop_size = param_.crop_size();     //裁剪後的尺寸,非0爲有效值
  const Dtype scale = param_.scale();           //數值縮放係數
  const bool do_mirror = param_.mirror() && Rand(2);    //是否須要鏡像, mirror()爲是否須要隨機鏡像,Rand(2)會返回0或1的值
  const bool has_mean_file = param_.has_mean_file();    //是否設置了均值文件
  const bool has_uint8 = data.size() > 0;               //datum中uint8數據的個數是否不爲空
  const bool has_mean_values = mean_values_.size() > 0; //是否設置了均值數值

  CHECK_GT(datum_channels, 0);        //有效性檢查,圖像通道數是否大於0
  CHECK_GE(datum_height, crop_size);  //原始圖像高度大於等於裁剪後的尺寸
  CHECK_GE(datum_width, crop_size);   //原始圖像寬度大於等於裁剪後的尺寸

  Dtype* mean = NULL;
  if (has_mean_file) {
    //設置了均值文件,則檢查均值文件中數據的channel/height/width與原始圖像的是否匹配
    CHECK_EQ(datum_channels, data_mean_.channels());
    CHECK_EQ(datum_height, data_mean_.height());
    CHECK_EQ(datum_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();   //最後返回均值文件的數據指針
  }
  if (has_mean_values) {
    //設置了均值數值,則設置的數值的個數要麼爲1(圖像的全部通道都減去相同的值),要麼設置的個數與圖像的通道數相等
    CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
     "Specify either 1 mean_value or as many as channels: " << datum_channels;
    if (datum_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < datum_channels; ++c) {  //設置的數值的個數爲1,可是圖像通道數個數不爲1
        mean_values_.push_back(mean_values_[0]);  //將每一個通道對應的均值均設置爲該值mean_values_[0]
      }
    }
  }

  int height = datum_height;    //height/width爲預處理後圖像的長寬,初始時爲原圖尺寸
  int width = datum_width;

  int h_off = 0;  //裁剪時的h/w方向的偏移量
  int w_off = 0;
  if (crop_size) {
    height = crop_size; //若是設置了裁剪的尺寸,則更新
    width = crop_size;
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {    //訓練模式下,隨機獲得裁剪的h和w方向的偏移
      h_off = Rand(datum_height - crop_size + 1); //返回一個 0 ~ datum_height - crop_size 之間的隨機數
      w_off = Rand(datum_width - crop_size + 1);
    } else {                  //測試模式下,固定爲中心裁剪
      h_off = (datum_height - crop_size) / 2;     //中心裁剪的h/w的偏移
      w_off = (datum_width - crop_size) / 2;
    }
  }

  //datum內只存有一張圖像,num=1,n=0
  //top_index爲輸出圖像的某個點的在輸出圖像中的索引,data_index爲該點在原始圖像中的索引
  //datum_element爲該點在原始圖像中的值
  Dtype datum_element;
  int top_index, data_index;
  for (int c = 0; c < datum_channels; ++c) {
    for (int h = 0; h < height; ++h) {
      for (int w = 0; w < width; ++w) {
        //原始圖像中的(0, c, h_off + h, w_off + w)點
        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
        if (do_mirror) {                                          //此處能夠看出鏡像爲width方向的鏡像
          top_index = (c * height + h) * width + (width - 1 - w); //鏡像,則對應輸出圖像的(0,c,h,width - 1 - w)點
        } else {
          top_index = (c * height + h) * width + w;       //無需鏡像,則對應輸出的(0,c,h,w)點
        }
        if (has_uint8) {    //若是datum中存在uint8數據
          datum_element = static_cast<Dtype>(static_cast<uint8_t>(data[data_index])); //原始圖像上該點的值
        } else {            //若是datum中不存在uint8數據,則從float_data中讀取float類型的數據
          datum_element = datum.float_data(data_index);   //一樣,該點的值
        }
        if (has_mean_file) {
          //設置了均值文件,則每一個數據都有個對應的均值mean[data_index],減去均值後乘上數值縮放係數,獲得輸出的值
          transformed_data[top_index] = (datum_element - mean[data_index]) * scale;
        } else {
          if (has_mean_values) {
            //設置了均值數值,則圖像每一個通道上的數據都存在一個均值,減均值乘上縮放係數
            transformed_data[top_index] = (datum_element - mean_values_[c]) * scale;
          } else {
            transformed_data[top_index] = datum_element * scale;    //未設置均值,直接縮放
          }
        }
      }
    }
  }
}

//對Datum類中的圖像進行預處理操做(減均值/裁剪/鏡像/數值縮放),將處理後的圖像數據存入Blob類型的數據中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum, Blob<Dtype>* transformed_blob) {
  // If datum is encoded, decode and transform the cv::image.
  if (datum.encoded()) {    //若是數據爲編碼過的數據,則須要使用opencv進行解碼
#ifdef USE_OPENCV
    //force_color表示解碼後的數據爲3通道彩色圖,force_gray表示解碼後的圖像爲單通道的灰度圖,二者不能同時設置
    CHECK(!(param_.force_color() && param_.force_gray())) << "cannot set both force_color and force_gray";
    cv::Mat cv_img;
    if (param_.force_color() || param_.force_gray()) {
    // If force_color then decode in color otherwise decode in gray.
      cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //從內存緩衝區中讀取一張圖像
    } else {
      cv_img = DecodeDatumToCVMatNative(datum); //未設置force_color/force_gray,則按原始格式讀取圖像
    }
    // Transform the cv::image into blob.
    return Transform(cv_img, transformed_blob); //將讀取的圖像進行預處理,而後存入transformed_blob中
#else
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  } else {
    //未編碼數據,不能設置force_color或force_gray,不然報錯
    if (param_.force_color() || param_.force_gray()) {
      LOG(ERROR) << "force_color and force_gray only for encoded datum";
    }
  }

  const int crop_size = param_.crop_size();     //裁剪後的尺寸
  const int datum_channels = datum.channels();  //原始圖像的通道數/高度/寬度
  const int datum_height = datum.height();
  const int datum_width = datum.width();

  // Check dimensions.
  const int channels = transformed_blob->channels();  //輸出blob的通道數/高度/寬度/個數
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int num = transformed_blob->num();

  CHECK_EQ(channels, datum_channels); //檢查原始圖像與輸出圖像的尺寸是否匹配
  CHECK_LE(height, datum_height);
  CHECK_LE(width, datum_width);
  CHECK_GE(num, 1);

  if (crop_size) {
    CHECK_EQ(crop_size, height);    //須要裁剪,則原始圖像的寬高大於等於輸出的圖像的寬高
    CHECK_EQ(crop_size, width);
  } else {
    CHECK_EQ(datum_height, height); //無需裁剪,則二者相等
    CHECK_EQ(datum_width, width);
  }

  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //輸出blob的數據指針
  Transform(datum, transformed_data); //預處理圖像,並將數據存入transformed_data中
}

//對datum_vector中的多張圖像進行預處理,並將結果存入Blob類型的數據中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
                                       Blob<Dtype>* transformed_blob) {
  const int datum_num = datum_vector.size();    //原始圖像數據的個數
  const int num = transformed_blob->num();      //輸出blob的各個維度的值
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();

  //檢查輸入的原始圖像的個數,大於0,不超過blob的num維度的值
  CHECK_GT(datum_num, 0) << "There is no datum to add";
  CHECK_LE(datum_num, num) << "The size of datum_vector must be no greater than transformed_blob->num()";
  Blob<Dtype> uni_blob(1, channels, height, width);   //用於存放單個圖像數據
  for (int item_id = 0; item_id < datum_num; ++item_id) {
    int offset = transformed_blob->offset(item_id);   //(n=item_id, c=0, h=0, w=0)點的偏移量,用於存放一張新的圖像
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //將uni_blob的數據指針指向transformed_blob的緩衝區
    Transform(datum_vector[item_id], &uni_blob);      //預處理,並將預處理後的圖像保存在uni_blob中
  }
}

//對mat_vector中的多張圖像進行預處理,並將結果存入Blob類型的數據中
#ifdef USE_OPENCV
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
                                       Blob<Dtype>* transformed_blob) {
  const int mat_num = mat_vector.size();            //輸入圖像的個數
  const int num = transformed_blob->num();          //輸出blob的個數/通道數/高度/寬度
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();

  //一樣檢查輸入圖像的個數大於0,小於輸出blob的num維度的值
  CHECK_GT(mat_num, 0) << "There is no MAT to add";
  CHECK_EQ(mat_num, num) << "The size of mat_vector must be equals to transformed_blob->num()";
  Blob<Dtype> uni_blob(1, channels, height, width);
  for (int item_id = 0; item_id < mat_num; ++item_id) {
    int offset = transformed_blob->offset(item_id);   //(n=item_id, c=0, h=0, w=0)的偏移
    uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset); //將uni_blob的數據指針指向transformed_blob的緩衝區
    Transform(mat_vector[item_id], &uni_blob);    //預處理圖像,結果存入uni_blob中
  }
}

//對cv_img(單張圖像)進行預處理,並將結果存入Blob類型的數據中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
                                       Blob<Dtype>* transformed_blob) {
  const int crop_size = param_.crop_size();     //裁剪後的圖像尺寸
  const int img_channels = cv_img.channels();   //原始圖像的通道數/高度/寬度
  const int img_height = cv_img.rows;
  const int img_width = cv_img.cols;

  // Check dimensions.
  const int channels = transformed_blob->channels();  //輸出blob的通道數/高度/寬度/個數
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int num = transformed_blob->num();

  CHECK_EQ(channels, img_channels); //檢查輸入圖像與輸出blob的各個維度是否匹配
  CHECK_LE(height, img_height);
  CHECK_LE(width, img_width);
  CHECK_GE(num, 1);

  //cv_img中的圖像數據必須爲uint8類型
  CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";

  const Dtype scale = param_.scale();                   //設置的數值縮放係數
  const bool do_mirror = param_.mirror() && Rand(2);    //是否鏡像
  const bool has_mean_file = param_.has_mean_file();    //是否設置了均值文件
  const bool has_mean_values = mean_values_.size() > 0; //是否設置了均值數值

  CHECK_GT(img_channels, 0);        //檢查輸入圖像的維度/高度/寬度是否有效
  CHECK_GE(img_height, crop_size);
  CHECK_GE(img_width, crop_size);

  Dtype* mean = NULL;
  if (has_mean_file) {
    //存在均值文件,則還會檢查均值blob的數據的形狀與輸入圖像的形狀是否匹配
    CHECK_EQ(img_channels, data_mean_.channels());
    CHECK_EQ(img_height, data_mean_.height());
    CHECK_EQ(img_width, data_mean_.width());
    mean = data_mean_.mutable_cpu_data();   //均值數據指針
  }
  if (has_mean_values) {
    //若是設置了均值數值,則會檢查均值數值的個數是否爲1或者等於輸入圖像的通道數
    CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
     "Specify either 1 mean_value or as many as channels: " << img_channels;
    if (img_channels > 1 && mean_values_.size() == 1) {
      // Replicate the mean_value for simplicity
      for (int c = 1; c < img_channels; ++c) {
        //均值數值的個數爲1,圖像通道數不爲0,則將該均值mean_values_[0]做爲每一個通道的均值
        mean_values_.push_back(mean_values_[0]);
      }
    }
  }

  int h_off = 0;    //裁剪的h/w方向的偏移
  int w_off = 0;
  cv::Mat cv_cropped_img = cv_img;  //裁剪後的圖像,初始設置爲原始圖像
  if (crop_size) {
    CHECK_EQ(crop_size, height);    //檢查輸出blob的尺寸是否等於裁剪後的圖像
    CHECK_EQ(crop_size, width);
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {          //一樣,訓練模式下會隨機獲得裁剪時h/w方向的偏移值
      h_off = Rand(img_height - crop_size + 1);
      w_off = Rand(img_width - crop_size + 1);
    } else {                        //測試模式下會使用中心裁剪方式獲得h/w方向的偏移
      h_off = (img_height - crop_size) / 2;
      w_off = (img_width - crop_size) / 2;
    }
    cv::Rect roi(w_off, h_off, crop_size, crop_size); //設置圖像興趣區域的位置
    cv_cropped_img = cv_img(roi);   //獲得裁剪後的圖像
  } else {
    CHECK_EQ(img_height, height);   //非裁剪模式,檢查輸入圖像的尺寸與輸入blob的形狀是否一致
    CHECK_EQ(img_width, width);
  }

  CHECK(cv_cropped_img.data);   //裁剪後的圖像數據不爲空

  //此處注意opencv中圖像是以(h,w,c)形式存放的
  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //輸出blob的數據指針
  int top_index;
  for (int h = 0; h < height; ++h) {
    const uchar* ptr = cv_cropped_img.ptr<uchar>(h);  //裁剪圖像的第h行數據的指針
    int img_index = 0;
    for (int w = 0; w < width; ++w) {
      for (int c = 0; c < img_channels; ++c) {
        if (do_mirror) {    //鏡像模式下
          top_index = (c * height + h) * width + (width - 1 - w); //獲得裁剪圖像上(h,w,c)點在輸出blob上的索引(c,h,width - 1 - w)
        } else {
          top_index = (c * height + h) * width + w;         //裁剪圖像上(h,w,c)點對應輸出blob上的(c,h,w)點
        }
        // int top_index = (c * height + h) * width + w;
        Dtype pixel = static_cast<Dtype>(ptr[img_index++]); //裁剪圖像上(h,w,c)點的值
        if (has_mean_file) {
          //裁剪圖像上(h,w,c)點對應均值文件上的(c, h_off + h, w_off + w)點
          int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
          transformed_data[top_index] = (pixel - mean[mean_index]) * scale;   //減均值,縮放
        } else {
          if (has_mean_values) {
            //裁剪圖像上(h,w,c)點對應均值數值的mean_values_[c]
            transformed_data[top_index] = (pixel - mean_values_[c]) * scale;  //減均值,縮放
          } else {
            transformed_data[top_index] = pixel * scale;  //未設置均值,直接縮放
          }
        }
      }
    }
  }
}
#endif  // USE_OPENCV

//對input_blob中的全部圖像進行預處理,並將結果存入transformed_blob中
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
                                       Blob<Dtype>* transformed_blob) {
  const int crop_size = param_.crop_size();           //裁剪後的尺寸
  const int input_num = input_blob->num();            //輸入圖像的個數
  const int input_channels = input_blob->channels();  //輸入圖像的通道數/高度/寬度
  const int input_height = input_blob->height();
  const int input_width = input_blob->width();

  if (transformed_blob->count() == 0) {   //若是輸出blob爲空,則先按照輸出圖像的尺寸調整blob的形狀
    // Initialize transformed_blob with the right shape.
    if (crop_size) {    //設置了裁剪尺寸    //調整形狀,在實際訪問內部數據的以後便會爲其分配相應的空間
      transformed_blob->Reshape(input_num, input_channels, crop_size, crop_size);
    } else {
      transformed_blob->Reshape(input_num, input_channels, input_height, input_width);
    }
  }

  const int num = transformed_blob->num();    //輸出圖像的個數/通道數/高度/寬度
  const int channels = transformed_blob->channels();
  const int height = transformed_blob->height();
  const int width = transformed_blob->width();
  const int size = transformed_blob->count(); //輸出blob的大小

  CHECK_LE(input_num, num);           //輸入圖像個數不超過輸出圖像個數
  CHECK_EQ(input_channels, channels); //輸入輸出圖像通道數相同
  CHECK_GE(input_height, height);     //輸入圖像尺寸不小於輸出圖像尺寸
  CHECK_GE(input_width, width);

  const Dtype scale = param_.scale();                   //設置的數值縮放係數
  const bool do_mirror = param_.mirror() && Rand(2);    //是否鏡像
  const bool has_mean_file = param_.has_mean_file();    //是否設置了均值文件
  const bool has_mean_values = mean_values_.size() > 0; //是否設置了均值數值

  int h_off = 0;    //裁剪圖像時h/w方向的偏移量
  int w_off = 0;
  if (crop_size) {  //須要裁剪
    CHECK_EQ(crop_size, height);  //輸出圖像與裁剪尺寸一致
    CHECK_EQ(crop_size, width);
    // We only do random crop when we do training.
    if (phase_ == TRAIN) {        //訓練模式,隨機獲取裁剪時h/w方向的偏移量
      h_off = Rand(input_height - crop_size + 1);
      w_off = Rand(input_width - crop_size + 1);
    } else {                      //測試模式,獲取中心裁剪時h/w方向的偏移量
      h_off = (input_height - crop_size) / 2;
      w_off = (input_width - crop_size) / 2;
    }
  } else {
    CHECK_EQ(input_height, height); //非裁剪模式,檢查輸入圖像與輸出圖像尺寸是否一致
    CHECK_EQ(input_width, width);
  }

  Dtype* input_data = input_blob->mutable_cpu_data();   //輸入blob的數據指針
  if (has_mean_file) {
    CHECK_EQ(input_channels, data_mean_.channels());  //設置了均值文件,則檢查均值文件中的blob與輸入blob的c/h/w是否一致
    CHECK_EQ(input_height, data_mean_.height());
    CHECK_EQ(input_width, data_mean_.width());
    for (int n = 0; n < input_num; ++n) {
      int offset = input_blob->offset(n);   //輸入blob中第n張圖像數據的起始偏移
      caffe_sub(data_mean_.count(), input_data + offset,
            data_mean_.cpu_data(), input_data + offset);  //相減,(input_data + offset)[] -= data_mean_cpp_data[]
    }
  }

  if (has_mean_values) {    //設置了均值數值
    //一樣,檢查均值數值的個數等於1或等於通道數
    CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<
     "Specify either 1 mean_value or as many as channels: " << input_channels;
    if (mean_values_.size() == 1) {
      caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);  //input_data[i] += -(mean_values_[0])
    } else {
      for (int n = 0; n < input_num; ++n) {
        for (int c = 0; c < input_channels; ++c) {
          int offset = input_blob->offset(n, c);    //輸入blob的第n張圖的第c通道的起始偏移,同一通道需減去相同的均值數值
          // (input_data + offset)[i] += -(mean_values_[c])
          caffe_add_scalar(input_height * input_width, -(mean_values_[c]), input_data + offset);
        }
      }
    }
  }

  Dtype* transformed_data = transformed_blob->mutable_cpu_data(); //輸出blob的數據指針

  for (int n = 0; n < input_num; ++n) {
    int top_index_n = n * channels;     //計算輸出偏移的中間量,很差描述,大體可理解爲輸出blob的(n, ?, ?, ?)點的偏移
    int data_index_n = n * channels;    //輸入blob的的(n, ?, ?, ?)點的偏移
    for (int c = 0; c < channels; ++c) {
      int top_index_c = (top_index_n + c) * height;                   //輸出blob的(n, c, ?, ?)點的偏移
      int data_index_c = (data_index_n + c) * input_height + h_off;   //輸入blob的(n, c, h_off, ?)點的偏移
      for (int h = 0; h < height; ++h) {
        int top_index_h = (top_index_c + h) * width;                  //輸出blob的(n, c, h, ?)點的偏移
        int data_index_h = (data_index_c + h) * input_width + w_off;  //輸入blob的(n, c, h_off + h, w_off)點的偏移
        if (do_mirror) {  //須要鏡像
          int top_index_w = top_index_h + width - 1;                  //輸出blob的(n, c, h, width - 1)點的偏移
          for (int w = 0; w < width; ++w) {
            //輸出blob的(n, c, h, width - 1 - w)點對應輸入blob的(n, c, h_off + h, w_off + w)點
            transformed_data[top_index_w-w] = input_data[data_index_h + w];
          }
        } else {
          for (int w = 0; w < width; ++w) {
            //輸出blob的(n, c, h, w)點對應輸入blob的(n, c, h_off + h, w_off + w)點
            transformed_data[top_index_h + w] = input_data[data_index_h + w];
          }
        }
      }
    }
  }
  if (scale != Dtype(1)) {    //非1,則還需縮放數據
    DLOG(INFO) << "Scale: " << scale;
    caffe_scal(size, scale, transformed_data);  //transformed_data[] *= scale
  }
}

//推斷圖像在預處理以後的形狀
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
  if (datum.encoded()) {    //編碼過的數據
#ifdef USE_OPENCV
    CHECK(!(param_.force_color() && param_.force_gray()))
        << "cannot set both force_color and force_gray";  //一樣,force_color/force_gray不能同時設置
    cv::Mat cv_img;
    if (param_.force_color() || param_.force_gray()) {
    // If force_color then decode in color otherwise decode in gray.
      cv_img = DecodeDatumToCVMat(datum, param_.force_color()); //讀取數據,返回圖像
    } else {
      cv_img = DecodeDatumToCVMatNative(datum);
    }
    // InferBlobShape using the cv::image.
    return InferBlobShape(cv_img);    //判斷圖像在預處理後的形狀,返回
#else
    LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif  // USE_OPENCV
  }
  //非編碼數據,直接判斷
  const int crop_size = param_.crop_size();     //裁剪後的尺寸
  const int datum_channels = datum.channels();  //輸入數據的通道數/高度/寬度
  const int datum_height = datum.height();
  const int datum_width = datum.width();
  // Check dimensions.
  CHECK_GT(datum_channels, 0);  //有效性檢查,輸入數據的通道數大於0,寬高不小於裁剪後的尺寸
  CHECK_GE(datum_height, crop_size);
  CHECK_GE(datum_width, crop_size);
  // Build BlobShape.
  vector<int> shape(4);     //圖像形狀
  shape[0] = 1;             //單張圖像,num固定爲1
  shape[1] = datum_channels;
  shape[2] = (crop_size)? crop_size: datum_height;  //須要裁剪則爲裁剪的尺寸,不然爲原始尺寸
  shape[3] = (crop_size)? crop_size: datum_width;
  return shape;
}

//推斷datum_vector中的圖像在預處理以後的形狀
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const vector<Datum> & datum_vector) {
  const int num = datum_vector.size();
  CHECK_GT(num, 0) << "There is no datum to in the vector"; //圖像個數需大於0
  // Use first datum in the vector to InferBlobShape.
  vector<int> shape = InferBlobShape(datum_vector[0]);  //獲得形狀,(1, channel, height, width)
  // Adjust num to the size of the vector.
  shape[0] = num;   //以圖像個數設置num維度的值
  return shape;
}

#ifdef USE_OPENCV
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) { //推斷cv_img在預處理以後的圖像尺寸
  const int crop_size = param_.crop_size();     //裁剪尺寸
  const int img_channels = cv_img.channels();   //輸入圖像的通道數/高度/寬度
  const int img_height = cv_img.rows;
  const int img_width = cv_img.cols;
  // Check dimensions.
  CHECK_GT(img_channels, 0);        //同理,有效性檢查
  CHECK_GE(img_height, crop_size);
  CHECK_GE(img_width, crop_size);
  // Build BlobShape.
  vector<int> shape(4);
  shape[0] = 1;
  shape[1] = img_channels;
  shape[2] = (crop_size)? crop_size: img_height;  //輸出尺寸爲裁剪後的尺寸或者原始尺寸
  shape[3] = (crop_size)? crop_size: img_width;
  return shape;
}

template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
    const vector<cv::Mat> & mat_vector) {       //推斷mat_vector中的圖像在預處理以後的形狀
  const int num = mat_vector.size();
  CHECK_GT(num, 0) << "There is no cv_img to in the vector";  //圖像個數大於0
  // Use first cv_img in the vector to InferBlobShape.
  vector<int> shape = InferBlobShape(mat_vector[0]);  //獲得單張圖像預處理後的尺寸
  // Adjust num to the size of the vector.
  shape[0] = num;   //以圖像個數設置num維度的值
  return shape;
}
#endif  // USE_OPENCV

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {       //初始化隨機數生成器
  //是否須要隨機數生成器,只有設置了隨機鏡像或訓練模式下設置了隨機裁剪才須要隨即操做
  const bool needs_rand = param_.mirror() || (phase_ == TRAIN && param_.crop_size());
  if (needs_rand) {
    const unsigned int rng_seed = caffe_rng_rand(); //隨機獲得一個隨機種子
    rng_.reset(new Caffe::RNG(rng_seed));       //使用該種子建立一個隨機數生成器
  } else {
    rng_.reset(); //不須要隨機,釋放
  }
}

template <typename Dtype>
int DataTransformer<Dtype>::Rand(int n) {   //返回一個0 ~ n-1 之間的隨機數
  CHECK(rng_);
  CHECK_GT(n, 0);
  caffe::rng_t* rng = static_cast<caffe::rng_t*>(rng_->generator());  //隨機數生成器
  return ((*rng)() % n);  //隨機數,取餘
}

小結

  1. 注意opencv中圖像是以(height, width, channel)形式存放的,與caffe中的(num, channel, height,width)形式不一樣。
  2. caffe::RNG類中封裝了boost庫和CUDA的CURAND庫的隨機數函數,實現了跨平臺編譯。CURAND庫的函數可參考官方提供的文檔。

參考

https://docs.nvidia.com/cuda/pdf/CURAND_Library.pdfdom

Caffe的源碼筆者是第一次閱讀,一邊閱讀一邊記錄,對代碼的理解和分析可能會存在錯誤或遺漏,但願各位讀者批評指正,謝謝支持!ide

相關文章
相關標籤/搜索