Python 为什么JAX与numpy中的这个函数速度较慢?

Python 为什么JAX与numpy中的这个函数速度较慢?,python,performance,numpy,optimization,jax,Python,Performance,Numpy,Optimization,Jax,我有下面的numpy函数,如下所示,我正试图通过使用JAX进行优化,但无论出于什么原因,它都比较慢 有人能指出我能做些什么来提高性能吗?我怀疑这与Cg_new的列表理解有关,但将其分解并不能在JAX中进一步提高性能 import numpy as np def testFunction_numpy(C, Mi, C_new, Mi_new): Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0]))) Cg_new = np.zeros

我有下面的numpy函数,如下所示,我正试图通过使用JAX进行优化,但无论出于什么原因,它都比较慢

有人能指出我能做些什么来提高性能吗?我怀疑这与Cg_new的列表理解有关,但将其分解并不能在JAX中进一步提高性能

import numpy as np 

def testFunction_numpy(C, Mi, C_new, Mi_new):
    Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = np.zeros((1, len(Mi[0])))
    invertCsensor_new = np.linalg.inv(C_new)

    Wg_new = np.dot(invertCsensor_new, Mi_new)
    Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
以下是JAX的等价物:

import jax.numpy as jnp
import numpy as np
import jax

def testFunction_JAX(C, Mi, C_new, Mi_new):
    Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
    Cg_new = jnp.zeros((1, len(Mi[0])))
    invertCsensor_new = jnp.linalg.inv(C_new)

    Wg_new = jnp.dot(invertCsensor_new, Mi_new)
    Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))] 

    return C_new, Mi_new, Wg_new, Cg_new

C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)

C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)

jitter = jax.jit(testFunction_JAX) 

%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop

当JAX-jit编译遇到Python控制流(包括列表理解)时,它会有效地使循环平坦化并将整个操作序列分段。这可能会导致jit编译时间变慢和代码不理想。幸运的是,函数中的列表理解很容易用本机numpy广播表示。此外,您还可以进行其他两项改进:

  • 在计算它们之前,无需向前声明
    Wg_new
    Cg_new
  • 计算
    点(inv(A),B)
    时,使用
    np.linalg.solve
    比显式计算逆运算更高效、更精确
对numpy和JAX版本进行这三项改进,结果如下:

def testFunction\u numpy\u v2(C、Mi、C\u new、Mi\u new):
Wg_new=np.linalg.solve(C_new,Mi_new)
Cg_new=-0.5*(Mi_new.conj()*Wg_new).sum(0)
返回C_new、Mi_new、Wg_new、Cg_new
@jax.jit
def testFunction_JAX_v2(C,Mi,C_new,Mi_new):
Wg_new=jnp.linalg.solve(C_new,Mi_new)
Cg_new=-0.5*(Mi_new.conj()*Wg_new).sum(0)
返回C_new、Mi_new、Wg_new、Cg_new
%timeit testFunction\u numpy\u v2(C、Mi、C\u新建、Mi\u新建)
#1000个回路,最佳3个:每个回路1.11毫秒
%timeit测试函数_JAX_v2(C_JAX,Mi_JAX,C_new_JAX,Mi_new_JAX)
#1000个回路,最佳3个:每个回路1.35毫秒

由于改进了实现,这两个功能都比以前快了一点。然而,您会注意到,这里JAX仍然比numpy慢;这在某种程度上是意料之中的,因为对于这种简单程度的函数,JAX和numpy都有效地生成了在CPU体系结构上执行的相同短系列的BLAS和LAPACK调用。与numpy的参考实现相比,没有太多的改进空间,而且使用如此小的阵列,JAX的开销是显而易见的。

Wow再次感谢您!numpy的优化版本确实非常方便。