y_hat混淆點

2.python

課程中的損失函數定義爲:函數

def squared_loss(y_hat, y):
    return (y_hat - y.view(y_hat.size())) ** 2 / 2

將返回結果替換爲下面的哪個會致使會致使模型沒法訓練:(閱讀材料:code

(y_hat.view(-1) - y) ** 2 / 2class

(y_hat - y.view(-1)) ** 2 / 2im

(y_hat - y.view(y_hat.shape)) ** 2 / 2view

(y_hat - y.view(-1, 1)) ** 2 / 2vi

答案解釋co

y_hat的形狀是[n, 1],而y的形狀是[n],二者相減獲得的結果的形狀是[n, n],至關於用y_hat的每個元素分別減去y的全部元素,因此沒法獲得正確的損失值。對於第一個選項,y_hat.view(-1)的形狀是[n],與y一致,能夠相減;對於第二個選項,y.view(-1)的形狀還是[n],因此沒有解決問題;對於第三個選項和第四個選項,y.view(y_hat.shape)y.view(-1, 1)的形狀都是[n, 1],與y_hat一致,能夠相減。如下是一段示例代碼:oss

x = torch.arange(3)
y = torch.arange(3).view(3, 1)
print(x)
print(y)
print(x + y)

假如你正在實現一個全鏈接層,全鏈接層的輸入形狀是7×8,輸出形狀是 7×1,其中7是批量大小,則權重參數ww和偏置參數bb的形狀分別是____和____time

1×8, 1×1

1×8, 7×1

8×1, 1×1

8×1, 7×1

答案解釋:

設輸入批量爲X in mathbb{R}^{7 times 8}X∈R7×8,對應的輸出爲Y in \mathbb{R}^{7 times 1}Y∈R7×1,令權重參數爲w in mathbb{R}^{8 times 1}w∈R8×1,則Xw in mathbb{R}^{7 times 1}Xw∈R7×1,而後咱們給XwXw中的每一個元素加上的偏置是同樣的,因此偏置參數b in mathbb{R} ^{1 times 1}b∈R1×1,基於加法的廣播機制,能夠完成獲得輸出Y = Xw + bY=Xw+b。參數的形狀與批量大小沒有關係,也正是由於如此,對同一個模型,咱們能夠選擇不一樣的批量大小。

相關文章
相關標籤/搜索