五分鐘理解:BCELoss 和 BCEWithLogitsLoss的區別

總體來講,這個區別,有沒有with logit就是看模型的最後一層有沒有加上sigmoid層。node

BCEwithlogitsloss = BCELoss + Sigmoidgit

1 舉個例子

導入必要庫,設置預測數值和模型真實類別(二分類問題)微信

import torchpred = torch.tensor([[-0.2],[0.2],[0.8]])target = torch.FloatTensor([[0],[0],[1]])

2 BCELoss

先把pred轉換成Sigmoid的0~1的機率。機器學習

sigmoid = torch.nn.Sigmoid()print(sigmoid(pred))

而後計算BCELoss:  svg

上面的pred(i)表示第i個樣本的通過sigmoid的預測機率。學習

  • 第一個樣本的loss:    spa

  • 第二個樣本的loss:    .net

  • 第三個樣本的loss:    code

求相反數的均值:   orm

而後用PyTorch的BCELoss來計算一下:

import torchpred = torch.tensor([[-0.2],[0.2],[0.8]])target = torch.FloatTensor([[0],[0],[1]])sigmoid = torch.nn.Sigmoid()loss = torch.nn.BCELoss()print('BCELoss:',loss(sigmoid(pred),target))

獲得答案:

3 BCEWithLogitsLoss

直接來看這個結果就好了:

import torchpred = torch.FloatTensor([[-0.2],[0.2],[0.8]])target = torch.FloatTensor([[0],[0],[1]])sigmoid = torch.nn.Sigmoid()loss = torch.nn.BCEWithLogitsLoss()print('BCEWithLogitsLoss:',loss(pred,target))

同樣,因此就相差一個Sigmoid罷了。


喜歡的話請關注咱們的微信公衆號~【機器學習煉丹術】。

  • 公衆號主要講統計學,數據科學,機器學習,深度學習,以及一些參加Kaggle競賽的經驗。

  • 公衆號內容建議做爲課後的一些相關知識的補充,飯後甜點。

微信搜索公衆號:【機器學習煉丹術】。期待您的關注。

本文分享自微信公衆號 - 機器學習煉丹術(liandanshu)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。

相關文章
相關標籤/搜索