1 大綱概述html
文本分類這個系列將會有十篇左右,包括基於word2vec預訓練的文本分類,與及基於最新的預訓練模型(ELMo,BERT等)的文本分類。總共有如下系列:python
textCNN 模型github
charCNN 模型json
Bi-LSTM 模型app
RCNN 模型函數
全部代碼均在textClassifier倉庫中。
2 數據集
數據集爲IMDB 電影影評,總共有三個數據文件,在/data/rawData目錄下,包括unlabeledTrainData.tsv,labeledTrainData.tsv,testData.tsv。在進行文本分類時須要有標籤的數據(labeledTrainData),數據預處理如文本分類實戰(一)—— word2vec預訓練詞向量中同樣,預處理後的文件爲/data/preprocess/labeledTrain.csv。
3 BERT預訓練模型
BERT 模型來源於論文BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding。BERT模型是谷歌提出的基於雙向Transformer構建的語言模型。BERT模型和ELMo有大不一樣,在以前的預訓練模型(包括word2vec,ELMo等)都會生成詞向量,這種類別的預訓練模型屬於domain transfer。而近一兩年提出的ULMFiT,GPT,BERT等都屬於模型遷移。
BERT 模型是將預訓練模型和下游任務模型結合在一塊兒的,也就是說在作下游任務時仍然是用BERT模型,並且自然支持文本分類任務,在作文本分類任務時不須要對模型作修改。谷歌提供了下面七種預訓練好的模型文件。
BERT模型在英文數據集上提供了兩種大小的模型,Base和Large。Uncased是意味着輸入的詞都會轉變成小寫,cased是意味着輸入的詞會保存其大寫(在命名實體識別等項目上須要)。Multilingual是支持多語言的,最後一個是中文預訓練模型。
在這裏咱們選擇BERT-Base,Uncased。下載下來以後是一個zip文件,解壓後有ckpt文件,一個模型參數的json文件,一個詞彙表txt文件。
在應用BERT模型以前,咱們須要去github上下載開源代碼,咱們能夠直接clone下來,在這裏有一個run_classifier.py文件,在作文本分類項目時,咱們須要修改這個文件,主要是添加咱們的數據預處理類。clone下來的項目結構以下:
在run_classifier.py文件中有一個基類DataProcessor類,其代碼以下:
class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines
在這個基類中定義了一個讀取文件的靜態方法_read_tsv,四個分別獲取訓練集,驗證集,測試集和標籤的方法。接下來咱們要定義本身的數據處理的類,咱們將咱們的類命名爲IMDBProcessor
class IMDBProcessor(DataProcessor): """ IMDB data processor """ def _read_csv(self, data_dir, file_name): with tf.gfile.Open(data_dir + file_name, "r") as f: reader = csv.reader(f, delimiter=",", quotechar=None) lines = [] for line in reader: lines.append(line) return lines def get_train_examples(self, data_dir): lines = self._read_csv(data_dir, "trainData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "train-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_dev_examples(self, data_dir): lines = self._read_csv(data_dir, "devData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "dev-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_test_examples(self, data_dir): lines = self._read_csv(data_dir, "testData.csv") examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "test-%d" % (i) text_a = tokenization.convert_to_unicode(line[0]) label = tokenization.convert_to_unicode(line[1]) examples.append( InputExample(guid=guid, text_a=text_a, label=label)) return examples def get_labels(self): return ["0", "1"]
在這裏咱們沒有直接用基類中的靜態方法_read_tsv,由於咱們的csv文件是用逗號分隔的,所以就本身定義了一個_read_csv的方法,其他的方法就是讀取訓練集,驗證集,測試集和標籤。在這裏標籤就是一個列表,將咱們的類別標籤放入就行。訓練集,驗證集和測試集都是返回一個InputExample對象的列表。InputExample是run_classifier.py中定義的一個類,代碼以下:
class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b=None, label=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.label = label
在這裏定義了text_a和text_b,說明是支持句子對的輸入的,不過咱們這裏作文本分類只有一個句子的輸入,所以text_b能夠不傳參。
另外從上面咱們自定義的數據處理類中能夠看出,訓練集和驗證集是保存在不一樣文件中的,所以咱們須要將咱們以前預處理好的數據提早分割成訓練集和驗證集,並存放在同一個文件夾下面,文件的名稱要和類中方法裏的名稱相同。
到這裏以後咱們已經準備好了咱們的數據集,並定義好了數據處理類,此時咱們須要將咱們的數據處理類加入到run_classifier.py文件中的main函數下面的processors字典中,結果以下:
以後就能夠直接執行run_classifier.py文件,執行腳本以下:
export BERT_BASE_DIR=../modelParams/uncased_L-12_H-768_A-12 export DATASET=../data/ python run_classifier.py \ --data_dir=$MY_DATASET \ --task_name=imdb \ --vocab_file=$BERT_BASE_DIR/vocab.txt \ --bert_config_file=$BERT_BASE_DIR/bert_config.json \ --output_dir=../output/ \ --do_train=true \ --do_eval=true \ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \ --max_seq_length=200 \ --train_batch_size=16 \ --learning_rate=5e-5\ --num_train_epochs=2.0
在這裏的task_name就是咱們定義的數據處理類的鍵,BERT模型較大,加載時須要較大的內存,若是出現內存溢出的問題,能夠適當的下降batch_size的值。
目前迭代完以後的輸出比較少,並且只有等迭代結束後纔會有結果輸出,不利於觀察損失的變化,後續將修改輸出。目前的輸出結果:
測試集上的準確率達到了90.7% ,這個結果比Bi-LSTM + Attention(87.7%)的結果要好。
4 增長驗證集輸出的指標值
目前驗證集上的輸出指標值只有loss和accuracy,如上圖所示,然而在分類時,咱們可能還須要看auc,recall,precision的值。增長几行代碼就能夠搞定:
在個人代碼中743行這裏有個metric_fn函數,以前這個函數下只有loss和accuracy的計算,咱們在這裏加上auc,recall,precision的計算,而後加入到return的這個字典中就能夠了。如今的輸出結果:
5 關於BERT的問題
在run_classifier.py文件中,訓練模型,驗證模型都是用的tensorflow中的estimator接口,所以咱們沒法實如今訓練迭代100步就用驗證集驗證一次,在run_classifier.py文件中提供的方法是先運行完全部的epochs以後,再加載模型進行驗證。訓練模型時的代碼:
在個人代碼中948行這裏,在這裏咱們加入了幾行代碼,能夠實現訓練時輸出loss,就是上面的:
tensors_to_log = {"train loss": "loss/Mean:0"} logging_hook = tf.train.LoggingTensorHook( tensors=tensors_to_log, every_n_iter=100)
這是咱們添加進去的,加入了一個hooks的參數,讓訓練的時候沒迭代100步就輸出一次loss。然而這樣的意義並非很大。
下面的日誌能夠看到驗證時是加載訓練完的模型來進行驗證的,見下圖第一行:Restoring xxx
這種沒法在訓練時輸出驗證集上的結果,會致使咱們很難直觀的看到損失函數的變化。就沒法很方便的肯定模型是否收斂,這也是tensorflow中這些高級API的問題,高級封裝雖然讓書寫代碼更容易,但也讓代碼更死板。
bert的其餘應用在NLP-Project中的pre_trained_model中,包括bert+bilstm+crf作命名實體識別,bert+cnn作文本分類。