Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/343.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python scipy stats zmap函数的替代方法_Python_Performance_Scipy_Numba_Jax - Fatal编程技术网

Python scipy stats zmap函数的替代方法

Python scipy stats zmap函数的替代方法,python,performance,scipy,numba,jax,Python,Performance,Scipy,Numba,Jax,zmap函数的scipy stats模块是否有其他替代方案?我目前正在使用它来获得两个非常大的阵列的zmap分数,这需要相当长的时间 是否有任何库或替代方案可以提高其性能?或者甚至是另一种获得zmap函数功能的方法 您的想法和意见将不胜感激 下面是我的最小可复制代码: from scipy import stats import numpy as np FeatureData = np.random.rand(483, 1) goodData = np.random.rand(4640, 48

zmap函数的scipy stats模块是否有其他替代方案?我目前正在使用它来获得两个非常大的阵列的zmap分数,这需要相当长的时间

是否有任何库或替代方案可以提高其性能?或者甚至是另一种获得zmap函数功能的方法

您的想法和意见将不胜感激

下面是我的最小可复制代码:

from scipy import stats
import numpy as np

FeatureData = np.random.rand(483, 1)
goodData = np.random.rand(4640, 483)
FeatureNorm= stats.zmap(FeatureData, goodData)
下面是scipy stats.zmap在引擎盖下的功能:

def zmap(scores, compare, axis=0, ddof=0):
    scores, compare = map(np.asanyarray, [scores, compare])
    mns = compare.mean(axis=axis, keepdims=True)
    sstd = compare.std(axis=axis, ddof=ddof, keepdims=True)
    return (scores - mns) / sstd

关于如何针对我的用例优化它,有什么想法吗?我可以使用像numba或JAX这样的库来进一步增强这一点吗?

幸运的是,
zmap
代码非常简单。然而,numpy的开销将来自它必须实例化中间数组这一事实。如果使用
numba
jax
中提供的数值编译器,它可以融合这些操作并以较少的开销进行计算

不幸的是,NUBA不支持可选的参数<代码>平均和 STD,所以让我们看看JAX。以下是在Google Colab CPU运行时计算的scipy和原始numpy版本函数的基准,供参考:

将numpy导入为np
从scipy导入统计信息
FeatureData=np.random.rand(483,1)
goodData=np.random.rand(4640483)
%timeit stats.zmap(FeatureData、goodData)
#100圈,最佳3圈:每圈13.9毫秒
def np_zmap(分数、比较、轴=0、ddof=0):
分数,比较=地图(np.asanyarray,[分数,比较])
mns=比较。平均值(轴=轴,keepdims=真)
sstd=compare.std(axis=axis,ddof=ddof,keepdims=True)
返回(分数-mns)/sstd
%timeit np_zmap(功能数据、良好数据)
#100个回路,最佳3个:每个回路13.8毫秒
以下是在JAX中执行的等效代码,包括急切模式和JIT编译:

将jax.numpy作为jnp导入
从jax导入jit
def jnp_zmap(分数、比较、轴=0、ddof=0):
分数,比较=映射(jnp.asarray,[分数,比较])
mns=比较。平均值(轴=轴,keepdims=真)
sstd=compare.std(axis=axis,ddof=ddof,keepdims=True)
返回(分数-mns)/sstd
jit_jnp_zmap=jit(jnp_zmap)
FeatureData=jnp.array(FeatureData)
goodData=jnp.array(goodData)
%timeit jnp_zmap(FeatureData,goodData).block_直到_就绪()
#100圈,最佳3圈:每圈8.59毫秒
jit_jnp_zmap(FeatureData,goodData)#触发器编译
%timeit jit_jnp_zmap(FeatureData,goodData).block_直到_就绪()
#100圈,最佳3圈:每圈2.78毫秒
JIT编译版本大约比scipy或numpy代码快5倍。在Colab T4 GPU运行时上,编译版本获得另一个因子10:

%timeit jit\u jnp\u zmap(FeatureData,goodData)。阻止\u直到\u就绪()
1000个回路,最好为3个:每个回路286µs

如果这种操作是您分析中的一个瓶颈,那么像JAX这样的编译器可能是一个不错的选择。

您看过吗?如果你为你的问题添加一个最小的、可复制的例子,这里的用户将更容易为你提供具体的解决方案。我已经看过源代码,但不太确定如何优化它。我添加了一个最小的可复制代码示例。非常感谢!这正是我要找的。我之前尝试过这个版本,但忘记了我需要将numpy数组转换为JAX样式的数组。