nn.Module vs nn.functional
前者會保存權重等信息,後者只是作運算python
parameters()
返回可訓練參數網絡
nn.ModuleList vs. nn.ParameterList vs. nn.Sequential
layer_list = [nn.Conv2d(5,5,3), nn.BatchNorm2d(5), nn.Linear(5,2)] class myNet(nn.Module): def __init__(self): super().__init__() self.layers = layer_list def forward(x): for layer in self.layers: x = layer(x) net = myNet() print(list(net.parameters())) # Parameters of modules in the layer_list don't show up.
nn.ModuleList
的做用就是wrap pthon list,這樣其中的參數會被註冊,所以能夠返回可訓練參數(ParameterList)。ide
nn.Sequential
的做用以下:ui
class myNet(nn.Module): def __init__(self): super().__init__() self.layers = nn.Sequential( nn.Relu(inplace=True), nn.Linear(10, 10) ) def forward(x): x = layer(x) x = torch.rand(10) net = myNet() print(net(x).shape)
能夠看到Sequential
的做用就是按照指定的順序構建網絡結構,獲得一個完整的模塊,而ModuleList
則只是像list那樣把元素集合起來而已。url
nn.modules vs. nn.children
class myNet(nn.Module): def __init__(self): super().__init__() self.convBN = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10)) self.linear = nn.Linear(10,2) def forward(self, x): pass Net = myNet() print("Printing children\n------------------------------") print(list(Net.children())) print("\n\nPrinting Modules\n------------------------------") print(list(Net.modules()))
輸出信息以下:spa
Printing children ------------------------------ [Sequential( (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), Linear(in_features=10, out_features=2, bias=True)] Printing Modules ------------------------------ [myNet( (convBN1): Sequential( (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (linear): Linear(in_features=10, out_features=2, bias=True) ), Sequential( (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ), Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), Linear(in_features=10, out_features=2, bias=True)]
能夠看到children
只會返回子元素,子元素多是單個操做,如Linear,也多是Sequential。 而modules()
返回的信息更加詳細,不只會返回children
同樣的信息,同時還會遞歸地返回,例如modules()
會迭代地返回Sequential
中包含的若干個子元素。.net
named_*
- named_parameters: 返回一個
iterator
,每次它會提供包含參數名的元組。
In [27]: x = torch.nn.Linear(2,3) In [28]: x_name_params = x.named_parameters() In [29]: next(x_name_params) Out[29]: ('weight', Parameter containing: tensor([[-0.5262, 0.3480], [-0.6416, -0.1956], [ 0.5042, 0.6732]], requires_grad=True)) In [30]: next(x_name_params) Out[30]: ('bias', Parameter containing: tensor([ 0.0595, -0.0386, 0.0975], requires_grad=True))
- named_modules
這個其實就是把上面提到的nn.modules
以iterator
的形式返回,每次讀取和上面同樣也是用next()
,示例以下:
In [46]: class myNet(nn.Module): ...: def __init__(self): ...: super().__init__() ...: self.convBN1 = nn.Sequential(nn.Conv2d(10,10,3), nn.BatchNorm2d(10)) ...: self.linear = nn.Linear(10,2) ...: ...: def forward(self, x): ...: pass ...: In [47]: net = myNet() In [48]: net_named_modules = net.named_modules() In [49]: next(net_named_modules) Out[49]: ('', myNet( (convBN1): Sequential( (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (linear): Linear(in_features=10, out_features=2, bias=True) )) In [50]: next(net_named_modules) Out[50]: ('convBN1', Sequential( (0): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) )) In [51]: next(net_named_modules) Out[51]: ('convBN1.0', Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))) In [52]: next(net_named_modules) Out[52]: ('convBN1.1', BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) In [53]: next(net_named_modules) Out[53]: ('linear', Linear(in_features=10, out_features=2, bias=True)) In [54]: next(net_named_modules) --------------------------------------------------------------------------- StopIteration Traceback (most recent call last) <ipython-input-54-05e848b071b8> in <module> ----> 1 next(net_named_modules) StopIteration:
- named_children
同named_modules
code