【MMDetection 超全專欄】二,配置類和註冊器&數據處理&訓練pipline

0. 目錄

目錄,第一節和第二節請看上篇推文

第三節 配置類和註冊器

這兩個東西可變爲自用+練習。html

0.3.1 配置類

配置方式支持python/json/yaml,從mmcv的Config解析,其功能同maskrcnn-benchmark的yacs相似,將字典的取值方式屬性化.這裏貼部分代碼,以供學習。python

class Config(object):
...
@staticmethod
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
shutil.copyfile(filename,
osp.join(temp_config_dir, '_tempconfig.py'))
sys.path.insert(0, temp_config_dir)
mod = import_module('_tempconfig')
sys.path.pop(0)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules['_tempconfig']
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
else:
raise IOError('Only py/yml/yaml/json type are supported now!')

cfg_text = filename + '\n'
with open(filename, 'r') as f:
cfg_text += f.read()
# 2.0新增的配置文件的組合繼承
if '_base_' in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop('_base_')
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]

cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
# 遞歸,可搜索staticmethod and recursion
# 靜態方法調靜態方法,類方法調靜態方法
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)

base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
# 合併
Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict

# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)

return cfg_dict, cfg_text

...
# 獲取key值
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
# 序列化
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
# 將字典屬性化主要用了__setattr__
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
# 更新key值
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
# 迭代器
def __iter__(self):
return iter(self._cfg_dict)

主要考慮點是本身怎麼實現相似的東西,核心點就是python的基本魔法函數的應用,可同時參考yacs。web

0.3.2 註冊器

把基本對象放到一個繼承了字典的對象中,實現了對象的靈活管理。算法

import inspect
from functools import partial

import mmcv


class Registry(object):
# 2.0 放到mmcv中

def __init__(self, name):
self._name = name
self._module_dict = dict()

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key):
return self._module_dict.get(key, None)

def _register_module(self, module_class, force=False):
"""Register a module.

Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class # 類名:類

def register_module(self, cls=None, force=False):
# 做爲類cls的裝飾器
if cls is None:
# partial函數(類)固定參數,返回新對象,遞歸不是很清楚
return partial(self.register_module, force=force)
self._register_module(cls, force=force) # 將cls裝進當前Registry對象的中_module_dict
return cls # 返回類

def build_from_cfg(cfg, registry, default_args=None):
assert isinstance(cfg, dict) and 'type' in cfg
assert isinstance(default_args, dict) or default_args is None
args = cfg.copy()
obj_type = args.pop('type')
if mmcv.is_str(obj_type):
# 從註冊類中拿出obj_type類
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
# 增長一些新的參數
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args) # **args是將字典解析成位置參數(k=v)。

第四節 數據處理

數據處理多是煉丹師接觸最爲密集的了,由於一般狀況,除了數據的離線處理,寫個數據類,就能夠煉丹了。但本節主要涉及數據的在線處理,更進一步應該是檢測分割數據的pytorch處理方式。雖然mmdet將經常使用的數據都實現了,並且也實現了中間通用數據格式,但,這和模型,損失函數,性能評估的實現也相關,好比你想把官網的centernet完整的改爲mmdet風格,就能看到(看起來不必)。json

0.4.1 檢測分割數據

看看配置文件,數據相關的有datadict,裏面包含了train,val,test的路徑信息,用於數據類初始化,有pipeline,將各個函數及對應參數以字典形式放到列表裏,是對pytorch原裝的transforms+compose,在檢測,分割相關數據上的一次封裝,使得形式更加統一。api

從builder.py中build_dataset函數能看到,構建數據有三種方式,ConcatDataset,RepeatDataset和從註冊器中提取。其中dataset_wrappers.py中ConcatDataset和RepeatDataset意義自明,前者繼承自pytorch原始的ConcatDataset,將多個數據集整合到一塊兒,具體爲把不一樣序列(可參考容器的抽象基類https://docs.python.org/zh-cn/3/library/collections.abc.html)的長度相加,__getitem__函數對應index替換一下,後者就是單個數據類(序列)的屢次重複。就功能來講,前者提升數據豐富度,後者可解決數據太少使得loading時間長的問題(見代碼註釋)。而被註冊的數據類在datasets下一些熟知的數據名文件中。其中,基類爲custom.py中的CustomDataset,coco繼承自它,cityscapes繼承自coco,xml_style的XMLDataset繼承CustomDataset,而後wider_face,voc均繼承自XMLDataset。所以這裏先分析一下CustomDataset。微信

CustomDataset 記錄數據路徑等信息,解析標註文件,將每一張圖的全部信息以字典做爲數據結構存在results中,而後進入pipeline:數據加強相關操做,代碼以下:數據結構

self.pipeline = Compose(pipeline)
# Compose是實現了__call__方法的類,其做用是使實例可以像函數同樣被調用,同時不影響實例自己的生命週期
def pre_pipeline(self, results):
# 擴展字典信息
results['img_prefix'] = self.img_prefix
results['seg_prefix'] = self.seg_prefix
results['proposal_file'] = self.proposal_file
results['bbox_fields'] = []
results['mask_fields'] = []
results['seg_fields'] = []

def prepare_train_img(self, idx):
img_info = self.img_infos[idx]
ann_info = self.get_ann_info(idx)
# 基本信息,初始化字典
results = dict(img_info=img_info, ann_info=ann_info)
if self.proposals is not None:
results['proposals'] = self.proposals[idx]
self.pre_pipeline(results)
return self.pipeline(results) # 數據加強

def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
while True:
data = self.prepare_train_img(idx)
if data is None:
idx = self._rand_another(idx)
continue
return data

這裏數據結構的選取須要注意一下,字典結構,在數據加強庫albu中也是如此處理,所以能夠快速替換爲albu中的算法。另外每一個數據類增長了各自的evaluate函數。evaluate基礎函數在mmdet.core.evaluation中,後作補充。app

mmdet的數據處理,字典結構pipelineevaluate是三個關鍵部分。其餘全部類的文件解析部分,數據篩選等,看看便可。由於咱們知道,pytorch讀取數據,是將序列轉化爲迭代器後進行io操做,因此在dataset下除了pipelines外還有loader文件夾,裏面實現了分組,分佈式分組採樣方法,以及調用了mmcv中的collate函數(此處爲1.x版本,2.0版本將loader移植到了builder.py中),且build_dataloader封裝的DataLoader最後在 train_detector中被調用,這部分將在後面補充,這裏說說pipelines。dom

返回maskrcnn的配置文件(1.x,2.0看base config),能夠看到訓練和測試的不一樣之處:LoadAnnotations,MultiScaleFlipAug,DefaultFormatBundle和Collect。額外提示,雖然測試沒有LoadAnnotations,根據CustomDataset可知,它仍需標註文件,這和inference的pipeline不一樣,也即這裏的test實爲evaluate。

# 序列中的dict能夠隨意刪減,增長,屬於數據加強調參內容
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

最後這些全部操做被Compose串聯起來,代碼以下:

@PIPELINES.register_module
class Compose(object):

def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence) # 列表是序列結構
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')

def __call__(self, data):
for t in self.transforms:
data = t(data)
if data is None:
return None
return data

上面代碼能看到,配置文件中pipeline中的字典傳入build_from_cfg函數,逐一實現了各個加強類(方法)。擴展的加強類均需實現__call__方法,這和pytorch原始方法是一致的。

有了以上認識,從新梳理一下pipelines的邏輯,由三部分組成,load,transforms,和format。load相關的LoadImageFromFile,LoadAnnotations都是字典results進去,字典results出來。具體代碼看下便知,LoadImageFromFile增長了'filename','img','img_shape','ori_shape','pad_shape','scale_factor','img_norm_cfg'字段。其中img是numpy格式。LoadAnnotations從 results['ann_info']中解析出bboxs,masks,labels等信息。注意coco格式的原始解析來自pycocotools,包括其評估方法,這裏關鍵是字典結構(這個和模型損失函數,評估等相關,統一結構,使得代碼統一)。transforms中的類做用於字典的values,也即數據加強。format中的DefaultFormatBundle是將數據轉成mmcv擴展的容器類格式DataContainer。另外Collect會根據不一樣任務的不一樣配置,從results中選取只含keys的信息生成新的字典,具體看下該類幫助文檔。這裏看一下從numpy轉成tensor的代碼:

def to_tensor(data):
"""Convert objects of various python types to :obj:`torch.Tensor`.

Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`.
"""
if isinstance(data, torch.Tensor):
return data
elif isinstance(data, np.ndarray):
return torch.from_numpy(data)
elif isinstance(data, Sequence) and not mmcv.is_str(data):
return torch.tensor(data)
elif isinstance(data, int):
return torch.LongTensor([data])
elif isinstance(data, float):
return torch.FloatTensor([data])
else:
raise TypeError('type {} cannot be converted to tensor.'.format(
type(data)))
以上代碼告訴咱們,基本數據類型,需掌握。

那麼DataContainer是什麼呢?它是對tensor的封裝,將results中的tensor轉成DataContainer格式,實際上只是增長了幾個property函數,cpu_only,stack,padding_value,pad_dims,其含義自明,以及size,dim用來獲取數據的維度,形狀信息。 考慮到序列數據在進入DataLoader時,須要以batch方式進入模型,那麼一般的collate_fn會要求tensor數據的形狀一致。可是這樣不是很方便,因而有了DataContainer。它能夠作到載入GPU的數據能夠保持統一shape,並被stack,也能夠不stack,也能夠保持原樣,或者在非batch維度上作pad。固然這個也要對default_collate進行改造,mmcv在parallel.collate中實現了這個。

collate_fn是DataLoader中將序列dataset組織成batch大小的函數,這裏帖三個普通例子:

def collate_fn_1(batch):
# 這是默認的,明顯batch中包含相同形狀的img\_tensor和label
return tuple(zip(*batch))

def coco_collate_2(batch):
# 傳入的batch數據是被albu加強後的(字典結構)
imgs = [s['image'] for s in batch] # tensor, h, w, c->c, h, w , handle at transform in __getitem__
annots = [s['bboxes'] for s in batch]
labels = [s['category_id'] for s in batch]

# 以當前batch中圖片annot數量的最大值做爲標記數據的第二維度值,空出的就補-1。
max_num_annots = max(len(annot) for annot in annots)
annot_padded = np.ones((len(annots), max_num_annots, 5))*-1

if max_num_annots > 0:
for idx, (annot, lab) in enumerate(zip(annots, labels)):
if len(annot) > 0:
annot_padded[idx, :len(annot), :4] = annot
# 不一樣模型,損失值計算可能不一樣,這裏ssd結構須要改成xyxy格式而且要作尺度歸一化
# 這一步徹底能夠放到\_\_getitem\_\_中去,只是albu的格式需求問題。
annot_padded[idx, :len(annot), 2] += annot_padded[idx, :len(annot), 0] # xywh-->x1,y1,x2,y2 for general box,ssd target assigner
annot_padded[idx, :len(annot), 3] += annot_padded[idx, :len(annot), 1] # contains padded -1 label
annot_padded[idx, :len(annot), :] /= 640 # priorbox for ssd primary target assinger
annot_padded[idx, :len(annot), 4] = lab
return torch.stack(imgs, 0), torch.FloatTensor(annot_padded)

def detection_collate_3(batch):
targets = []
imgs = []
for _, sample in enumerate(batch):
for _, img_anno in enumerate(sample):
if torch.is_tensor(img_anno):
imgs.append(img_anno)
elif isinstance(img_anno, np.ndarray):
annos = torch.from_numpy(img_anno).float()
targets.append(annos)
return torch.stack(imgs, 0), targets # 作了stack, DataContainer能夠不作stack

以上就是數據處理的相關內容。最後再用DataLoader封裝拆成迭代器,其相關細節,sampler等暫略。

data_loader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=num_workers,
collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu),
pin_memory=False,
worker_init_fn=init_fn,
**kwargs)

5. 訓練pipeline

訓練流程的包裝過程大體以下:tools/train.py->apis/train.py->mmcv/runner.py->mmcv/hook.py(後面是分散的),其中runner維護了數據信息,優化器,日誌系統,訓練loop中的各節點信息,模型保存,學習率等.另外補充一點,以上包裝過程,在mmdet中無處不在,包括mmcv的代碼也是對平常頻繁使用的函數進行了統一封裝.

0.5.1 訓練邏輯

圖見Figure2:

Figure 2

注意它的四個層級.代碼上,主要查看apis/train.py,mmcv中的runner相關文件.核心圍繞Runner,Hook兩個類.Runner將模型,批處理函數batch_processor,優化器做爲基本屬性,訓練過程當中與訓練狀態,各節點相關的信息被記錄在mode,_hooks,_epoch,_iter,_inner_iter,_max_epochs,_max_iters中,這些信息維護了訓練過程當中插入不一樣hook的操做方式.理清訓練流程只需看Runner的成員函數run.在run裏會根據mode按配置中workflow的epoch循環調用train和val函數,跑完全部的epoch.好比train:

def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train' # 改變模式
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(data_loader) # 最大batch循環次數
self.call_hook('before_train_epoch') # 根據名字獲取hook對象函數
for i, data_batch in enumerate(data_loader):
self._inner_iter = i # 記錄訓練迭代輪數
self.call_hook('before_train_iter') # 一個batch前向開始
outputs = self.batch_processor(
self.model, data_batch, train_mode=True, **kwargs)
self.outputs = outputs
self.call_hook('after_train_iter') # 一個batch前向結束
self._iter += 1 # 方便resume時,知道從哪一輪開始優化

self.call_hook('after_train_epoch') # 一個epoch結束
self._epoch += 1 # 記錄訓練epoch狀態,方便resume

上面須要說明的是自定義hook類,自定義hook類需繼承mmcv的Hook類,其默認了6+8+4個成員函數,也即Figure2所示的6個層級節點,外加2*4個區分train和val的節點記錄函數,以及4個邊界檢查函數.從train.py中容易看出,在訓練以前,已經將須要的hook函數註冊到Runner的self._hook中了,包括從配置文件解析的優化器,學習率調整函數,模型保存,一個batch的時間記錄等(註冊hook算子在self._hook中按優先級升序排列).這裏的call_hook函數定義以下:

def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)

容易看出,在訓練的不一樣節點,將從註冊列表中調用實現了該節點函數的類成員函數.好比

class OptimizerHook(Hook):

def __init__(self, grad_clip=None):
self.grad_clip = grad_clip

def clip_grads(self, params):
clip_grad.clip_grad_norm_(
filter(lambda p: p.requires_grad, params), **self.grad_clip)

def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()

將在每一個train_iter後實現反向傳播和參數更新.學習率優化相對複雜一點,其基類LrUpdaterHook,實現了before_run,before_train_epoch, before_train_iter三個hook函數,意義自明.這裏選一個餘弦式變化,稍做說明:

class CosineLrUpdaterHook(LrUpdaterHook):

def __init__(self, target_lr=0, **kwargs):
self.target_lr = target_lr
super(CosineLrUpdaterHook, self).__init__(**kwargs)

def get_lr(self, runner, base_lr):
if self.by_epoch:
progress = runner.epoch
max_progress = runner.max_epochs
else:
progress = runner.iter # runner須要管理各節點信息的緣由之一
max_progress = runner.max_iters
return self.target_lr + 0.5 * (base_lr - self.target_lr) * \
(1 + cos(pi * (progress / max_progress)))

從get_lr能夠看到,學習率變換週期有兩種,epoch->max_epoch,或者更大的iter->max_iter,後者代表一個epoch內不一樣batch的學習率能夠不一樣,由於沒有什麼理論,全部這兩種方式都行.其中base_lr爲初始學習率,target_lr爲學習率衰減的上界,而當前學習率即爲返回值.

留言區


歡迎關注GiantPandaCV, 在這裏你將看到獨家的深度學習分享,堅持原創,天天分享咱們學習到的新鮮知識。( • ̀ω•́ )✧

有對文章相關的問題,或者想要加入交流羣,歡迎添加BBuf微信:

爲了方便讀者獲取資料以及咱們公衆號的做者發佈一些Github工程的更新,咱們成立了一個QQ羣,二維碼以下,感興趣能夠加入。

公衆號QQ交流羣


本文分享自微信公衆號 - GiantPandaCV(BBuf233)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索