Pytorch 排列后如何进行张量点运算

Pytorch 排列后如何进行张量点运算,pytorch,permute,tensordot,Pytorch,Permute,Tensordot,我有两个张量,A和B: A = torch.randn([32,128,64,12],dtype=torch.float64) B = torch.randn([64,12,64,12],dtype=torch.float64) C = torch.tensordot(A,B,([2,3],[0,1])) D = C.permute(0,2,1,3) # shape:[32,64,128,12] 张量D来自于“tensordot->permute”操作。如何实现新的操作f(),使f()之后的t

我有两个张量,A和B:

A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
张量D来自于“tensordot->permute”操作。如何实现新的操作f(),使f()之后的tensordot操作类似于:

您是否考虑过使用非常灵活的

D=torch.einsum('ijab,abkl->ikjl',A,B)
tensordot
的问题是,它先输出
A
的所有维度,然后输出
B
的所有维度,而您要寻找的(在排列时)是将
A
B
的维度“交错”在一起!我终于用上了“torch.einsum”。
A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)