pytorch 踩坑筆記之w.grad.data.zero_()

  在使用pytorch實現多項線性迴歸中,在grad更新時,每一次運算後都須要將上一次的梯度記錄清空,運用以下方法:spa

 w.grad.data.zero_() b.grad.data.zero_() 

   可是,運行程序就會報以下錯誤:code

  報錯,grad沒有data這個屬性,blog

  緣由是,在系統將w的grad值初始化爲none,第一次求梯度計算是在none值上進行報錯,天然會沒有data屬性get

  修改方法:添加一個判斷語句,從第二次循環開始執行求導運算class

for i in range(100): y_pred = multi_linear(x_train) loss = getloss(y_pred,y_train) if i != 0: w.grad.data.zero_() b.grad.data.zero_() loss.backward() w.data = w.data - 0.001 * w.grad.data b.data = b.data - 0.001 * b.grad.data
相關文章
相關標籤/搜索