從 relu 的多種實現來看 torch.nn 與 torch.nn.functional 的區別與聯繫

從 relu 的多種實現來看 torch.nn 與 torch.nn.functional 的區別與聯繫

relu多種實現之間的關係

relu 函數在 pytorch 中總共有 3 次出現:python

  1. torch.nn.ReLU()
  2. torch.nn.functional.relu_() torch.nn.functional.relu_()
  3. torch.relu() torch.relu_()

而這3種不一樣的實現實際上是有固定的包裝關係,由上至下是由表及裏的過程。網絡

其中最後一個實際上並不被 pytorch 的官方文檔包含,同時也找不到對應的 python 代碼,只是在 __init__.pyi 中存在,由於他們來自於經過C++編寫的THNN庫。函數

下面經過分析源碼來進行具體分析:code

  1. torch.nn.ReLU()
    torch.nn 中的類表明的是神經網絡層,這裏咱們看到做爲類出現的 ReLU() 實際上只是調用了 torch.nn.functional 中的 relu relu_ 實現。
class ReLU(Module):
    r"""Applies the rectified linear unit function element-wise:

    :math:`\text{ReLU}(x)= \max(0, x)`

    Args:
        inplace: can optionally do the operation in-place. Default: ``False``

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input

    .. image:: scripts/activation_images/ReLU.png

    Examples::

        >>> m = nn.ReLU()
        >>> input = torch.randn(2)
        >>> output = m(input)


      An implementation of CReLU - https://arxiv.org/abs/1603.05201

        >>> m = nn.ReLU()
        >>> input = torch.randn(2).unsqueeze(0)
        >>> output = torch.cat((m(input),m(-input)))
    """
    __constants__ = ['inplace']

    def __init__(self, inplace=False):
        super(ReLU, self).__init__()
        self.inplace = inplace

    @weak_script_method
    def forward(self, input):
      # F 來自於 import nn.functional as F
        return F.relu(input, inplace=self.inplace)

    def extra_repr(self):
        inplace_str = 'inplace' if self.inplace else ''
        return inplace_str
  1. torch.nn.functional.relu() torch.nn.functional.relu_()
    其實這兩個函數也是調用了 torch.relu() and torch.relu_()
def relu(input, inplace=False):
    # type: (Tensor, bool) -> Tensor
    r"""relu(input, inplace=False) -> Tensor

    Applies the rectified linear unit function element-wise. See
    :class:`~torch.nn.ReLU` for more details.
    """
    if inplace:
        result = torch.relu_(input)
    else:
        result = torch.relu(input)
    return result


relu_ = _add_docstr(torch.relu_, r"""
relu_(input) -> Tensor

In-place version of :func:`~relu`.
""")

至此咱們對 RELU 函數在 torch 中的出現有了一個深刻的認識。實際上做爲基礎的兩個包,torch.nntorch.nn.functional 的關係是引用與包裝的關係。orm

torch.nn 與 torch.nn.functional 的區別與聯繫

結合上述對 relu 的分析,咱們可以更清晰的認識到兩個庫之間的聯繫。ip

一般來講 torch.nn.functional 調用了 THNN庫,實現核心計算,可是不對 learnable_parameters 例如 weight bias ,進行管理,爲模型的使用帶來不便。而 torch.nn 中實現的模型則對 torch.nn.functional,本質上是官方給出的對 torch.nn.functional的使用範例,咱們經過直接調用這些範例可以快速方便的使用 pytorch ,可是範例可能不可以照顧到全部人的使用需求,所以保留 torch.nn.functional 來爲這些用戶提供靈活性,他們能夠本身組裝須要的模型。所以 pytorch 可以在靈活性與易用性上取得平衡。element

特別注意的是,torch.nn不全都是對torch.nn.functional的範例,有一些調用了來自其餘庫的函數,例如經常使用的RNN型神經網絡族即沒有在torch.nn.functional中出現。文檔

咱們帶着這樣的思考再來看下一個例子做爲結束:input

對於Linear請注意⚠️對比兩個庫下實現的不一樣:源碼

  1. learnable parameters的管理
  2. 相互之間的調用關係
  3. 初始化過程
class Linear(Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
def linear(input, weight, bias=None):
    # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(bias, input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += bias
        ret = output
    return ret
相關文章
相關標籤/搜索