[Pytorch]深度模型的顯存計算以及優化

原文連接:https://oldpan.me/archives/how-to-calculate-gpu-memoryphp

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》前言

親,顯存炸了,你的顯卡快冒煙了!html

torch.FatalError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THC/generic/THCStorage.cu:58

想必這是全部煉丹師們最不想看到的錯誤,沒有之一。python

OUT OF MEMORY,顯然是顯存裝不下你那麼多的模型權重還有中間變量,而後程序奔潰了。怎麼辦,其實辦法有不少,及時清空中間變量,優化代碼,減小batch,等等等等,都可以減小顯存溢出的風險。git

可是這篇要說的是上面這一切優化操做的基礎,如何去計算咱們所使用的顯存。學會如何計算出來咱們設計的模型以及中間變量所佔顯存的大小,想必知道了這一點,咱們對本身顯存也就會駕輕就熟了。github

如何計算

首先咱們應該瞭解一下基本的數據量信息:算法

  • 1 G = 1000 MB
  • 1 M = 1000 KB
  • 1 K = 1000 Byte
  • 1 B = 8 bit

好,確定有人會問爲何是1000而不是1024,這裏不過多討論,只能說兩種說法都是正確的,只是應用場景略有不一樣。這裏統一按照上面的標準進行計算。bash

而後咱們說一下咱們日常使用的向量所佔的空間大小,以Pytorch官方的數據格式爲例(全部的深度學習框架數據格式都遵循同一個標準):服務器

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

咱們只須要看左邊的信息,在日常的訓練中,咱們常用的通常是這兩種類型:markdown

  • float32 單精度浮點型
  • int32 整型

通常一個8-bit的整型變量所佔的空間爲1B也就是8bit。而32位的float則佔4B也就是32bit。而雙精度浮點型double和長整型long在日常的訓練中咱們通常不會使用。網絡

ps:消費級顯卡對單精度計算有優化,服務器級別顯卡對雙精度計算有優化。

也就是說,假設有一幅RGB三通道真彩色圖片,長寬分別爲500 x 500,數據類型爲單精度浮點型,那麼這張圖所佔的顯存的大小爲:500 x 500 x 3 x 4B = 3M。

而一個(256,3,100,100)-(N,C,H,W)的FloatTensor所佔的空間爲256 x 3 x 100 x 100 x 4B = 31M

很少是吧,不要緊,好戲纔剛剛開始。

顯存去哪兒了

看起來一張圖片(3x256x256)和卷積層(256x100x100)所佔的空間並不大,那爲何咱們的顯存依舊仍是用的比較多,緣由很簡單,佔用顯存比較多空間的並非咱們輸入圖像,而是神經網絡中的中間變量以及使用optimizer算法時產生的巨量的中間參數

咱們首先來簡單計算一下Vgg16這個net須要佔用的顯存:

一般一個模型佔用的顯存也就是兩部分:

  • 模型自身的參數(params)
  • 模型計算產生的中間變量(memory)

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

圖片來自cs231n,這是一個典型的sequential-net,自上而下很順暢,咱們能夠看到咱們輸入的是一張224x224x3的三通道圖像,能夠看到一張圖像只佔用150x4k,但上面標註的是150k,這是由於上圖中在計算的時候默認的數據格式是8-bit而不是32-bit,因此最後的結果要乘上一個4。

咱們能夠看到,左邊的memory值表明:圖像輸入進去,圖片以及所產生的中間卷積層所佔的空間。咱們都知道,這些形形色色的深層卷積層也就是深度神經網絡進行「思考」的過程:

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

圖片從3通道變爲64 –> 128 –> 256 –> 512 …. 這些都是卷積層,而咱們的顯存也主要是他們佔用了。

還有上面右邊的params,這些是神經網絡的權重大小,能夠看到第一層卷積是3×3,而輸入圖像的通道是3,輸出通道是64,因此很顯然,第一個卷積層權重所佔的空間是 (3 x 3 x 3) x 64。

另外還有一個須要注意的是中間變量在backward的時候會翻倍!

爲何,舉個例子,下面是一個計算圖,輸入x,通過中間結果z,而後獲得最終變量L

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

咱們在backward的時候須要保存下來的中間值。輸出是L,而後輸入x,咱們在backward的時候要求Lx的梯度,這個時候就須要在計算鏈Lx中間的z

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

dz/dx這個中間值固然要保留下來以用於計算,因此粗略估計,backward的時候中間變量的佔用了是forward的兩倍!

優化器和動量

要注意,優化器也會佔用咱們的顯存!

爲何,看這個式子:

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

上式是典型的SGD隨機降低法的整體公式,權重W在進行更新的時候,會產生保存中間變量《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》,也就是在優化的時候,模型中的params參數所佔用的顯存量會翻倍。

固然這只是SGD優化器,其餘複雜的優化器若是在計算時須要的中間變量多的時候,就會佔用更多的內存。

模型中哪些層會佔用顯存

有參數的層即會佔用顯存的層。咱們通常的卷積層都會佔用顯存,而咱們常用的激活層Relu沒有參數就不會佔用了。

佔用顯存的層通常是:

  • 卷積層,一般的conv2d
  • 全鏈接層,也就是Linear層
  • BatchNorm層
  • Embedding層

而不佔用顯存的則是:

  • 剛纔說到的激活層Relu等
  • 池化層
  • Dropout層

具體計算方式:

  • Conv2d(Cin, Cout, K): 參數數目:Cin × Cout × K × K
  • Linear(M->N): 參數數目:M×N
  • BatchNorm(N): 參數數目: 2N
  • Embedding(N,W): 參數數目: N × W

額外的顯存

總結一下,咱們在整體的訓練中,佔用顯存大概分如下幾類:

  • 模型中的參數(卷積層或其餘有參數的層)
  • 模型在計算時產生的中間參數(也就是輸入圖像在計算時每一層產生的輸入和輸出)
  • backward的時候產生的額外的中間參數
  • 優化器在優化時產生的額外的模型參數

但其實,咱們佔用的顯存空間爲何比咱們理論計算的還要大,緣由大概是由於深度學習框架一些額外的開銷吧,不過若是經過上面公式,理論計算出來的顯存和實際不會差太多的。

如何優化

優化除了算法層的優化,最基本的優化無非也就一下幾點:

  • 減小輸入圖像的尺寸
  • 減小batch,減小每次的輸入圖像數量
  • 多使用下采樣,池化層
  • 一些神經網絡層能夠進行小優化,利用relu層中設置inplace
  • 購買顯存更大的顯卡
  • 從深度學習框架上面進行優化

下篇文章我會說明如何在Pytorch這個深度學習框架中跟蹤顯存的使用量,而後針對Pytorch這個框架進行有目的顯存優化。

參考:
https://blog.csdn.net/liusandian/article/details/79069926

《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》

原文連接:https://ptorch.com/news/181.html


前言


在上篇文章《淺談深度學習:如何計算模型以及中間變量的顯存佔用大小》中咱們對如何計算各類變量所佔顯存大小進行了一些探索。而這篇文章咱們着重講解如何利用Pytorch深度學習框架的一些特性,去查看咱們當前使用的變量所佔用的顯存大小,以及一些優化工做。如下代碼所使用的平臺框架爲Pytorch。


優化顯存


在Pytorch中優化顯存是咱們處理大量數據時必要的作法,由於咱們並不可能擁有無限的顯存。顯存是有限的,而數據是無限的,咱們只有優化顯存的使用量纔可以最大化地利用咱們的數據,實現多種多樣的算法。


估測模型所佔的內存


上篇文章中說過,一個模型所佔的顯存無非是這兩種:



  • 模型權重參數

  • 模型所儲存的中間變量


其實權重參數通常來講並不會佔用不少的顯存空間,主要佔用顯存空間的仍是計算時產生的中間變量,當咱們定義了一個model以後,咱們能夠經過如下代碼簡單計算出這個模型權重參數所佔用的數據量:


import numpy as np

# model是咱們在pytorch定義的神經網絡層
# model.parameters()取出這個model全部的權重參數
para = sum([np.prod(list(p.size())) for p in model.parameters()])

假設咱們有這樣一個model:


Sequential(
(conv_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_1): ReLU(inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu_2): ReLU(inplace)
(pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv_3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

而後咱們獲得的para112576,可是咱們計算出來的僅僅是權重參數的「數量」,單位是B,咱們須要轉化一下:


# 下面的type_size是4,由於咱們的參數是float32也就是4B,4個字節
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

這樣就能夠打印出:


Model Sequential : params: 0.450304M

可是咱們以前說過一個神經網絡的模型,不只僅有權重參數還要計算中間變量的大小。怎麼去計算,咱們能夠假設一個輸入變量,而後將這個輸入變量投入這個模型中,而後咱們主動提取這些計算出來的中間變量:


# model是咱們加載的模型
# input是實際中投入的input(Tensor)變量

# 利用clone()去複製一個input,這樣不會對input形成影響
input_ = input.clone()
# 確保不須要計算梯度,由於咱們的目的只是爲了計算中間變量而已
input_.requires_grad_(requires_grad=False)

mods = list(model.modules())
out_sizes = []

for i in range(1, len(mods)):
m
= mods[i]
# 注意這裏,若是relu激活函數是inplace則不用計算
if isinstance(m, nn.ReLU):
if m.inplace:
continue
out
= m(input_)
out_sizes.append(np.array(out.size()))
input_ = out

total_nums = 0
for i in range(len(out_sizes)):
s
= out_sizes[i]
nums = np.prod(np.array(s))
total_nums += nums

上面獲得的值是模型在運行時候產生全部的中間變量的「數量」,固然咱們須要換算一下:


# 打印兩種,只有 forward 和 foreward、backward的狀況
print('Model {} : intermedite variables: {:3f} M (without backward)'
.format(model._get_name(), total_nums * type_size / 1000 / 1000))
print('Model {} : intermedite variables: {:3f} M (with backward)'
.format(model._get_name(), total_nums * type_size*2 / 1000 / 1000))

由於在backward的時候全部的中間變量須要保存下來再來進行計算,因此咱們在計算backward的時候,計算出來的中間變量須要乘個2。


而後咱們得出,上面這個模型的中間變量須要的佔用的顯存,很顯然,中間變量佔用的值比模型自己的權重值多多了。若是進行一次backward那麼須要的就更多。


Model Sequential : intermedite variables: 336.089600 M (without backward)
Model Sequential : intermedite variables: 672.179200 M (with backward)

咱們總結一下以前的代碼:


# 模型顯存佔用監測函數
# model:輸入的模型
# input:實際中須要輸入的Tensor變量
# type_size 默認爲 4 默認類型爲 float32

def modelsize(model, input, type_size=4):
para = sum([np.prod(list(p.size())) for p in model.parameters()])
print('Model {} : params: {:4f}M'.format(model._get_name(), para * type_size / 1000 / 1000))

input_ = input.clone()
input_.requires_grad_(requires_grad=<span class="hljs-keyword">False</span>)

mods = list(model.modules())
out_sizes = []

<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1</span>, len(mods)):
    m = mods[i]
    <span class="hljs-keyword">if</span> isinstance(m, nn.ReLU):
        <span class="hljs-keyword">if</span> m.inplace:
            <span class="hljs-keyword">continue</span>
    out = m(input_)
    out_sizes.append(np.array(out.size()))
    input_ = out

total_nums = <span class="hljs-number">0</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(len(out_sizes)):
    s = out_sizes[i]
    nums = np.prod(np.array(s))
    total_nums += nums

print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (without backward)'</span>
      .format(model._get_name(), total_nums * type_size / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (with backward)'</span>
      .format(model._get_name(), total_nums * type_size*<span class="hljs-number">2</span> / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))
input_ = input.clone() input_.requires_grad_(requires_grad=<span class="hljs-keyword">False</span>) mods = list(model.modules()) out_sizes = [] <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">1</span>, len(mods)): m = mods[i] <span class="hljs-keyword">if</span> isinstance(m, nn.ReLU): <span class="hljs-keyword">if</span> m.inplace: <span class="hljs-keyword">continue</span> out = m(input_) out_sizes.append(np.array(out.size())) input_ = out total_nums = <span class="hljs-number">0</span> <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(len(out_sizes)): s = out_sizes[i] nums = np.prod(np.array(s)) total_nums += nums print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (without backward)'</span> .format(model._get_name(), total_nums * type_size / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>)) print(<span class="hljs-string">'Model {} : intermedite variables: {:3f} M (with backward)'</span> .format(model._get_name(), total_nums * type_size*<span class="hljs-number">2</span> / <span class="hljs-number">1000</span> / <span class="hljs-number">1000</span>))

固然咱們計算出來的佔用顯存值僅僅是作參考做用,由於Pytorch在運行的時候須要額外的顯存值開銷,因此實際的顯存會比咱們計算的稍微大一些。


關於inplace=False

咱們都知道激活函數Relu()有一個默認參數inplace,默認設置爲False,當設置爲True時,咱們在經過relu()計算時的獲得的新值不會佔用新的空間而是直接覆蓋原來的值,這也就是爲何當inplace參數設置爲True時能夠節省一部份內存的緣故。


《如何在Pytorch中精細化利用顯存》


犧牲計算速度減小顯存使用量


Pytorch-0.4.0出來了一個新的功能,能夠將一個計算過程分紅兩半,也就是若是一個模型須要佔用的顯存太大了,咱們就能夠先計算一半,保存後一半須要的中間結果,而後再計算後一半。


也就是說,新的checkpoint容許咱們只存儲反向傳播所須要的部份內容。若是當中缺乏一個輸出(爲了節省內存而致使的),checkpoint將會從最近的檢查點從新計算中間輸出,以便減小內存使用(固然計算時間增長了):


# 輸入
input = torch.rand(1, 10)
# 假設咱們有一個很是深的網絡
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)
output = model(input)

上面的模型須要佔用不少的內存,由於計算中會產生不少的中間變量。爲此checkpoint就能夠幫助咱們來節省內存的佔用了。


# 首先設置輸入的input=&gt;requires_grad=True
# 若是不設置可能會致使獲得的gradient爲0

input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]

# 定義要計算的層函數,能夠看到咱們定義了兩個
# 一個計算前500個層,另外一個計算後500個層

def run_first_half(*args):
x = args[0]
for layer in layers[:500]:
x = layer(x)
return x

def run_second_half(*args):
x = args[0]
for layer in layers[500:-1]:
x = layer(x)
return x

# 咱們引入新加的checkpoint
from torch.utils.checkpoint import checkpoint

x = checkpoint(run_first_half, input)
x = checkpoint(run_second_half, x)
# 最後一層單獨調出來執行
x = layers-1
x.sum.backward() # 這樣就能夠了

對於Sequential-model來講,由於Sequential()中能夠包含不少的block,因此官方提供了另外一個功能包:


input = torch.rand(1, 10, requires_grad=True)
layers = [nn.Linear(10, 10) for _ in range(1000)]
model = nn.Sequential(*layers)

from torch.utils.checkpoint import checkpoint_sequential

# 分紅兩個部分
num_segments = 2
x = checkpoint_sequential(model, num_segments, input)
x.sum().backward() # 這樣就能夠了

跟蹤顯存使用狀況


顯存的使用狀況,在編寫程序中咱們可能沒法精確計算,可是咱們能夠經過pynvml這個Nvidia的Python環境庫和Python的垃圾回收工具,能夠實時地打印咱們使用的顯存以及哪些Tensor使用了咱們的顯存。


相似於下面的報告:


# 08-Jun-18-17:56:51-gpu_mem_prof

At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">39</span>                        Total Used Memory:<span class="hljs-number">399.4</span>  Mb
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                        Total Used Memory:<span class="hljs-number">992.5</span>  Mb
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                         (<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">1.82</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span>                         (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                       Total Used Memory:<span class="hljs-number">1088.5</span> Mb
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)       <span class="hljs-number">0</span>.<span class="hljs-number">14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)      <span class="hljs-number">0</span>.<span class="hljs-number">28</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">0</span>.<span class="hljs-number">56</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)        <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">2.25</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">4.5</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">9.0</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">64</span>,)                <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>)     <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">128</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">512</span>,)               <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">3</span>,)                 <span class="hljs-number">1.14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>;
+ __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span>                        (<span class="hljs-number">256</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>)     <span class="hljs-number">1.12</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>;
...</code></pre>

如下是相關的代碼,目前代碼依然有些地方須要修改,等修改完善好我會將完整代碼以及使用說明放到github上:https://github.com/Oldpan/Pytorch-Memory-Utils 請你們多多留意。

import datetime
import linecache
import os

import gc
import pynvml
import torch
import numpy as np

print_tensor_sizes = True
last_tensor_sizes = set()
gpu_profile_fn = f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_prof.txt'

# if 'GPU_DEBUG' in os.environ:
# print('profiling gpu usage to ', gpu_profile_fn)

lineno = None
func_name = None
filename = None
module_name = None

# fram = inspect.currentframe()
# func_name = fram.f_code.co_name
# filename = fram.f_globals["__file__"]
# ss = os.path.dirname(os.path.abspath(filename))
# module_name = fram.f_globals["__name__"]

def gpu_profile(frame, event):
    # it is _about to_ execute (!)
    global last_tensor_sizes
    global lineno, func_name, filename, module_name

    if event == 'line':
        try:
            # about _previous_ line (!)
            if lineno is not None:
                pynvml.nvmlInit()
                # handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ['GPU_DEBUG']))
                handle = pynvml.nvmlDeviceGetHandleByIndex(0)
                meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
                line = linecache.getline(filename, lineno)
                where_str = module_name+' '+func_name+':'+' line '+str(lineno)

                with open(gpu_profile_fn, 'a+') as f:
                    f.write(f"At {where_str:&lt;50}"
                            f"Total Used Memory:{meminfo.used/1024**2:&lt;7.1f}Mb\n")

                    if print_tensor_sizes is True:
                        for tensor in get_tensors():
                            if not hasattr(tensor, 'dbg_alloc_where'):
                                tensor.dbg_alloc_where = where_str
                        new_tensor_sizes = {(type(x), tuple(x.size()), np.prod(np.array(x.size()))*4/1024**2,
                                             x.dbg_alloc_where) for x in get_tensors()}
                        for t, s, m, loc in new_tensor_sizes - last_tensor_sizes:
                            f.write(f'+ {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n')
                        for t, s, m, loc in last_tensor_sizes - new_tensor_sizes:
                            f.write(f'- {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n')
                        last_tensor_sizes = new_tensor_sizes
                pynvml.nvmlShutdown()

            # save details about line _to be_ executed
            lineno = None

            func_name = frame.f_code.co_name
            filename = frame.f_globals["__file__"]
            if (filename.endswith(".pyc") or
                    filename.endswith(".pyo")):
                filename = filename[:-1]
            module_name = frame.f_globals["__name__"]
            lineno = frame.f_lineno

            return gpu_profile

        except Exception as e:
            print('A exception occured: {}'.format(e))

    return gpu_profile

def get_tensors():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                tensor = obj
            else:
                continue
            if tensor.is_cuda:
                yield tensor
        except Exception as e:
            print('A exception occured: {}'.format(e))

須要注意的是,linecache中的getlines只能讀取緩衝過的文件,若是這個文件沒有運行過則返回無效值。Python 的垃圾收集機制會在變量沒有應引用的時候立馬進行回收,可是爲何模型中計算的中間變量在執行結束後還會存在呢。既然都沒有引用了爲何還會佔用空間?

一種可能的狀況是這些引用不在Python代碼中,而是在神經網絡層的運行中爲了backward被保存爲gradient,這些引用都在計算圖中,咱們在程序中是沒法看到的:

《如何在Pytorch中精細化利用顯存》

後記

實際中咱們會有些只使用一次的模型,爲了節省顯存,咱們須要一邊計算一遍清除中間變量,使用del進行操做。限於篇幅這裏不進行講解,下一篇會進行說明。

原文地址:如何在Pytorch中精細化利用顯存

        <br>
        原創文章,轉載請註明 :<a href="https://ptorch.com/news/181.html" target="_blank">如何在Pytorch中精細化利用顯存以及提升Pytorch顯存利用率 - pytorch中文網</a><br>
        原文出處:   https://ptorch.com/news/181.html<br>
        問題交流羣 :168117787
    </div>
At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">39</span> Total Used Memory:<span class="hljs-number">399.4</span> Mb At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> Total Used Memory:<span class="hljs-number">992.5</span> Mb + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">1</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">1.82</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">40</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; At __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> Total Used Memory:<span class="hljs-number">1088.5</span> Mb + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">28</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">56</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">2.25</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">256</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">4.5</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">9.0</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">64</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">682</span>, <span class="hljs-number">700</span>) <span class="hljs-number">5.46</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">128</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">512</span>,) <span class="hljs-number">0</span>.<span class="hljs-number">00</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">3</span>,) <span class="hljs-number">1.14</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.Tensor'</span>&amp;<span class="hljs-keyword">gt</span>; + __main_<span class="hljs-number">_</span> &amp;<span class="hljs-keyword">lt</span>;module&amp;<span class="hljs-keyword">gt</span>;: line <span class="hljs-number">126</span> (<span class="hljs-number">256</span>, <span class="hljs-number">128</span>, <span class="hljs-number">3</span>, <span class="hljs-number">3</span>) <span class="hljs-number">1.12</span> M &amp;<span class="hljs-keyword">lt</span>;class <span class="hljs-string">'torch.nn.parameter.Parameter'</span>&amp;<span class="hljs-keyword">gt</span>; ...</code></pre>import datetime import linecache import os import gc import pynvml import torch import numpy as np print_tensor_sizes = True last_tensor_sizes = set() gpu_profile_fn = f'{datetime.datetime.now():%d-%b-%y-%H:%M:%S}-gpu_mem_prof.txt' # if 'GPU_DEBUG' in os.environ: # print('profiling gpu usage to ', gpu_profile_fn) lineno = None func_name = None filename = None module_name = None # fram = inspect.currentframe() # func_name = fram.f_code.co_name # filename = fram.f_globals["__file__"] # ss = os.path.dirname(os.path.abspath(filename)) # module_name = fram.f_globals["__name__"] def gpu_profile(frame, event): # it is _about to_ execute (!) global last_tensor_sizes global lineno, func_name, filename, module_name if event == 'line': try: # about _previous_ line (!) if lineno is not None: pynvml.nvmlInit() # handle = pynvml.nvmlDeviceGetHandleByIndex(int(os.environ['GPU_DEBUG'])) handle = pynvml.nvmlDeviceGetHandleByIndex(0) meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle) line = linecache.getline(filename, lineno) where_str = module_name+' '+func_name+':'+' line '+str(lineno) with open(gpu_profile_fn, 'a+') as f: f.write(f"At {where_str:&lt;50}" f"Total Used Memory:{meminfo.used/1024**2:&lt;7.1f}Mb\n") if print_tensor_sizes is True: for tensor in get_tensors(): if not hasattr(tensor, 'dbg_alloc_where'): tensor.dbg_alloc_where = where_str new_tensor_sizes = {(type(x), tuple(x.size()), np.prod(np.array(x.size()))*4/1024**2, x.dbg_alloc_where) for x in get_tensors()} for t, s, m, loc in new_tensor_sizes - last_tensor_sizes: f.write(f'+ {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n') for t, s, m, loc in last_tensor_sizes - new_tensor_sizes: f.write(f'- {loc:&lt;50} {str(s):&lt;20} {str(m)[:4]} M {str(t):&lt;10}\n') last_tensor_sizes = new_tensor_sizes pynvml.nvmlShutdown() # save details about line _to be_ executed lineno = None func_name = frame.f_code.co_name filename = frame.f_globals["__file__"] if (filename.endswith(".pyc") or filename.endswith(".pyo")): filename = filename[:-1] module_name = frame.f_globals["__name__"] lineno = frame.f_lineno return gpu_profile except Exception as e: print('A exception occured: {}'.format(e)) return gpu_profile def get_tensors(): for obj in gc.get_objects(): try: if torch.is_tensor(obj): tensor = obj else: continue if tensor.is_cuda: yield tensor except Exception as e: print('A exception occured: {}'.format(e))<br> 原創文章,轉載請註明 :<a href="https://ptorch.com/news/181.html" target="_blank">如何在Pytorch中精細化利用顯存以及提升Pytorch顯存利用率 - pytorch中文網</a><br> 原文出處: https://ptorch.com/news/181.html<br> 問題交流羣 :168117787 </div>
相關文章
相關標籤/搜索