PyTorch 索引与切片-Tensor基本操作

作者 : admin 本文共629个字,预计阅读时间需要2分钟 发布时间: 2024-06-16 共4人阅读

以如下 tensor a 为例,展示常用的 indxing, slicing 及其他高阶操作

>>> a = torch.rand(4,3,28,28)
>>> a.shape
torch.Size([4, 3, 28, 28])
  • Indexing: 使用索引获取目标对象,[x,x,x,....]

    >>> a[0].shape
    torch.Size([3, 28, 28])
    >>> a[0,0].shape
    torch.Size([28, 28])
    >>> a[0,0,0].shape 
    torch.Size([28])
    >>> a[0,0,0,0].shape 
    torch.Size([])
    
  • Slicing: 使用切片获取一截目标对象,::step

    >>> a[:2].shape
    torch.Size([2, 3, 28, 28])
    >>> a[0, :2].shape 
    torch.Size([2, 28, 28])
    >>> a[0, 0, :2].shape 
    torch.Size([2, 28])
    >>> a[0, 0, 0, :2].shape 
    torch.Size([2])
    
  • 其他汇总:

    >>> a.index_select(dim, torch.tensor([idx_1,idx_2, ...]))  ## by specific idx
    >>> torch.take(a, torch.tensor([idx_1, idx2, ...]))  ## 不指定 dim 先打平 a 再按序提取 
    >>> a[a.ge(0.5)]  ## by mask=a.ge(0.5),该方法没有保持 shape
    

  • B站视频参考资料
本站无任何商业行为
个人在线分享 » PyTorch 索引与切片-Tensor基本操作
E-->