pytorch 在測試階段,進行前向推斷時,顯存累加溢出問題

一、問題描述:測試

  pytorch中,在測試階段進行前向推斷運行時,隨着for循環次數的增長,顯存不斷累加變大,最終致使顯存溢出。spa


二、解決方法:
    使用以下代碼處理輸入數據:input

  假設X爲模型的輸入for循環

  X = X.cuda()model

  input_blobs = Variable(X, volatile=True)循環

  output = model(input_blobs)方法

  注意: 必定要設置 volatile=True 該參數,不然在for循環過程當中,顯存會不斷累加。
數據

相關文章
相關標籤/搜索