以前一直本身手寫各類triphard,triplet損失函數, 寫的比較暴力,而後今天一個學長給我在github上看了一個別人的triphard的寫法,一開始沒看懂,用的pytorch函數沒怎麼見過,看懂了以後, 被驚豔到了。。所以在此記錄一下,以及詳細註釋一下python
class TripletLoss(nn.Module): def __init__(self, margin=0.3): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) # 得到一個簡單的距離triplet函數 def forward(self, inputs, labels): n = inputs.size(0) # 獲取batch_size # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # 每一個數平方後, 進行加和(經過keepdim保持2維),再擴展成nxn維 dist = dist + dist.t() # 這樣每一個dis[i][j]表明的是第i個特徵與第j個特徵的平方的和 dist.addmm_(1, -2, inputs, inputs.t()) # 而後減去2倍的 第i個特徵*第j個特徵 從而經過徹底平方式獲得 (a-b)^2 dist = dist.clamp(min=1e-12).sqrt() # 而後開方 # For each anchor, find the hardest positive and negative mask = labels.expand(n, n).eq(labels.expand(n, n).t()) # 這裏dist[i][j] = 1表明i和j的label相同, =0表明i和j的label不相同 dist_ap, dist_an = [], [] for i in range(n): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) # 在i與全部有相同label的j的距離中找一個最大的 dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) # 在i與全部不一樣label的j的距離找一個最小的 dist_ap = torch.cat(dist_ap) # 將list裏的tensor拼接成新的tensor dist_an = torch.cat(dist_an) # Compute ranking hinge loss y = torch.ones_like(dist_an) # 聲明一個與dist_an相同shape的全1tensor loss = self.ranking_loss(dist_an, dist_ap, y) return loss