Tweet with Disaster(Kaggle NLP項目實戰)
項目介紹(Real or Not? NLP with Disaster Tweets)
項目kaggle連接:https://www.kaggle.com/c/nlp-getting-started/overviewhtml
在緊急狀況下,Twitter已經成爲一個重要的溝通渠道。智能手機的普及令人們可以實時宣佈正在觀察的緊急狀況。正由於如此,愈來愈多的機構對程序化監控Twitter(即救災組織和新聞機構)感興趣。可是,人們並不老是清楚一我的的話是否真的在宣告一場災難。好比下面的例子:python
做者明確地使用了「燃燒」這個詞,但它的意思是隱喻性的。這一點對人類來講是顯而易見的,特別是在視覺輔助下。但對機器來講就不那麼清楚了。react
在這場競爭中,你面臨着創建一個機器學習模型的挑戰,該模型能夠預測哪些Tweets是關於真正的災難的,哪些Tweets不是。git
EDA
數據預處理部分
1 導入數據
train = pd.read_csv('../input/nlp-getting-started/train.csv') test = pd.read_csv('../input/nlp-getting-started/test.csv') sample_submission = pd.read_csv('../input/nlp-getting-started/sample_submission.csv') # Print the shape of the training data print('{} rows and {} cols in training dataset.'.format(train.shape[0], train.shape[1])) print('{} rows and {} cols in training dataset.'.format(test.shape[0], test.shape[1])) # Inspecting the training data train.head(10)
2 描述性分析
查看標籤0和1的分佈狀況github
# Frequency for taget variable count_table = train.target.value_counts() display(count_table) # Plot class distribution plt.figure(figsize=(6,5)) plt.bar('False',count_table[0],label='False',width=0.6) plt.bar('True', count_table[1],label='True',width=0.6) plt.legend() plt.ylabel('Count of examples') plt.xlabel('Category') plt.title('Class Distribution') plt.ylim([0,4700]) plt.show()
每條推特長度的分佈
json
# Plot the frequency of tweets length bins = 150 plt.figure(figsize=(18,5)) plt.hist(train[train['target']==0]['length'], label= 'False',bins=bins,alpha=0.8) plt.hist(train[train['target']==1]['length'], label= 'True', bins=bins,alpha=0.8) plt.xlabel('Length of text (characters)') plt.ylabel('Count') plt.title('Frequency of tweets length') plt.legend(loc='best') plt.show()
兩種推特的長度分佈狀況對比
數組
# Frequency of tweets length in 2 classes fg, (ax1, ax2)=plt.subplots(1,2,figsize=(14,5)) ax1.hist(train[train['target']==0]['length'],color='red') ax1.set_title('Distribution of fake tweets') ax1.set_xlabel('Tweets length (characters)') ax1.set_ylabel('Count') ax2.hist(train[train['target']==1]['length'],color='blue') ax2.set_title('Distribution of true tweets') ax2.set_xlabel('Tweets length (characters)') ax2.set_ylabel('Count') fg.suptitle('Characater in classes') plt.show()
兩種推特出現的詞的數量分佈
網絡
# Plot the distribution of count of words words_true = train[train['target']==1]['text'].str.split().apply(len) words_false = train[train['target']==0]['text'].str.split().apply(len) plt.figure(figsize=(10,5)) plt.hist(words_false, label='False',alpha=0.8,bins=15) plt.hist(words_true, label='True',alpha=0.6,bins=15) plt.legend(loc='best') plt.title('Count of words in tweets') plt.xlabel('Count of words') plt.ylabel('Count') plt.show()
3 數據清洗
定義去除全部停用詞,語氣符號,html符號,表情符號的函數app
# Define a function to remove URL def remove_url(text): url = re.compile(r'https?://\S+|www\.\S+') return url.sub(r'',text) # Test function test = 'Address of this kernel: https://www.kaggle.com/lilstarboy/kernel4d04fe5667/edit' print(remove_url(test)) # Define a function to remove html tag def remove_html(text): html = re.compile(r'<.*?>') return html.sub(r'',text) # Test function test = """<div> <h1>Real or Fake</h1> <p>Kaggle </p> <a href="https://www.kaggle.com/c/nlp-getting-started">getting started</a> </div>""" print(remove_html(test)) # Define a function to remove emojis def remove_emoji(text): emoji_pattern = re.compile("[" u"\U0001F600-\U0001F64F" # emoticons u"\U0001F300-\U0001F5FF" # symbols & pictographs u"\U0001F680-\U0001F6FF" # transport & map symbols u"\U0001F1E0-\U0001F1FF" # flags (iOS) u"\U00002702-\U000027B0" u"\U000024C2-\U0001F251" "]+", flags=re.UNICODE) return emoji_pattern.sub(r'', text) remove_emoji("To test 🚀") # Define a function to remove punctuations def remove_punct(text): table=str.maketrans('','',string.punctuation) return text.translate(table) # Define a function to convert abbreviations to text abbreviations = { "$" : " dollar ", "€" : " euro ", "4ao" : "for adults only", "a.m" : "before midday", "a3" : "anytime anywhere anyplace", "aamof" : "as a matter of fact", "acct" : "account", "adih" : "another day in hell", "afaic" : "as far as i am concerned", "afaict" : "as far as i can tell", "afaik" : "as far as i know", "afair" : "as far as i remember", "afk" : "away from keyboard", "app" : "application", "approx" : "approximately", "apps" : "applications", "asap" : "as soon as possible", "asl" : "age, sex, location", "atk" : "at the keyboard", "ave." : "avenue", "aymm" : "are you my mother", "ayor" : "at your own risk", "b&b" : "bed and breakfast", "b+b" : "bed and breakfast", "b.c" : "before christ", "b2b" : "business to business", "b2c" : "business to customer", "b4" : "before", "b4n" : "bye for now", "b@u" : "back at you", "bae" : "before anyone else", "bak" : "back at keyboard", "bbbg" : "bye bye be good", "bbc" : "british broadcasting corporation", "bbias" : "be back in a second", "bbl" : "be back later", "bbs" : "be back soon", "be4" : "before", "bfn" : "bye for now", "blvd" : "boulevard", "bout" : "about", "brb" : "be right back", "bros" : "brothers", "brt" : "be right there", "bsaaw" : "big smile and a wink", "btw" : "by the way", "bwl" : "bursting with laughter", "c/o" : "care of", "cet" : "central european time", "cf" : "compare", "cia" : "central intelligence agency", "csl" : "can not stop laughing", "cu" : "see you", "cul8r" : "see you later", "cv" : "curriculum vitae", "cwot" : "complete waste of time", "cya" : "see you", "cyt" : "see you tomorrow", "dae" : "does anyone else", "dbmib" : "do not bother me i am busy", "diy" : "do it yourself", "dm" : "direct message", "dwh" : "during work hours", "e123" : "easy as one two three", "eet" : "eastern european time", "eg" : "example", "embm" : "early morning business meeting", "encl" : "enclosed", "encl." : "enclosed", "etc" : "and so on", "faq" : "frequently asked questions", "fawc" : "for anyone who cares", "fb" : "facebook", "fc" : "fingers crossed", "fig" : "figure", "fimh" : "forever in my heart", "ft." : "feet", "ft" : "featuring", "ftl" : "for the loss", "ftw" : "for the win", "fwiw" : "for what it is worth", "fyi" : "for your information", "g9" : "genius", "gahoy" : "get a hold of yourself", "gal" : "get a life", "gcse" : "general certificate of secondary education", "gfn" : "gone for now", "gg" : "good game", "gl" : "good luck", "glhf" : "good luck have fun", "gmt" : "greenwich mean time", "gmta" : "great minds think alike", "gn" : "good night", "g.o.a.t" : "greatest of all time", "goat" : "greatest of all time", "goi" : "get over it", "gps" : "global positioning system", "gr8" : "great", "gratz" : "congratulations", "gyal" : "girl", "h&c" : "hot and cold", "hp" : "horsepower", "hr" : "hour", "hrh" : "his royal highness", "ht" : "height", "ibrb" : "i will be right back", "ic" : "i see", "icq" : "i seek you", "icymi" : "in case you missed it", "idc" : "i do not care", "idgadf" : "i do not give a damn fuck", "idgaf" : "i do not give a fuck", "idk" : "i do not know", "ie" : "that is", "i.e" : "that is", "ifyp" : "i feel your pain", "IG" : "instagram", "iirc" : "if i remember correctly", "ilu" : "i love you", "ily" : "i love you", "imho" : "in my humble opinion", "imo" : "in my opinion", "imu" : "i miss you", "iow" : "in other words", "irl" : "in real life", "j4f" : "just for fun", "jic" : "just in case", "jk" : "just kidding", "jsyk" : "just so you know", "l8r" : "later", "lb" : "pound", "lbs" : "pounds", "ldr" : "long distance relationship", "lmao" : "laugh my ass off", "lmfao" : "laugh my fucking ass off", "lol" : "laughing out loud", "ltd" : "limited", "ltns" : "long time no see", "m8" : "mate", "mf" : "motherfucker", "mfs" : "motherfuckers", "mfw" : "my face when", "mofo" : "motherfucker", "mph" : "miles per hour", "mr" : "mister", "mrw" : "my reaction when", "ms" : "miss", "mte" : "my thoughts exactly", "nagi" : "not a good idea", "nbc" : "national broadcasting company", "nbd" : "not big deal", "nfs" : "not for sale", "ngl" : "not going to lie", "nhs" : "national health service", "nrn" : "no reply necessary", "nsfl" : "not safe for life", "nsfw" : "not safe for work", "nth" : "nice to have", "nvr" : "never", "nyc" : "new york city", "oc" : "original content", "og" : "original", "ohp" : "overhead projector", "oic" : "oh i see", "omdb" : "over my dead body", "omg" : "oh my god", "omw" : "on my way", "p.a" : "per annum", "p.m" : "after midday", "pm" : "prime minister", "poc" : "people of color", "pov" : "point of view", "pp" : "pages", "ppl" : "people", "prw" : "parents are watching", "ps" : "postscript", "pt" : "point", "ptb" : "please text back", "pto" : "please turn over", "qpsa" : "what happens", #"que pasa", "ratchet" : "rude", "rbtl" : "read between the lines", "rlrt" : "real life retweet", "rofl" : "rolling on the floor laughing", "roflol" : "rolling on the floor laughing out loud", "rotflmao" : "rolling on the floor laughing my ass off", "rt" : "retweet", "ruok" : "are you ok", "sfw" : "safe for work", "sk8" : "skate", "smh" : "shake my head", "sq" : "square", "srsly" : "seriously", "ssdd" : "same stuff different day", "tbh" : "to be honest", "tbs" : "tablespooful", "tbsp" : "tablespooful", "tfw" : "that feeling when", "thks" : "thank you", "tho" : "though", "thx" : "thank you", "tia" : "thanks in advance", "til" : "today i learned", "tl;dr" : "too long i did not read", "tldr" : "too long i did not read", "tmb" : "tweet me back", "tntl" : "trying not to laugh", "ttyl" : "talk to you later", "u" : "you", "u2" : "you too", "u4e" : "yours for ever", "utc" : "coordinated universal time", "w/" : "with", "w/o" : "without", "w8" : "wait", "wassup" : "what is up", "wb" : "welcome back", "wtf" : "what the fuck", "wtg" : "way to go", "wtpa" : "where the party at", "wuf" : "where are you from", "wuzup" : "what is up", "wywh" : "wish you were here", "yd" : "yard", "ygtr" : "you got that right", "ynk" : "you never know", "zzz" : "sleeping bored and tired" } def convert_abbrev(word): return abbreviations[word.lower()] if word.lower() in abbreviations.keys() else word def convert_abbrev_in_text(text): tokens = word_tokenize(text) tokens = [convert_abbrev(word) for word in tokens] text = ' '.join(tokens) return text # Test function test = 'This is very complex!!!!!??' print(remove_punct(test))
4 用詞雲進行可視化展現
# Wordcloud for not disaster tweets corpus_all_0 = create_corpus(df, 0) # Plot the wordcloud plt.figure(figsize=(15,8)) word_cloud = WordCloud( background_color='white', max_font_size = 80 ).generate(" ".join(corpus_all_0)) plt.imshow(word_cloud) plt.axis('off') plt.show() # Wordcloud for disaster tweets corpus_all_1 = create_corpus(df, 1) # Plot the wordcloud plt.figure(figsize=(15,8)) word_cloud = WordCloud( background_color='white', max_font_size = 80 ).generate(" ".join(corpus_all_1)) plt.imshow(word_cloud) plt.axis('off') plt.show()
沒有說起真實的災難的推特的詞雲:
說起真實災難的推特的詞雲
dom
導入Bert預訓練模型
介紹下Bert預訓練模型:
用Bert進行遷移學習和fine-tuning的原理你們能夠參考這篇論文https://arxiv.org/abs/1810.04805
這裏用的是Bert-based Uncased模型,是一個12層神經網絡,768個hidden layer,110M個參數的小模型(在Bert模型裏面確實算小了狗頭)
# Define hyperparameters MAXLEN = 128 BATCH_SIZE = 32 NUM_EPOCHS = 5 LEARNING_RATE = 3e-6 # Import bert tokenizer, config and model tokenizer = BertTokenizer.from_pretrained("https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt") config = BertConfig.from_pretrained("https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json") bert_model = TFBertModel.from_pretrained("https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",config=config)
接下來咱們使用Bert自帶的分詞器生成詞向量看看效果
# Convert the first sentence in 'text' column into word vector text = train['text'][0] print(text) input_ids = tokenizer.encode(text,max_length=MAXLEN) print(input_ids) print(tokenizer.convert_ids_to_tokens(input_ids))
構造Bert模型輸入
接下來咱們就要構造Bert模型的輸入層
這裏的Bert預訓練模型有三個輸入:
- 一個二維數組(batch_size,input_length)
- 每一個單詞的index
- 相應的attention_mask和對應的token_type_id
輸出層有兩個輸出
- 每一個時刻的hidden state(batch_size,input_length,hidden_size),是一個三維數組
- 每一個句子的向量表示(batch_size,input_length),由上一個hidden_state獲得
具體設置和預設參數請參考Bert的官方GitHub:https://github.com/google-research/bert
這裏咱們進行了一個Bert模型輸入的簡單構造,每一句句子的詞向量不夠的長度用0補充,因爲都是單個句子,因此token type都是0
# Build input values on the training data train_input_ids = [] train_attension_mask = [] train_token_type_ids = [] for text in train['text']: input_ids = tokenizer.encode(text,max_length=MAXLEN) padding_length = MAXLEN-len(input_ids) train_input_ids.append(input_ids+[0]*padding_length) train_attension_mask.append([1]*len(input_ids)+[0]*padding_length) train_token_type_ids.append([0]*MAXLEN) train_input_ids = np.array(train_input_ids) train_attension_mask = np.array(train_attension_mask) train_token_type_ids = np.array(train_token_type_ids) # Build input values on the testing data test_input_ids = [] test_attension_mask = [] test_token_type_ids = [] for text in test['text']: input_ids = tokenizer.encode(text,max_length=MAXLEN) padding_length = MAXLEN-len(input_ids) test_input_ids.append(input_ids+[0]*padding_length) test_attension_mask.append([1]*len(input_ids)+[0]*padding_length) test_token_type_ids.append([0]*MAXLEN) test_input_ids = np.array(test_input_ids) test_attension_mask = np.array(test_attension_mask) test_token_type_ids = np.array(test_token_type_ids) y_train = np.array(train['target'])
創建模型並訓練
接下來咱們就構造Bert模型,因爲二分類任務激活函數是sigmoid,Adam優化器其餘沒啥好說的
# Build the Bert-base-Uncased model input_ids = keras.layers.Input(shape=(MAXLEN,),dtype='int32') attension_mask = keras.layers.Input(shape=(MAXLEN,),dtype='int32') token_type_ids = keras.layers.Input(shape=(MAXLEN,),dtype='int32') _, x = bert_model([input_ids,attension_mask,token_type_ids]) outputs = keras.layers.Dense(1,activation='sigmoid')(x) model = keras.models.Model(inputs=[input_ids,attension_mask,token_type_ids],outputs=outputs) model.compile(loss='binary_crossentropy',optimizer=keras.optimizers.Adam(lr=LEARNING_RATE),metrics=['accuracy'])
接下來訓練
# Fit the Bert-base-Uncased model (train_input_ids,valid_input_ids, train_attension_mask,valid_attension_mask, train_token_type_ids,valid_token_type_ids,y_train,y_valid) = train_test_split(train_input_ids,train_attension_mask, train_token_type_ids,y_train,test_size=0.1, stratify=y_train, random_state=0) early_stopping = keras.callbacks.EarlyStopping(patience=3,restore_best_weights=True) model.fit([train_input_ids,train_attension_mask,train_token_type_ids],y_train, validation_data=([valid_input_ids,valid_attension_mask,valid_token_type_ids],y_valid), batch_size = BATCH_SIZE,epochs=NUM_EPOCHS,callbacks=[early_stopping])
看看summary
model.summary()
提交結果
# Use the model to do prediction y_pred = model.predict([test_input_ids,test_attension_mask,test_token_type_ids],batch_size=BATCH_SIZE,verbose=1).ravel() y_pred = (y_pred>=0.5).astype(int) # Export to submission submission = pd.read_csv("../input/nlp-getting-started/sample_submission.csv") submission['target'] = y_pred submission.to_csv('nlp_prediction.csv',index=False)
調參過程這裏就不詳細說了,通過幾回提交,獲得最好的成績是accuracy:0.83742
具體流程能夠參閱咱們的kaggle網頁https://www.kaggle.com/lilstarboy/pig-budt758b-project-notebook?scriptVersionId=33280711