Pytorch數據集讀入——Dataset類,實現數據集打亂Shuffle

在進行相關平臺的練習過程當中,因爲要本身導入數據集,而導入方法在市面上五花八門,各類庫均可以應用,在這個過程當中我準備嘗試torchvision的庫dataset
torchvision.datasets.ImageFolder
簡單應用起來很是簡單,用torchvision.datasets.ImageFolder實現圖片的導入,在隨後訓練過程當中用Datalodar處理後可按批次取出訓練集git

class ImageFolder(root, transform=None, target_transform=None, loader=default_loader, is_valid_file=None)
ImageFolder有這麼幾個參數,其中root指的是數據所在的文件夾,其中該文件夾的存儲方式應爲
root/labels/xxx.jpg
即根據自身分類標籤存儲在對應標籤名的文件夾內
ImageFolder在讀入的過程當中會自行加好標籤,最後造成一對對的數據
另外比較經常使用的就是transform,表示對於傳入圖片的預處理,如剪裁,顏色選擇等等
好比github

transform_t = transforms.Compose([
    transforms.Resize([64, 64]), 
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()]
    )

具體參數能夠上網查看
在以後用DataLodar處理後雖然的確有Shuffle的參數,可是卻只是在一個小批次內進行打亂,本來是按照類別存儲的,這樣的話會致使很嚴重的過擬合,爲了不這個,我決定常識改寫一下Dataset的類(主要是看起來Dataset看起來改寫比較順手...ImageFolder尚未看源碼並沒要對此下手)
可是Dataset須要讀入一個個的訓練數據的位置,怎麼辦呢?我就先寫了一個小腳本,生成一個txt文件來存儲全部數據的名稱(相對路徑),同時在這一步就進行打亂操做【一眼看下去甚至會發現init的classnum參數徹底沒用上(捂臉app

import os
import numpy as np
'''
self.target     順序存儲數據集
self.DataFile   存儲根目錄
self.s          存儲全部數據
self.label      存儲全部標籤及其對應的值
'''
class create_list():
    def __init__(self,root,classnum=2):
        self.target=open("./Data.txt",'w')
        self.DataFile=root
        self.s=[]
        self.label={}
        self.datanum=0
    
    def create(self):
        files=os.listdir(self.DataFile)
        for labels in files:
            tempdata=os.listdir(self.DataFile+"/"+labels)
            self.label[labels]=len(self.label)
            for img in tempdata:
                self.datanum+=1
                self.target.write(self.DataFile+"/"+labels+"/"+img+" "+labels+"\n")
                self.s.append([self.DataFile+"/"+labels+"/"+img,labels])
    
    def detail(self):
        #查看數據數量以及標籤對應
        print(self.datanum)
        print(self.label)
    
    def get_all(self):
        #查看全部數據
        print(self.s)

    def get_root(self):
        #得到根目錄
        return self.DataFile

    def shuffle(self):
        #得到打亂的存儲txt
        shuffle_file=open("./Shuffle_Data.txt",'w')
        temp=self.s
        np.random.shuffle(temp)
        for i in temp:
            shuffle_file.write(i[0]+" "+str(i[1])+"\n")
        return self.DataFile+"/Shuffle_Data.txt"

    def label_id(self,label):
        #得到該標籤對應的值
        return self.label[label]

數據集的存儲方式上的要求跟以前的ImageFolder同樣
最終會生成一個這樣的txt文件
image
數據集來源於某x光胸片判斷...
而Shuffle操做就是爲了生成打亂後的txt文件,我寫的比較簡單粗暴...先將就看吧,生成後大概就是這個樣子
image
至少真正的作到打亂數據了
完成這個之後,就能夠用此來幫助DataLodar了
接下來的代碼或許比較辣眼睛...可是事實證實是有用的,可是可能Python技巧不太熟練因此就會顯得很生澀...
我重現的Dataset類:dom

from PIL import Image
import torch

class cDataset(torch.utils.data.Dataset):
    def __init__(self, datatxt, root="", transform=None, target_transform=None, LabelDic=None):
        super(cDataset,self).__init__()
        files = open(root + "/" + datatxt, 'r')
        self.img=[]
        for i in files:
            i = i.rstrip()
            temp = i.split()
            if LabelDic!=None:
                self.img.append((temp[0],LabelDic[temp[1]]))
            else:
                self.img.append((temp[0],temp[0]))
            
        self.transform = transform
        self.target_transform = target_transform
    
    def __getitem__(self, index):
        files, label = self.img[index]
        img = Image.open(files).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    
    def __len__(self):
        return len(self.img)

其實直接看就能大概看明白,主要也就是要實現類裏面的幾個方法code

class cDataset(torch.utils.data.Dataset):
    def __init__():
    def __getitem__(self, index):
    def __len__(self):

其中getitm相似一次次的取出數據,len就是返回數據集數目
其中init的參數我作了稍許調整,因爲我以前的txt內標籤是字符串,而爲了能讓對應生成的tag是所要求的,能夠傳入一個字典,如:
LabelDic={"NORMAL":0,"PNEUMONIA":1}
這樣就能夠在以後轉化爲數字的標籤,onehot或者怎麼怎麼樣了,,,orm

相關文章
相關標籤/搜索