Python 两个三维张量之间的点积

Python 两个三维张量之间的点积,python,numpy,tensorflow,tensordot,Python,Numpy,Tensorflow,Tensordot,我有两个三维张量,张量A具有形状[B,N,S]和张量B也具有形状[B,N,S]。我想要得到的是第三个张量C,我希望它具有[B,B,N]形状,其中元素C[I,j,k]=np.dot(a[I,k,:],B[j,k,:]。我还想实现这是一种矢量化的方法 一些进一步的信息:两个张量A和B具有形状[批大小、数量向量、向量大小]。张量C应表示A批次中的每个元素与B批次中的每个元素之间的点积,即所有不同向量之间的点积 希望它足够清晰,并期待您的回答!尝试: C = np.diagonal( np.tensor

我有两个三维张量,张量
A
具有形状
[B,N,S]
和张量
B
也具有形状
[B,N,S]
。我想要得到的是第三个张量
C
,我希望它具有
[B,B,N]
形状,其中元素
C[I,j,k]=np.dot(a[I,k,:],B[j,k,:]
。我还想实现这是一种矢量化的方法

一些进一步的信息:两个张量
A
B
具有形状
[批大小、数量向量、向量大小]
。张量
C
应表示
A
批次中的每个元素与
B
批次中的每个元素之间的点积,即所有不同向量之间的点积

希望它足够清晰,并期待您的回答!

尝试:

C = np.diagonal( np.tensordot(A,B, axes=(2,2)), axis1=1, axis2=3)

解释

该解决方案由两个操作组成。首先,根据需要,a和B之间的张量积在其第三轴上。这将输出一个秩4张量,通过在轴1和轴3上取相等的索引,将其减少为秩3张量(你的
k
在你的符号中,请注意
tensordot
给出的轴顺序与你的数学不同)。这可以通过采用对角线来实现,就像你将矩阵缩减为其对角线项的向量时所做的那样。

我认为你可以使用如下方法:

使用下标
'ikm,jkm->ijk'
,您可以指定使用爱因斯坦约定减少的维度。此处命名为
'm'
的数组A和B的第三维将减少,就像
操作对向量所做的那样

In [331]: A=np.random.rand(100,200,300)                                                              
In [332]: B=A
建议的
einsum
,直接从

C[i,j,k] = np.dot(A[i,k,:], B[j,k,:] 
表达方式:

In [333]: np.einsum( 'ikm, jkm-> ijk', A, B).shape                                                   
Out[333]: (100, 100, 200)
In [334]: timeit np.einsum( 'ikm, jkm-> ijk', A, B).shape                                            
800 ms ± 25.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
matmul
在最后两个维度上做一个
dot
,并处理前导维度作为批次。在您的案例中,“k”是批次维度,“m”应该遵守
最后一个A和第二个到最后一个B
规则。因此重写
ikm,jkm…
以适应,并相应地转置
A
B

In [335]: np.einsum('kim,kmj->kij', A.transpose(1,0,2), B.transpose(1,2,0)).shape                     
Out[335]: (200, 100, 100)
In [336]: timeit np.einsum('kim,kmj->kij',A.transpose(1,0,2), B.transpose(1,2,0)).shape              
774 ms ± 22.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
性能差别不大。但现在使用
matmul

In [337]: (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                             
Out[337]: (100, 100, 200)
In [338]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                      
64.4 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
并验证值是否匹配(但如果形状匹配,则值通常不匹配)

我不会尝试测量内存使用情况,但时间的改善表明它也更好

在某些情况下,
einsum
被优化为使用
matmul
。这里的情况似乎不是这样,尽管我们可以使用它的参数。我有点惊讶于
matmul
做得这么好

===

我模模糊糊地回忆起另一个关于
matmul
的例子,当两个数组是相同的东西时,抄近路,
A@A
。我在这些测试中使用了
B=A

In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                      
60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [352]: B2=np.random.rand(100,200,300)                                                             
In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape                     
97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
但这只是一个微小的区别

In [356]: np.__version__                                                                             
Out[356]: '1.16.4'

My BLAS etc是标准的Linux,没有什么特别之处。

请再试一次,第一个版本的轴索引错误。如果您能解释一下,这将是非常有用的。一行程序没有多大帮助,即使它可以工作。与上面的
einsum
相比,这种方法现在非常低效。您正在创建一个比
C
然后取对角线。对于小数组,它至少比
einsum
慢10倍,而对于大数组,它可能慢几个数量级。@Brella正确,我刚刚对它们进行了基准测试,我想我的速度慢了10倍,加上更大的临时内存占用。不过,我不理解下一票。谢谢!你能翻译一下吗在矩阵乘法运算中晚了吗?@gorjan不确定我是否理解你想让我做什么?我不想使用像
einsum
这样的东西,而是想知道你的解决方案是否可以写成一个/多个矩阵乘法(我想可以).据我所知,
einsum
只是一种语法糖,它包装了对
matmul
@gorjan的一个或多个调用,我看到了,也许有一个简单的解决方案,但不幸的是,我似乎找不到一种不会进行额外计算的方法,这将与Learnin的
tensordot
解决方案类似g是一个消息,将数组转置为kim,kmj->kij,然后使用
@
很好的答案。我在
einsum
@
中得到了类似的计时(np版本1.15.3).您的进步是因为您使用的是np>1.16.0吗?我更新了
numpy
,现在我得到了类似的计时。可能是因为我的意思是与答案中的计时类似。我的numpy是1.16.4。我还使用了
B=A
,matmul在
中走了一条适度的捷径A@A
case,但这并不能解释大部分时间的差异是的,在更新numpy之后,我得到了和你答案中一样的时间安排
In [350]: timeit (A.transpose(1,0,2)@B.transpose(1,2,0)).transpose(1,2,0).shape                      
60.6 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [352]: B2=np.random.rand(100,200,300)                                                             
In [353]: timeit (A.transpose(1,0,2)@B2.transpose(1,2,0)).transpose(1,2,0).shape                     
97.4 ms ± 164 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [356]: np.__version__                                                                             
Out[356]: '1.16.4'