解密飛槳多任務學習框架PALM,讓你的模型開啓「學霸」模式

隨着預訓練技術的到來,做爲深度學習重要應用領域之一,天然語言處理也迎來了新的春天。經過使用預訓練模型能夠大大減小模型訓練對數據的依賴,僅須要使用少許數據在下游任務中微調(Fine-tune),就能夠得到效果很是優秀的模型。不過若是但願得到更好的效果,該怎麼辦呢?有人也許會說:多訓練幾個epoch嘛!可是對於這種單一任務且有監督學習的微調方式,單獨增長訓練epoch並非一個好方法,過分的訓練容易損害模型的泛化能力,發生過擬合現象。php

要知道訓練一個模型就像在養育一個孩子同樣。在子女的教育問題上,每一個家長都會投入儘量多的人力和資源,但願把本身孩子教育成才,可以觸類旁通、舉一反三,成爲一個「學霸」。python

可是若是到考試時發現本身的孩子只會作課本上的原題,題目稍微改改就作很差,我想家長必定會欲哭無淚吧。相比模型訓練又未嘗不是呢?開發者不只要投入大量的服務器硬件資源,還要辛辛苦苦寫代碼,結果最後訓練出的模型泛化能力極差,跳出訓練數據的範圍,就啥也幹不了,相信這絕對不是任何一個開發者但願看到的。git

那麼有什麼方法能夠提升模型的泛化能力,讓模型能夠更加聰明呢?其實能夠在微調階段引入輔助任務信號,經過多任務學習的方式,即將多個目標任務場景聯合學習,就能夠顯著提升模型所學到的表徵的通用性,使得模型具有更強的泛化能力。github

可是基於傳統的深度學習框架,多任務學習方式的代碼實現門檻較高,策略調整不夠靈活,成本高,且容易出錯。爲此,飛槳開源深度學習平臺發佈了High-Level的多任務學習框架PALM,該框架靈活且易於使用,旨在幫助用戶快速開發具有強泛化能力的NLP模型,爲模型添加學霸屬性!服務器

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

什麼是多任務學習

在瞭解PALM以前,首先咱們來看下什麼是多任務學習。多任務學習是機器學習領域的一個研究分支,經過讓模型在學習階段同時解決多個任務,使其能夠學習到任務之間的共性和差別性。網絡

對於大部分NLP任務來講,都依賴於一個通用的文本表示模塊(Encoder)來完成文本的語義向量表示,這部分每每能夠看做是各任務的共性知識;而要解決不一樣的NLP任務,則須要在任務的輸出層來編碼各個不一樣任務所獨有的強相關的知識,所以輸出層每每能夠表徵任務之間的差別性。架構

圖1 多任務學習網絡示意圖框架

當預訓練模型應用於多任務學習中時,預訓練模型自己每每做爲各個任務的「共有部分」。在訓練過程當中,多個任務同時學習,不一樣任務之間共享預訓練參數,從而最終獲得一個更加魯棒、更強泛化能力的模型。就像讓一個孩子同時學習不一樣學科的知識,將不一樣學科的知識融會貫通,這樣未來考試時不管是考課內的,仍是課外的,單獨學科仍是考文理綜合,都會信手拈來!dom

PALM多任務學習框架概覽

瞭解了什麼是多任務學習後,我們來看看飛槳的PALM多任務學習框架的內部是如何組成的。如圖2所示,PALM的架構包含三層,從下到上依次是組件層(Component Layer)、訓練器層(Trainer Layer)和高級訓練器層(High-level Trainer Layer):機器學習

  • 組件層:PALM提供了6個 解耦的組件來實現NLP任務。每一個組件包含豐富的預約義類和一個基類。預約義類是針對典型的NLP任務的,而基類則是幫助用戶完成該組件的自定義。

  • 訓練器層:經過使用選定的構件創建計算圖,用於進行訓練和推理。該層描述了訓練策略、模型保存和加載、評估和推理過程。一個訓練器只能處理一個任務。

  • 高級訓練器層:用於複雜的學習和推理策略,如多任務學習。經過添加輔助任務來訓練健壯的NLP模型(提升模型的測試集和領域外的性能),或者聯合訓練多個相關任務來得到每一個任務的更高性能。

圖2 PALM的運行原理圖

飛槳PALM涉及的模塊以下表所示。

如今介紹完框架結構和模塊了,至關於演員都到場了,該開始唱戲了!下面我們再來看看如何使用這些模塊實現多任務學習功能的吧!

如何使用PALM?

1. 安裝PALM

PALM的安裝很是簡單,能夠經過pip直接安裝,也能夠經過git clone的方式從github上獲取。

pip install paddlepalm
#或
git clone  https://github.com/PaddlePaddle/PALM.git

PALM同時支持了Python2 和 Python三、Linux 和Windows、CPU 和 GPU等不一樣軟硬件環境。PALM安裝完成後會根據所處環境自動切換模型訓練/推理設備。

此外PALM中還內置了豐富的預訓練模型,用戶能夠輕鬆的切換預訓練模型,探索其做爲多任務學習的模型主幹的有效性。

若是要查看全部可用的預訓練模型並下載,請在python解釋器中運行以下代碼。


>>> from paddlepalm import downloader
>>> downloader.ls('pretrain')
Available pretrain items:
  => RoBERTa-zh-base
  => RoBERTa-zh-large
  => ERNIE-v2-en-base
  => ERNIE-v2-en-large
  => XLNet-cased-base
  => XLNet-cased-large
  => ERNIE-v1-zh-base
  => ERNIE-v1-zh-base-max-len-512
  => BERT-en-uncased-large-whole-word-masking
  => BERT-en-cased-large-whole-word-masking
  => BERT-en-uncased-base
  => BERT-en-uncased-large
  => BERT-en-cased-base
  => BERT-en-cased-large
  => BERT-multilingual-uncased-base
  => BERT-multilingual-cased-base
  => BERT-zh-base

>>> downloader.download('pretrain''BERT-en-uncased-base''./pretrain_models')

2. 參考以下例子編寫代碼

這裏舉一個對話系統構建的例子。在任務完成型的對話系統中,咱們爲了理解用戶的對話輸入,須要完成兩個NLP任務:一個是意圖理解,這個能夠看作是一個文本分類任務;另外一個是槽位填充,即識別出意圖中的相關屬性和屬性值,這個能夠看作是序列標註任務。咱們但願將這兩個NLP任務進行聯合訓練,來獲得更佳的模型。

基於PALM能夠很是輕鬆的實現這個多任務訓練需求。代碼以下所示。(爲了簡化說明,這裏省略了模型組網、迭代訓練等部分的相關代碼,僅體現PALM相關的內容。)

# 建立數據集的讀取與預處理工具
seq_label_reader = palm.reader.SequenceLabelReader(vocab_path, max_seqlen, label_map, seed=random_seed) 
cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, seed=random_seed)

# 加載訓練數據
seq_label_reader.load_data(train_slot, file_format='tsv', num_epochs=None, batch_size=batch_size) 
cls_reader.load_data(train_intent, batch_size=batch_size, num_epochs=None)

# 建立骨幹網絡提取文本特徵
ernie = palm.backbone.ERNIE.from_config(config)

# 在ERNIE的骨幹網絡上註冊數據集讀取與預處理工具
seq_label_reader.register_with(ernie) 
cls_reader.register_with(ernie)

#  建立任務的輸出層
seq_label_head = palm.head.SequenceLabel(num_classes, input_dim, dropout_prob) 
cls_head = palm.head.Classify(num_classes_intent, input_dim, dropout_prob)

# 建立任務訓練單元和多任務訓練模塊
trainer_seq_label = palm.Trainer("slot", mix_ratio=1.0) 
trainer_cls = palm.Trainer("intent", mix_ratio=1.0)
trainer = palm.MultiHeadTrainer([trainer_seq_label, trainer_cls])

# 構建包含主幹網絡和任務頭的前向圖
loss1 = trainer_cls.build_forward(ernie, cls_head) 
loss2 = trainer_seq_label.build_forward(ernie, seq_label_head) 
loss_var = trainer.build_forward()

# 使能warmup策略以獲取更好的微調效果
n_steps = seq_label_reader.num_examples * 1.5 * num_epochs
warmup_steps = int(0.1 * n_steps) 
sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps)

# 構建優化器
adam = palm.optimizer.Adam(loss_var, lr, sched)

# 構建反向圖
trainer.build_backward(optimizer=adam, weight_decay=weight_decay)

#將準備好的reader和數據給到訓練單元。
trainer.fit_readers_with_mixratio([seq_label_reader, cls_reader], "slot", num_epochs)

# 加載預訓練模型
trainer.load_pretrain('./pretrain/ERNIE-v2-en-base')

# 設置訓練期間保存模型
trainer.set_saver(save_path='./outputs/', save_steps=300)

# 開始訓練
trainer.train(print_steps=10)

其它實現細節和完整的示例代碼請參見

https://github.com/PaddlePaddle/PALM/tree/master/examples/multi-task

運行代碼後,部分訓練日誌以下所示,能夠看到不一樣的訓練任務都在執行過程當中。

global step: 5, slot: step 3/309 (epoch 0), loss: 68.965, speed: 0.58 steps/s 
global step: 10, intent: step 3/311 (epoch 0), loss: 3.407, speed: 8.76 steps/s 
global step: 15, slot: step 12/309 (epoch 0), loss: 54.611, speed: 1.21 steps/s 
global step: 20, intent: step 7/311 (epoch 0), loss: 3.487, speed: 10.28 steps/s

更多示例

除了上面的示例以外,飛槳PALM還能夠用來幫助復現EMNLP2019 MRQA比賽中的奪冠方案D-Net。經過使用飛槳PALM,能夠幫助機器閱讀理解引入Mask Language Model和段落打分輔助任務的過程變得很是容易。

此外,Github Repo中還提供了情感分析、問題類似度匹配、命名實體識別、機器閱讀理解等更多的NLP示例,在這些單任務示例的基礎上嘗試引入更多相關的輔助任務能夠預期獲得泛化能力更強的模型,快去試試吧!

若是您加入官方QQ羣,您將趕上大批志同道合的深度學習同窗。官方QQ羣:703252161。

若是您想詳細瞭解更多飛槳的相關內容,請參閱如下文檔。

下載安裝命令

## CPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/cpu paddlepaddle

## GPU版本安裝命令
pip install -f https://paddlepaddle.org.cn/pip/oschina/gpu paddlepaddle-gpu

官網地址:

https://www.paddlepaddle.org.cn

PLAM項目地址:

https://github.com/PaddlePaddle/PALM

飛槳開源框架項目地址: 

GitHub: https://github.com/PaddlePaddle/Paddle

Gitee:  https://gitee.com/paddlepaddle/Paddle

>> 訪問 PaddlePaddle 官網,瞭解更多相關內容

相關文章
相關標籤/搜索