Python 有效地从函数填充数组
我想从一个函数构造一个2D数组,这样我就可以利用Python 有效地从函数填充数组,python,numpy,jax,Python,Numpy,Jax,我想从一个函数构造一个2D数组,这样我就可以利用jax.jit 我通常使用numpy执行此操作的方法是创建一个空数组,然后将该数组填充到位 xx = jnp.empty((num_a, num_b)) yy = jnp.empty((num_a, num_b)) zz = jnp.empty((num_a, num_b)) for ii_a in range(num_a): for ii_b in range(num_b): a = aa[ii_a, ii_b]
jax.jit
我通常使用numpy
执行此操作的方法是创建一个空数组,然后将该数组填充到位
xx = jnp.empty((num_a, num_b))
yy = jnp.empty((num_a, num_b))
zz = jnp.empty((num_a, num_b))
for ii_a in range(num_a):
for ii_b in range(num_b):
a = aa[ii_a, ii_b]
b = bb[ii_a, ii_b]
xyz = self.get_coord(a, b)
xx[ii_a, ii_b] = xyz[0]
yy[ii_a, ii_b] = xyz[1]
zz[ii_a, ii_b] = xyz[2]
为了在jax
中实现这一点,我尝试使用jax.opt.index\u更新
xx = xx.at[ii_a, ii_b].set(xyz[0])
yy = yy.at[ii_a, ii_b].set(xyz[1])
zz = zz.at[ii_a, ii_b].set(xyz[2])
这运行时没有错误,但当我尝试使用@jax.jit
装饰器时,速度非常慢(至少比纯python/numpy版本慢一个数量级)
使用jax
从函数填充多维数组的最佳方法是什么?jax有一个专门为此类应用程序设计的
只要您的get_coords
函数与JAX兼容(即是一个没有副作用的纯函数),您就可以在一行中完成这一点:
从jax导入vmap
xx,yy,zz=vmap(vmap(get_coord))(aa,bb)
这可以通过使用或函数有效地实现
使用矢量化的示例:
import jax.numpy as jnp
def get_coord(a, b):
return jnp.array([a, b, a+b])
f0 = jnp.vectorize(get_coord, signature='(),()->(i)')
f1 = jnp.vectorize(f0, excluded=(1,), signature='()->(i,j)')
xyz = f1(a,b)
vectorize
功能在发动机罩下使用vmap
,因此这应该完全等同于:
f0 = jax.vmap(get_coord, (None, 0))
f1 = jax.vmap(f0, (0, None))
使用vectorize
的优点是代码仍然可以在标准numpy中运行。缺点是代码不够简洁,而且由于包装器的存在,可能会产生少量开销。为了让它正常工作,我似乎需要在_轴中显式地添加,jax.vmap(jax.vmap(get_coord,(None,0)),(0,None))
。