PyTorch 拼接与拆分-Tensor基本操作

作者 : admin 本文共631个字,预计阅读时间需要2分钟 发布时间: 2024-06-16 共1人阅读

拼接: cat, stack …

  • 使用 cat 在指定维度 dim 上拼接: torch.cat(element_list, dim)

    >>> a = torch.rand(2,3) 
    >>> b = torch.rand(1,3) 
    >>> c = torch.cat([a,b], dim=0) 
    >>> c.shape
    torch.Size([3, 3])
    
  • 使用 stack 在新增维度 dim 上拼接: torch.cat(element_list, dim),

    • 注:element_list 中 element 的 shape 必须完全一致
    >>> a = torch.rand(2,3) 
    >>> b = torch.rand(2,3) 
    >>> c = torch.stack([a,b], dim=0)
    >>> c.shape
    torch.Size([2, 2, 3])
    

拆分:split,chunk …

  • 使用 split 根据长度拆分:a.split(l, dim)
    • 注:长度不一样时:a.split(l_list, dim)
    >>> a.split(1, dim=0)  # 或 a.split([1,1], dim=0)
    (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
    
  • 使用 chunk根据数量拆分:a.chunk(n, dim)
    >>> a.chunk(2, dim=0) 
    (tensor([[0.7967, 0.5056, 0.7963]]), tensor([[0.8603, 0.7029, 0.7590]]))
    

  • B站视频参考资料
本站无任何商业行为
个人在线分享 » PyTorch 拼接与拆分-Tensor基本操作
E-->