樣本放置格式以下,假設有n個分類,每一個分類的樣本放置到相應標籤的文件目錄中:
|----XXX/sample/root/0
a0.jpg
b0.jpg
c0.jpg
|---XXX/sample/root/1
a1.jpg
b1.jpg
c1.jpg
|----XXX/sample/root/2
...
|----XXX/sample/root/n
python
例如:如下有11個分類git
工具語法格式:app
>>python collect_samples_path.py /where/the/samples /where/to/save/sample_info.txt [--shuffle=[True,False],--traintestRate=[0.3] ] dom
或工具
>>python collect_samples_path.py --help 獲取幫助code
python 代碼以下:ci
import os import sys import random import argparse #\n CMD FORMAT:\n >>collect_samples_path.py /where/the/samples /where/to/save/sample_info.txt [--shuffle=[True,False],--traintestRate=[0.3] ] parser = argparse.ArgumentParser('\n collect_samples_path.py') parser.add_argument('path',help='specify samples path') parser.add_argument('filepath',help='specify filename to save') parser.add_argument('--shuffle',help='shuffle each class sample file',default=True) parser.add_argument('--traintestRate',help='rate of each samples for test',type=float,default=0.3) def collect_sample_path(root,filename,shuffle = True,trainTestRate=0.3): ''' root: sample root path filename: to save file path shuffle: is shuffle data trainTestRate: rate of each samples for test ''' train_data = [] test_data = [] for parent,dirs,filenames in os.walk(root): if(parent != root): label = parent[len(root)+1:len(parent)] if(label.isdigit()): if(shuffle): random.shuffle(filenames) #collect train sample ntrain = int(len(filenames) * (1-trainTestRate)) for i in range(0,len(filenames)): if(i < ntrain): train_data.append('/'+label+'/'+filenames[i]+' '+label+'\n') else: test_data.append('/'+label+'/'+filenames[i]+' '+label+'\n') with open(filename+'.train','w+') as f: for item in train_data: f.write(item) with open(filename+'.test','w+') as f: for item in test_data: f.write(item) if __name__ == '__main__': args = parser.parse_args() collect_sample_path(args.path,args.filepath,args.shuffle,args.traintestRate)