最近在學習keras框架,不得不感嘆keras的確比pytorch好用。python
那麼,如今就來整理一下深度學習裏最經常使用的數據迭代器的寫法吧。app
# 數據文件一篇就是一個文件 def _read_file(filename): """讀取一個文件並轉換爲一行""" with open(filename, 'r', encoding='utf-8') as f: s = f.read().strip().replace('\n', '。').replace('\t', '').replace('\u3000', '') return re.sub(r'。+', '。', s) # 文章迭代器 def get_data_iterator(data_path): for category in os.listdir(data_path): category_path = os.path.join(data_path, category) for file_name in os.listdir(category_path): yield _read_file(os.path.join(category_path, file_name)), category it = get_data_iterator(data_path) print(next(it)) ''' ('競彩解析:日本美國爭冠死磕 兩巴相逢必有生死。週日受注賽事,女足世界盃決賽、美洲盃兩場1/4決賽毫無疑問是全世界球迷和彩民關注的焦點。本屆女足世界盃的最大黑馬日本隊可否一黑到底,創造亞洲奇蹟?女子足壇霸主美國隊可否再次「滅黑」成功,成就三冠偉業?巴西、巴拉圭冤家路窄,誰又能笑到最後?諸多謎底,在週一凌晨就會揭曉。日本美國爭冠死磕。本屆女足世界盃,是顛覆與反顛覆之爭。奪冠大熱門東道主德國隊1/4決賽被日本隊加時賽一球而「黑」,另外一個奪冠大熱門瑞典隊則在半決賽被日本隊3:1完全打垮。而美國隊則捍衛着女足豪強的尊嚴,在1/4決賽,她們與巴西女足苦戰至點球大戰,最終以5:3淘汰這支迅速崛起的黑馬球隊,而在半決賽,她們更是3:1大勝歐洲黑馬法國隊。美日兩隊這次世界盃進程驚人類似,小組賽前兩輪全勝,最後一輪輸球,1/4決賽一樣與對手90分鐘內戰成平局,半決賽竟一樣3:1大勝對手。這次決戰,不管是日本仍是美國隊奪冠,均將創造女足世界盃新的歷史。兩巴相逢必有生死。本屆美洲盃,讓人大跌眼鏡的事情太多。巴西、巴拉圭冤家路窄彷佛更具傳奇色彩。兩隊小組賽同分在B組,本來兩個出線大熱門,卻雙雙在前兩輪小組賽戰平,兩隊直接交鋒就是2:2平局,結果雙雙面臨出局危險。最後一輪,巴西隊在下半場終於發威,4:2大勝厄瓜多爾後來居上以小組第一齣線,而巴拉圭最後一戰仍是3:3戰平委內瑞拉得到小組第三,僥倖憑藉淨勝球優點擠掉A組第三名的哥斯達黎加,得到一個八強席位。在小組賽,巴西隊是在最後時刻才逼平了巴拉圭,他們的好運氣會在淘汰賽再顯神威嗎?巴拉圭此前3輪小組賽彷佛都缺少運氣,此番又會否被幸運之神補償一下呢?。另外一場美洲盃1/4決賽,智利隊在C組小組賽2勝1平以小組頭名晉級八強;而委內瑞拉在B組是最不被看好的球隊,但居然在與巴西、巴拉圭同組的狀況下,前兩輪就奠基了小組出線權,他們小組3戰1勝2平保持不敗戰績,而入球數跟智利同樣都是4球,只是失球數比智利多了1個。但既然他們面對強大的巴西都能保持球門不失,此番再創佳績也不足爲怪。', '彩票') ''' ''' 通過一堆處理後... ''' # 構建循環的數據迭代器 def get_handled_data_iterator(data_path): pad_sequences_iter = get_pad_sequences_iterator(data_path, sequences_max_length) while True: for pad_sequences, label_one_hot in pad_sequences_iter: yield pad_sequences, label_one_hot # 構建批次迭代器 def batch_iter(data_path, batch_size=64, shuffle=True): """生成批次數據""" handled_data_iter = get_handled_data_iterator(data_path) while True: data_list = [] for _ in range(batch_size): data = next(handled_data_iter) data_list.append(data) if shuffle: random.shuffle(data_list) pad_sequences_list = [] label_one_hot_list = [] for data in data_list: pad_sequences, label_one_hot = data pad_sequences_list.append(pad_sequences.tolist()) label_one_hot_list.append(label_one_hot.tolist()) yield np.array(pad_sequences_list), np.array(label_one_hot_list) it = batch_iter(data_path, batch_size=2) print(next(it)) ''' (array([[ 751, 257, 223, ..., 661, 551, 8], [ 772, 751, 307, ..., 296, 2015, 1169]]), array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])) '''
以後就能夠用框架
model.fit_generator(batch_iter(data_path, batch_size=64), steps_per_epoch, epochs=100, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None)
來訓練模型啦~dom