Pytorch中的切片说明

Pytorch中的切片说明,pytorch,slice,Pytorch,Slice,为什么每次的输出都是一样的 a = torch.tensor([0, 1, 2, 3, 4]) a[-2:] = torch.tensor([[[5, 6]]]) a 张量([0,1,2,5,6]) 张量([0,1,2,5,6]) 张量([0,1,2,5,6])Pytorch在Numpy之后,只要形状兼容,就可以进行切片,这意味着两侧具有相同的形状,或者右侧可以广播到切片的形状。从尾随维度开始,如果两个数组仅在其中一个为1的维度上不同,则两个数组是相同的。所以在这种情况下 a = torch.

为什么每次的输出都是一样的

a = torch.tensor([0, 1, 2, 3, 4])
a[-2:] = torch.tensor([[[5, 6]]])
a
张量([0,1,2,5,6])

张量([0,1,2,5,6])


张量([0,1,2,5,6])

Pytorch在Numpy之后,只要形状兼容,就可以进行切片,这意味着两侧具有相同的形状,或者右侧可以广播到切片的形状。从尾随维度开始,如果两个数组仅在其中一个为1的维度上不同,则两个数组是相同的。所以在这种情况下

a = torch.tensor([0, 1, 2, 3, 4])
b = torch.tensor([[[5, 6]]])
print(a[-2:].shape, b.shape)
>> torch.Size([2]) torch.Size([1, 1, 2])
Pytorch将执行以下比较:

  • a[-2:].shape[-1]
    b.shape[-1]
    相等,因此最后一个维度是兼容的
  • a[-2:].shape[-2]
    不存在,但
    b.shape[-2]
    为1,因此它们是兼容的
  • a[-2:].shape[-3]
    不存在,但
    b.shape[-3]
    为1,因此它们是兼容的
  • 所有维度都兼容,因此
    b
    可以广播到
    a
  • 最后,在执行赋值之前,Pytorch将
    b
    转换为
    张量([5,6])
    ,从而产生结果:

    a[-2:] = b
    print(a)
    >> tensor([0, 1, 2, 5, 6])
    

    Pytorch在这里遵循Numpy,只要形状兼容,就可以进行切片,这意味着两侧具有相同的形状,或者右手侧可以广播到切片的形状。从尾随维度开始,如果两个数组仅在其中一个为1的维度上不同,则两个数组是相同的。所以在这种情况下

    a = torch.tensor([0, 1, 2, 3, 4])
    b = torch.tensor([[[5, 6]]])
    print(a[-2:].shape, b.shape)
    >> torch.Size([2]) torch.Size([1, 1, 2])
    
    Pytorch将执行以下比较:

  • a[-2:].shape[-1]
    b.shape[-1]
    相等,因此最后一个维度是兼容的
  • a[-2:].shape[-2]
    不存在,但
    b.shape[-2]
    为1,因此它们是兼容的
  • a[-2:].shape[-3]
    不存在,但
    b.shape[-3]
    为1,因此它们是兼容的
  • 所有维度都兼容,因此
    b
    可以广播到
    a
  • 最后,在执行赋值之前,Pytorch将
    b
    转换为
    张量([5,6])
    ,从而产生结果:

    a[-2:] = b
    print(a)
    >> tensor([0, 1, 2, 5, 6])