pytorch中的多维张量点积

pytorch中的多维张量点积,pytorch,tensor,dot-product,Pytorch,Tensor,Dot Product,我有两个8,1,128形状的张量,如下所示 q_s.shape Out[161]: torch.Size([8, 1, 128]) p_s.shape Out[162]: torch.Size([8, 1, 128]) 以上两个张量表示一批8个128维向量。我想要批量q_s与批量p_s的点积。我该怎么做?我尝试使用torch.tensordot函数,如下所示。它也像预期的那样工作。但它也会做额外的工作,我不想让它做。请参见下面的示例 dt = torch.tensordot(q_s, p_s

我有两个8,1,128形状的张量,如下所示

q_s.shape
Out[161]: torch.Size([8, 1, 128])

p_s.shape
Out[162]: torch.Size([8, 1, 128])
以上两个张量表示一批8个128维向量。我想要批量q_s与批量p_s的点积。我该怎么做?我尝试使用torch.tensordot函数,如下所示。它也像预期的那样工作。但它也会做额外的工作,我不想让它做。请参见下面的示例

dt = torch.tensordot(q_s, p_s, dims=([1,2], [1,2]))

dt
Out[176]: 
tensor([[0.9051, 0.9156, 0.7834, 0.8726, 0.8581, 0.7858, 0.7881, 0.8063],
        [1.0235, 1.5533, 1.2155, 1.2048, 1.3963, 1.1310, 1.1724, 1.0639],
        [0.8762, 1.3490, 1.2923, 1.0926, 1.4703, 0.9566, 0.9658, 0.8558],
        [0.8136, 1.0611, 0.9131, 1.1636, 1.0969, 0.9443, 0.9587, 0.8521],
        [0.6104, 0.9369, 0.9576, 0.8773, 1.3042, 0.7900, 0.8378, 0.6136],
        [0.8623, 0.9678, 0.8163, 0.9727, 1.1161, 1.6464, 0.9765, 0.7441],
        [0.6911, 0.8392, 0.6931, 0.7325, 0.8239, 0.7757, 1.0456, 0.6657],
        [0.8493, 0.8174, 0.8041, 0.9013, 0.8003, 0.7451, 0.7408, 1.1771]],
       grad_fn=<AsStridedBackward>)

dt.shape
Out[177]: torch.Size([8, 8])
正如我们所看到的,这产生了大小为8,8的张量,点积位于对角线上。有没有不同的方法来获得更小的8,1形状的张量,它只包含上面结果中对角线上的元素。更清楚地说,对角线上的元素是我们想要的两批dot产品所需的正确dot产品。索引[0][0]处的元素是q_s[0]和p_s[0]的点积。索引[1][1]处的元素是q_s[1]和p_s[1]的点积,依此类推


有没有更好的方法在pytorch中获得所需的点积?

您可以直接进行:

a = torch.rand(8, 1, 128)
b = torch.rand(8, 1, 128)

torch.sum(a * b, dim=(1, 2))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])    

torch.diag(torch.tensordot(a, b, dim=([1,2], [1,2])))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])

如果在求和中设置轴=2,则将得到形状为8,1的张量。

您可以直接执行:

a = torch.rand(8, 1, 128)
b = torch.rand(8, 1, 128)

torch.sum(a * b, dim=(1, 2))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])    

torch.diag(torch.tensordot(a, b, dim=([1,2], [1,2])))
# tensor([29.6896, 30.4994, 32.9577, 30.2220, 33.9913, 35.1095, 32.3631, 30.9153])
如果在求和中设置轴=2,将得到形状为8,1的张量