pytorch 網絡可視化

今天使用hiddenlayer測試了下retinanet網絡的可視化。
首先,安裝hiddlayer,直接pip pip install git+https://github.com/waleedka/hiddenlayer.git
而後在終端加載模型並顯示:python

import model, torch
import hiddenlayer as hl

retinanet = model.resnet18(num_classes=100, pretrained=True).cuda()
x = torch.rand((1, 3, 224, 224)).cuda().float()
ann = torch.tensor([[[20.0, 30.0, 53.2, 33.3, 32.0]]]).cuda().float()
hl.build_graph(retinanet, [x, ann])
hl.save('/home/willer/model.pdf')

模型太複雜了,放在這裏了。
昨天晚上對比着模型結構的pdf和代碼又看了下,發現仍是頗有用的,起碼對網絡的數據流動的認識更加清晰了。git

相關文章
相關標籤/搜索