caffe小工具---樣本路徑收集

樣本放置格式以下,假設有n個分類,每一個分類的樣本放置到相應標籤的文件目錄中:  
|----XXX/sample/root/0   
      a0.jpg  
      b0.jpg  
      c0.jpg  
|---XXX/sample/root/1  
      a1.jpg  
      b1.jpg  
      c1.jpg  
|----XXX/sample/root/2  
...  
|----XXX/sample/root/n  
  python

例如:如下有11個分類git


工具語法格式:app

>>python  collect_samples_path.py /where/the/samples /where/to/save/sample_info.txt [--shuffle=[True,False],--traintestRate=[0.3] ]  dom

工具

>>python  collect_samples_path.py --help  獲取幫助code

python 代碼以下:ci

import os
import sys
import random
import argparse
#\n CMD FORMAT:\n >>collect_samples_path.py /where/the/samples /where/to/save/sample_info.txt [--shuffle=[True,False],--traintestRate=[0.3] ]
parser = argparse.ArgumentParser('\n collect_samples_path.py')
parser.add_argument('path',help='specify samples path')
parser.add_argument('filepath',help='specify filename to save')
parser.add_argument('--shuffle',help='shuffle each class sample file',default=True)
parser.add_argument('--traintestRate',help='rate of each samples for test',type=float,default=0.3)

def collect_sample_path(root,filename,shuffle = True,trainTestRate=0.3):
    '''
    root: sample root path
    filename: to save file path
    shuffle: is shuffle data
    trainTestRate: rate of each samples for test
    '''
    train_data = []
    test_data = []
    for parent,dirs,filenames in os.walk(root):
        if(parent != root):
            label = parent[len(root)+1:len(parent)]
            if(label.isdigit()):
                if(shuffle):
                    random.shuffle(filenames)
                #collect train sample
                ntrain = int(len(filenames) * (1-trainTestRate))
                for i in range(0,len(filenames)):
                    if(i < ntrain):
                        train_data.append('/'+label+'/'+filenames[i]+' '+label+'\n')
                    else:
                        test_data.append('/'+label+'/'+filenames[i]+' '+label+'\n')
    with open(filename+'.train','w+') as f:
        for item in train_data:
            f.write(item)
    with open(filename+'.test','w+') as f:
        for item in test_data:
            f.write(item)
if __name__ == '__main__':
    args = parser.parse_args()
    collect_sample_path(args.path,args.filepath,args.shuffle,args.traintestRate)
相關文章
相關標籤/搜索