如何在PyTorch中为批处理数据添加“点”权重?

如何在PyTorch中为批处理数据添加“点”权重?,pytorch,Pytorch,我有批处理数据,希望dot()添加到数据中。W是可训练参数。 如何在批处理数据和权重之间添加点 hid_dim = 32 data = torch.randn(10, 2, 3, hid_dim) data = data.view(10, 2*3, hid_dim) W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter result = torch.bmm(data, W).squeeze() # erro

我有批处理数据,希望
dot()
添加到数据中。W是可训练参数。 如何在批处理数据和权重之间添加点

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
更新 这个怎么样

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

展开
W
tensor以匹配
data
tensor的形状。以下几点应该行得通

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)


编辑:您更新的代码是正确的。由于您正在将
W
转换为
Bxhid\u dimx1
,并且您的数据的形状为
bxxhid\u dim
,因此执行批处理矩阵乘法将导致
Bxdx1
,它本质上是
W
参数与
数据中所有行向量之间的点积(
dxhid\u dim
).

我更新了我的帖子。但是你的代码看起来比我好。我的代码也正确吗?