今天使用hiddenlayer測試了下retinanet網絡的可視化。
首先,安裝hiddlayer,直接pip pip install git+https://github.com/waleedka/hiddenlayer.git
而後在終端加載模型並顯示:python
import model, torch import hiddenlayer as hl retinanet = model.resnet18(num_classes=100, pretrained=True).cuda() x = torch.rand((1, 3, 224, 224)).cuda().float() ann = torch.tensor([[[20.0, 30.0, 53.2, 33.3, 32.0]]]).cuda().float() hl.build_graph(retinanet, [x, ann]) hl.save('/home/willer/model.pdf')
模型太複雜了,放在這裏了。
昨天晚上對比着模型結構的pdf和代碼又看了下,發現仍是頗有用的,起碼對網絡的數據流動的認識更加清晰了。git