pytorch中的广播元素乘法

pytorch中的广播元素乘法,pytorch,matrix-multiplication,array-broadcasting,Pytorch,Matrix Multiplication,Array Broadcasting,我在pytorch中有一个张量,大小torch.size([1443747128])。让我们把它命名为张量A。在这个张量中,128表示批量大小。我还有一个1D张量,大小torch.size([1443747])。我们叫它B。我想对B与A进行元素相乘,这样B与张量A的所有128列相乘(显然是以元素相乘的方式)。换句话说,我想沿着dimension=1广播元素相乘。 我如何在pytorch中实现这一点 如果张量a中没有涉及批量大小(batchsize=1),那么普通的*操作符将很容易进行乘法A*B则

我在pytorch中有一个张量,大小
torch.size([1443747128])
。让我们把它命名为张量
A
。在这个张量中,128表示批量大小。我还有一个1D张量,大小
torch.size([1443747])
。我们叫它
B
。我想对B与A进行元素相乘,这样B与张量
A
的所有128列相乘(显然是以元素相乘的方式)。换句话说,我想沿着
dimension=1
广播元素相乘。 我如何在pytorch中实现这一点

如果张量a中没有涉及批量大小(
batchsize=1
),那么普通的
*
操作符将很容易进行乘法
A*B
则会生成大小为
torch.size([1443747])
的合成张量。然而,我不明白为什么pytorch不沿着维度1传播张量乘法?有没有办法做到这一点


我想要的是,
B
应该以元素方式与
A
的所有128列相乘。因此,合成张量的大小将是
torch.size([1443747128])

尺寸应该匹配,如果你转置A或取消B,它应该工作:

C = A.transpose(1,0) * B    # shape: [128, 1443747]


请注意,这两种解决方案的形状是不同的。

看起来效果不错。我认为B和A的形状是沿着维度0匹配的,所以,
B*A
应该可以工作。乘法的顺序不会影响结果!(除非你做了适当的矩阵乘法B@A)。广播实际上是从最后一个维度开始工作的。
C = A * B.unsqueeze(dim=1)  # shape: [1443747, 128]