pytorch softmax的使用

torch.nn.functional.softmax(input, dim=None)
在這裏插入圖片描述html

import torch
import torch.nn as nn
m = nn.Softmax(dim=1) #注意是沿着那個維度計算
input = torch.randn(2,2)
print("input:")
print(input)
output = m(input)
print(output)

#注意區分如下結果,這是兩個不一樣size的tensor
input1=torch.randn(2,1)
print("input1:")
print(input1)
print(m(input1))
input2=torch.randn(1,2)
print("input2:")
print(input2)
print(m(input2))

在這裏插入圖片描述

相關文章
相關標籤/搜索