Numpy `np.dot`在剩余轴上没有笛卡尔积
根据报告: 对于N维Numpy `np.dot`在剩余轴上没有笛卡尔积,numpy,array-broadcasting,Numpy,Array Broadcasting,根据报告: 对于N维dot是a最后一个轴与b倒数第二个轴的和积: dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) 我想计算a的最后一个轴和b的倒数第二个轴的和积,但不在其余轴上形成笛卡尔积,因为其余轴的形状相同。让我举例说明: a = np.random.normal(size=(11, 12, 13)) b = np.random.normal(size=(11, 12, 13, 13)) c = np.dot(a, b) c.shape # =
dot
是a
最后一个轴与b
倒数第二个轴的和积:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
我想计算a
的最后一个轴和b
的倒数第二个轴的和积,但不在其余轴上形成笛卡尔积,因为其余轴的形状相同。让我举例说明:
a = np.random.normal(size=(11, 12, 13))
b = np.random.normal(size=(11, 12, 13, 13))
c = np.dot(a, b)
c.shape # = (11, 12, 11, 12, 13)
但是我希望形状是(11、12、13)
。使用广播可以达到预期的效果
c = np.sum(a[..., None] * b, axis=-2)
c.shape # = (11, 12, 13)
但是我的数组相对较大,我希望使用并行BLAS实现的强大功能,它似乎不受np.sum
的支持,但受np.dot
的支持。关于如何实现这一点有什么想法吗?您可以使用-
您还可以使用:
这相当于Python 3.5+中的:
c = (a[..., None, :] @ b)[..., 0, :]
在性能上没有太大的差异-如果有什么np.einsum
对于您的示例阵列来说似乎稍微快一点:
In [1]: %%timeit a = np.random.randn(11, 12, 13); b = np.random.randn(11, 12, 13, 13)
....: np.einsum('...i,...ij->...j', a, b)
....:
The slowest run took 5.24 times longer than the fastest. This could mean that an
intermediate result is being cached.
10000 loops, best of 3: 26.7 µs per loop
In [2]: %%timeit a = np.random.randn(11, 12, 13); b = np.random.randn(11, 12, 13, 13)
np.matmul(a[..., None, :], b)[..., 0, :]
....:
10000 loops, best of 3: 28 µs per loop
我真的很喜欢
einsum
,但它不能并行使用BLAS。@TillHoffmann嗯,AFAIK,基于点的函数,如np.dot
和np.tensordot
使用BLAS
将不允许对齐轴,就像einsum
在这里我们保持前两个轴在a
和b
之间对齐。在性能方面,对于示例数据量,这种einsum
方法被证明是4x+
更好:)非常好,感谢您的澄清。对任何感兴趣的人来说,np.einsum(“…i,…ij->…j',a,b)
对维度是不可知的。
c = (a[..., None, :] @ b)[..., 0, :]
In [1]: %%timeit a = np.random.randn(11, 12, 13); b = np.random.randn(11, 12, 13, 13)
....: np.einsum('...i,...ij->...j', a, b)
....:
The slowest run took 5.24 times longer than the fastest. This could mean that an
intermediate result is being cached.
10000 loops, best of 3: 26.7 µs per loop
In [2]: %%timeit a = np.random.randn(11, 12, 13); b = np.random.randn(11, 12, 13, 13)
np.matmul(a[..., None, :], b)[..., 0, :]
....:
10000 loops, best of 3: 28 µs per loop