pytorch中torch.narrow()函數

torch.narrow(inputdimstartlength) → Tensorhtml

Returns a new tensor that is a narrowed version of input tensor. The dimension dim is input from start to start +length. The returned tensor and input tensor share the same underlying storage.python

Parameters
  • input (Tensor) – the tensor to narrow函數

  • dim (int) – the dimension along which to narrowurl

  • start (int) – the starting dimensionspa

  • length (int) – the distance to the ending dimension.net

Example:code

>>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) >>> torch.narrow(x, 0, 0, 2) tensor([[ 1, 2, 3],  [ 4, 5, 6]]) >>> torch.narrow(x, 1, 1, 2) tensor([[ 2, 3],  [ 5, 6],  [ 8, 9]])

根據定義得知,這個函數是返回tensor的第dim維切片start: start+length的數, 針對例子,htm

x.size() = (3, 3)get

torch.narrow(x, 0, 0, 2) == x[0:0+2, :]input

torch.narrow(x, 1, 2, 1) == x[:, 2:2+1]

相關文章
相關標籤/搜索