pytorch-卷積基本網絡結構-提取網絡參數-初始化網絡參數

基本的卷積神經網絡網絡

from torch import nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        layer1 = nn.Sequential() # 將網絡模型進行添加
        layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1)) # nn.Conv
        layer1.add_module('relu1', nn.ReLU(True))
        layer1.add_module('pool1', nn.MaxPool2d(2, 2))
        self.layer1 = layer1

        layer2 = nn.Sequential()
        layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
        layer2.add_module('relu2', nn.ReLU(True))
        layer2.add_module('pool2', nn.MaxPool2d(2, 2))
        self.layer2 = layer2

        layer3 = nn.Sequential()
        layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
        layer3.add_module('relu3', nn.ReLU(True))
        layer3.add_module('pool3', nn.MaxPool2d(2, 2))
        self.layer3 = layer3

        layer4 = nn.Sequential()
        layer4.add_module('fc1', nn.Linear(2048, 512))
        layer4.add_module('fc_relu1', nn.ReLU(True))
        layer4.add_module('fc2', nn.Linear(512, 64))
        layer4.add_module('fc_relu2', nn.ReLU(True))
        layer4.add_module('fc3', nn.Linear(64, 10))
        self.layer4 = layer4

    def forward(self, x):
        conv1 = self.layer1(x)
        conv2 = self.layer2(conv1)
        conv3 = self.layer3(conv2)
        fc_input = conv3.view(conv3.size(0), -1)
        fc_out = self.layer4(fc_input)

        return fc_out

model = SimpleCNN()
# print(model) # 打印輸出網絡結構

提取前兩層網絡結構 spa

new_model = nn.Sequential(*list(model.children())[:2])  # 提取前兩層的網絡結構, 構造nn.Sequential網絡串接, * 表示將裏面的內容一個個傳進去

提取全部的卷積層網絡code

conv_model = nn.Sequential()
# 提取全部的卷積層操做
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        name = name.replace('.', '_')
        conv_model.add_module(name, layer)
print(conv_model)

打印卷積層的網絡名字orm

for param in model.named_parameters():
    print(param)

對權重參數進行初始化操做blog

from torch.nn import init
# 對權重參數進行初始化操做
for m in model.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight.data)
        init.xavier_normal(m.weight.data)
        init.kaiming_normal(m.weight.data)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_()
相關文章
相關標籤/搜索