給深度學習模型構建數據迭代器

最近在學習keras框架,不得不感嘆keras的確比pytorch好用。python

那麼,如今就來整理一下深度學習裏最經常使用的數據迭代器的寫法吧。app

# 數據文件一篇就是一個文件
def _read_file(filename):
    """讀取一個文件並轉換爲一行"""
    with open(filename, 'r', encoding='utf-8') as f:
        s = f.read().strip().replace('\n', '。').replace('\t', '').replace('\u3000', '')
        return re.sub(r'。+', '。', s)

# 文章迭代器
def get_data_iterator(data_path):
    for category in os.listdir(data_path):
        category_path = os.path.join(data_path, category)
        for file_name in os.listdir(category_path):
            yield _read_file(os.path.join(category_path, file_name)), category

it = get_data_iterator(data_path)
print(next(it))
'''
('競彩解析:日本美國爭冠死磕 兩巴相逢必有生死。週日受注賽事,女足世界盃決賽、美洲盃兩場1/4決賽毫無疑問是全世界球迷和彩民關注的焦點。本屆女足世界盃的最大黑馬日本隊可否一黑到底,創造亞洲奇蹟?女子足壇霸主美國隊可否再次「滅黑」成功,成就三冠偉業?巴西、巴拉圭冤家路窄,誰又能笑到最後?諸多謎底,在週一凌晨就會揭曉。日本美國爭冠死磕。本屆女足世界盃,是顛覆與反顛覆之爭。奪冠大熱門東道主德國隊1/4決賽被日本隊加時賽一球而「黑」,另外一個奪冠大熱門瑞典隊則在半決賽被日本隊3:1完全打垮。而美國隊則捍衛着女足豪強的尊嚴,在1/4決賽,她們與巴西女足苦戰至點球大戰,最終以5:3淘汰這支迅速崛起的黑馬球隊,而在半決賽,她們更是3:1大勝歐洲黑馬法國隊。美日兩隊這次世界盃進程驚人類似,小組賽前兩輪全勝,最後一輪輸球,1/4決賽一樣與對手90分鐘內戰成平局,半決賽竟一樣3:1大勝對手。這次決戰,不管是日本仍是美國隊奪冠,均將創造女足世界盃新的歷史。兩巴相逢必有生死。本屆美洲盃,讓人大跌眼鏡的事情太多。巴西、巴拉圭冤家路窄彷佛更具傳奇色彩。兩隊小組賽同分在B組,本來兩個出線大熱門,卻雙雙在前兩輪小組賽戰平,兩隊直接交鋒就是2:2平局,結果雙雙面臨出局危險。最後一輪,巴西隊在下半場終於發威,4:2大勝厄瓜多爾後來居上以小組第一齣線,而巴拉圭最後一戰仍是3:3戰平委內瑞拉得到小組第三,僥倖憑藉淨勝球優點擠掉A組第三名的哥斯達黎加,得到一個八強席位。在小組賽,巴西隊是在最後時刻才逼平了巴拉圭,他們的好運氣會在淘汰賽再顯神威嗎?巴拉圭此前3輪小組賽彷佛都缺少運氣,此番又會否被幸運之神補償一下呢?。另外一場美洲盃1/4決賽,智利隊在C組小組賽2勝1平以小組頭名晉級八強;而委內瑞拉在B組是最不被看好的球隊,但居然在與巴西、巴拉圭同組的狀況下,前兩輪就奠基了小組出線權,他們小組3戰1勝2平保持不敗戰績,而入球數跟智利同樣都是4球,只是失球數比智利多了1個。但既然他們面對強大的巴西都能保持球門不失,此番再創佳績也不足爲怪。',
 '彩票')
 '''

'''
通過一堆處理後...
'''

# 構建循環的數據迭代器
def get_handled_data_iterator(data_path):
    pad_sequences_iter = get_pad_sequences_iterator(data_path, sequences_max_length)
    while True:
        for pad_sequences, label_one_hot in pad_sequences_iter:
            yield pad_sequences, label_one_hot

# 構建批次迭代器
def batch_iter(data_path, batch_size=64, shuffle=True):
    """生成批次數據"""
    handled_data_iter = get_handled_data_iterator(data_path)
    while True:
        data_list = []
        for _ in range(batch_size):
            data = next(handled_data_iter)
            data_list.append(data)
        if shuffle:
            random.shuffle(data_list)
        
        pad_sequences_list = []
        label_one_hot_list = []
        for data in data_list:
            pad_sequences, label_one_hot = data
            pad_sequences_list.append(pad_sequences.tolist())
            label_one_hot_list.append(label_one_hot.tolist())

        yield np.array(pad_sequences_list), np.array(label_one_hot_list)

it = batch_iter(data_path, batch_size=2)
print(next(it))
'''
(array([[ 751,  257,  223, ...,  661,  551,    8],
        [ 772,  751,  307, ...,  296, 2015, 1169]]),
 array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))
'''

以後就能夠用框架

model.fit_generator(batch_iter(data_path, batch_size=64),
                    steps_per_epoch,
                    epochs=100,
                    verbose=1,
                    callbacks=None,
                    validation_data=None,
                    validation_steps=None,
                    class_weight=None)

來訓練模型啦~dom

相關文章
相關標籤/搜索