import torch索引
import numpy as npimport
#索引select
a = torch.rand(4,3,28,28)numpy
print(a[0].shape,a[0,0].shape,a[0,0,2,4])#torch.Size([3, 28, 28]) torch.Size([28, 28]) tensor(0.4030)im
print(a[:2].shape)#torch.Size([2, 3, 28, 28])數據
print(a[:2,:1,:,:],a[:2,-1,:,:])#-1表明反向索引,通常咱們會使用正向索引,-1表明反向的第一個元素;採集
print(a[:,:,0:28:2,0:28:2])#隔行採樣index
#特殊位置獲取
print(a.index_select(0,torch.tensor([0,2]))) #採集第0個維度上的第0和第2個數據
print(a[2,...])#...表示後面三個維度默認
print(a[2,...,:4])#...是自動推測的並非每一個.表明一個維度
#mask掩碼索引,問題是它會把數據打平
x = torch.randn(3,4)
mask = x.ge(0.5)#大於等於0.5
print(torch.masked_select(x,mask))#取出全部大於等於0.5的值或索引
#
#torch.take(x,mask)