使用TensorFlow訓練模型的基本流程

本文已在公衆號機器視覺與算法建模發佈,轉載請聯繫我。
html

使用TensorFlow的基本流程

本篇文章將介紹使用tensorflow的訓練模型的基本流程,包括製做讀取TFRecord,訓練和保存模型,讀取模型。python

準備

TFRecord

TensorFlow提供了一種統一的格式來存儲數據,這個格式就是TFRecord.git

message Example {  
 Features features = 1;  
};  
  
message Features{  
 map<string,Feature> featrue = 1;  
};  
  
message Feature{  
    oneof kind{  
        BytesList bytes_list = 1;  
        FloatList float_list = 2;  
        Int64List int64_list = 3;  
    }  
};

從代碼中咱們能夠看到, tf.train.Example 包含了一個字典,它的鍵是一個字符串,值爲Feature,Feature能夠取值爲字符串(BytesList)、浮點數列表(FloatList)、整型數列表(Int64List)。github

寫入一個TFRecord通常分爲三步:

  • 讀取須要轉化的數據
  • 將數據轉化爲Example Protocol Buffer,並寫入這個數據結構
  • 經過將數據轉化爲字符串後,經過TFRecordWriter寫出

方法一

此次咱們的數據是分別保存在多個文件夾下的,所以讀取數據最直接的方法是遍歷目錄下全部文件,而後讀入寫出TFRecord文件。該方法對應文件MakeTFRecord.py,咱們來看關鍵代碼算法

filenameTrain = 'TFRecord/train.tfrecords'
    filenameTest = 'TFRecord/test.tfrecords'
    writerTrain = tf.python_io.TFRecordWriter(filenameTrain)
    writerTest = tf.python_io.TFRecordWriter(filenameTest)
    folders = os.listdir(HOME_PATH)
    for subFoldersName in folders:
        label = transform_label(subFoldersName)
        path = os.path.join(HOME_PATH, subFoldersName)  # 文件夾路徑
        subFoldersNameList = os.listdir(path)
        i = 0
        for imageName in subFoldersNameList:
            imagePath = os.path.join(path, imageName)
            images = cv2.imread(imagePath)
            res = cv2.resize(images, (128, 128), interpolation=cv2.INTER_CUBIC)
            image_raw_data = res.tostring()
            example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(label),
                'image_raw': _bytes_feature(image_raw_data)
            }))
            if i <= len(subFoldersNameList) * 3 / 4:
                writerTrain.write(example.SerializeToString())
            else:
                writerTest.write(example.SerializeToString())
            i += 1

在作數據的時候,我打算將3/4的數據用作訓練集,剩下的1/4數據做爲測試集,方便起見,將其保存爲兩個文件。

基本流程就是遍歷Fnt目錄下的全部文件夾,再進入子文件夾遍歷其目錄下的圖片文件,而後用OpenCV的imread方法將其讀入,再將圖片數據轉化爲字符串。在TFRecord提供的數據結構中_bytes_feature'是存儲字符串的。<br> 以上將圖片成功讀入並寫入了TFRecord的數據結構中,那圖片對應的標籤怎麼辦呢? ``` Python def transform_label(folderName): label_dict = { 'Sample001': 0, 'Sample002': 1, 'Sample003': 2, 'Sample004': 3, 'Sample005': 4, 'Sample006': 5, 'Sample007': 6, 'Sample008': 7, 'Sample009': 8, 'Sample010': 9, 'Sample011': 10, } return label_dict[folderName] ``` 我創建了一個字典,因爲一個文件下的圖片都是同一類的,因此將圖片對應的文件夾名字與它所對應的標籤,產生映射關係。代碼中label = transform_label(subFoldersName)`經過該方法得到,圖片的標籤。網絡

方法二

在使用方法一產生的數據訓練模型,會發現很是容易產生過擬合。由於咱們在讀數據的時候是將它打包成batch讀入的,雖然可使用tf.train.shuffle_batch方法將隊列中的數據打亂再讀入,可是因爲一個類中的數據過多,會致使即使打亂後也是同一個類中的數據。例如:數字0有1000個樣本,假設你讀取的隊列長達1000個,這樣即使打亂隊列後讀取的圖片任然是0。這在訓練時容易過擬合。爲了不這種狀況發生,個人想法是在作數據時將圖片打亂後寫入。對應文件MakeTFRecord2.py,關鍵代碼以下數據結構

folders = os.listdir(HOME_PATH)
    for subFoldersName in folders:
        path = os.path.join(HOME_PATH, subFoldersName)  # 文件夾路徑
        subFoldersNameList = os.listdir(path)
        for imageName in subFoldersNameList:
            imagePath = os.path.join(path, imageName)
            totalList.append(imagePath)

    # 產生一個長度爲圖片總數的不重複隨機數序列
    dictlist = random.sample(range(0, len(totalList)), len(totalList))  
    print(totalList[0].split('\\')[1].split('-')[0])    # 這是圖片對應的類別

    i = 0
    for path in totalList:
        images = cv2.imread(totalList[dictlist[i]])
        res = cv2.resize(images, (128, 128), interpolation=cv2.INTER_CUBIC)
        image_raw_data = res.tostring()
        label = transform_label(totalList[dictlist[i]].split('\\')[1].split('-')[0])
        print(label)
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(label),
            'image_raw': _bytes_feature(image_raw_data)
        }))
        if i <= len(totalList) * 3 / 4:
            writerTrain.write(example.SerializeToString())
        else:
            writerTest.write(example.SerializeToString())
        i += 1

基本過程:遍歷目錄下全部的圖片,將它的路徑加入一個大的列表。經過一個不重複的隨機數序列,來控制使用哪張圖片。這就達到隨機的目的。

怎麼獲取標籤呢?圖片文件都是類型-序號這個形式命名的,這裏經過獲取它的類型名,創建字典產生映射關係。app

def transform_label(imgType):
    label_dict = {
        'img001': 0,
        'img002': 1,
        'img003': 2,
        'img004': 3,
        'img005': 4,
        'img006': 5,
        'img007': 6,
        'img008': 7,
        'img009': 8,
        'img010': 9,
        'img011': 10,
    }
    return label_dict[imgType]

原尺寸圖片CNN

對應CNN_train.py文件
訓練的時候怎麼讀取TFRecord數據呢,參考如下代碼dom

# 讀訓練集數據
def read_train_data():
    reader = tf.TFRecordReader()
    filename_train = tf.train.string_input_producer(["TFRecord128/train.tfrecords"])
    _, serialized_example_test = reader.read(filename_train)
    features = tf.parse_single_example(
        serialized_example_test,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
        }
    )

    img_train = features['image_raw']
    images_train = tf.decode_raw(img_train, tf.uint8)
    images_train = tf.reshape(images_train, [128, 128, 3])
    labels_train = tf.cast(features['label'], tf.int64)
    labels_train = tf.cast(labels_train, tf.int64)
    labels_train = tf.one_hot(labels_train, 10)
    return images_train, labels_train

經過features[鍵名]的方式將存入的數據讀取出來,鍵名和數據類型要與寫入的保持一致。

關於這裏的卷積神經網絡,我是參考王學長培訓時的代碼寫的。固然照搬確定不行,會遇到loss NaN的狀況,我解決的方法是仿照AlexNet中,在卷積後加入LRN層,進行局部響應歸一化。在設置參數時,加入l2正則項。關鍵代碼以下測試

def weights_with_loss(shape, stddev, wl):
    var = tf.truncated_normal(stddev=stddev, shape=shape)
    if wl is not None:
        weight_loss = tf.multiply(tf.nn.l2_loss(var), wl, name='weight_loss')
        tf.add_to_collection('losses', weight_loss)
    return tf.Variable(var)

def net(image, drop_pro):
    W_conv1 = weights_with_loss([5, 5, 3, 32], 5e-2, wl=0.0)
    b_conv1 = biasses([32])
    conv1 = tf.nn.relu(conv(image, W_conv1) + b_conv1)
    pool1 = max_pool_2x2(conv1)
    norm1 = tf.nn.lrn(pool1, 4, bias=1, alpha=0.001 / 9.0, beta=0.75)

    W_conv2 = weights_with_loss([5, 5, 32, 64], stddev=5e-2, wl=0.0)
    b_conv2 = biasses([64])
    conv2 = tf.nn.relu(conv(norm1, W_conv2) + b_conv2)
    norm2 = tf.nn.lrn(conv2, 4, bias=1, alpha=0.001 / 9.0, beta=0.75)
    pool2 = max_pool_2x2(norm2)

    W_conv3 = weights_with_loss([5, 5, 64, 128], stddev=0.04, wl=0.004)
    b_conv3 = biasses([128])
    conv3 = tf.nn.relu(conv(pool2, W_conv3) + b_conv3)
    pool3 = max_pool_2x2(conv3)

    W_conv4 = weights_with_loss([5, 5, 128, 256], stddev=1 / 128, wl=0.004)
    b_conv4 = biasses([256])
    conv4 = tf.nn.relu(conv(pool3, W_conv4) + b_conv4)
    pool4 = max_pool_2x2(conv4)

    image_raw = tf.reshape(pool4, shape=[-1, 8 * 8 * 256])

    # 全鏈接層
    fc_w1 = weights_with_loss(shape=[8 * 8 * 256, 1024], stddev=1 / 256, wl=0.0)
    fc_b1 = biasses(shape=[1024])
    fc_1 = tf.nn.relu(tf.matmul(image_raw, fc_w1) + fc_b1)

    # drop-out層
    drop_out = tf.nn.dropout(fc_1, drop_pro)

    fc_2 = weights_with_loss([1024, 10], stddev=0.01, wl=0.0)
    fc_b2 = biasses([10])

    return tf.matmul(drop_out, fc_2) + fc_b2

128x128x3原圖訓練過程
128*128
在驗證集上的正確率
128v
這裏使用的是1281283的圖片,圖片比較大,因此我產生了一個想法。在作TFRecord數據的時候,將圖片尺寸減半。因此就有了第二種方法。

圖片尺寸減半CNN

對應文件CNN_train2.py
與上面那種方法惟一的區別是將圖片尺寸128*128*3改爲了64*64*3因此我這裏就不重複說明了。
64x64x3圖片訓過程
64*64
在驗證集上的正確率
64v

保存模型

CNN_train.py中,對應保存模型的代碼是

def save_model(sess, step):
    MODEL_SAVE_PATH = "./model128/"
    MODEL_NAME = "model.ckpt"
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=step)

save_model(sess, i)

i是迭代的次數,能夠不填其對應的參數global_step

在測試集上檢驗準確率

對應文件AccuracyTest.py
代碼基本與訓練的代碼相同,這裏直接講怎麼恢復模型。關鍵代碼

ckpt = tf.train.get_checkpoint_state(MODEL_PATH)
    if ckpt and ckpt.model_checkpoint_path:
        #加載模型
        saver.restore(sess, ckpt.model_checkpoint_path)

值得一提的是tf.train.get_checkpoint_state該方法會自動找到文件夾下迭代次數最多的模型,而後讀入。而saver.restore(sess, ckpt.model_checkpoint_path)方法將恢復,模型在訓練時最後一次迭代的變量參數。

查看讀入的TFRecord圖片

對應文件ReadTest.py
若是你想檢查下在製做TFRecord時,圖片是否處理的正確,最簡單的方法就是將圖片顯示出來。關鍵代碼以下

def plot_images(images, labels):
    for i in np.arange(0, 20):
        plt.subplot(5, 5, i + 1)
        plt.axis('off')
        plt.title(labels[i], fontsize=14)
        plt.subplots_adjust(top=1.5)
        plt.imshow(images[i])
    plt.show()

plot_images(image, label

示例

總結

在摸索過程當中遇到不少問題,多虧了王學長耐心幫助,也但願這篇文章能幫助更多人吧。
新手上路,若是有錯,歡迎指正,謝謝。

代碼已上傳github:https://github.com/wmpscc/TensorflowBaseDemo

閱讀原文
wechat.jpg

相關文章
相關標籤/搜索