Numpy 条件nd argmin:如何找到多维数组子集的min的坐标?

Numpy 条件nd argmin:如何找到多维数组子集的min的坐标?,numpy,Numpy,我知道我可以使用argmin和unravel_index来查找数据数组中最小值的索引,但是如果我想查找最小的非零元素,或者非NaN的最小元素,该怎么办?这应该可以工作(条件是data!=0或~np.isnan(data)) 这里有一种使用平坦索引的方法- def flatnonzero_based(a,condition): # condition = a!= or ~np.isnan(a) idx = np.flatnonzero(condition) return np.un

我知道我可以使用
argmin
unravel_index
来查找数据数组中最小值的索引,但是如果我想查找最小的非零元素,或者非NaN的最小元素,该怎么办?

这应该可以工作(条件是data!=0或~np.isnan(data))


这里有一种使用平坦索引的方法-

def flatnonzero_based(a,condition): # condition = a!= or ~np.isnan(a)
    idx = np.flatnonzero(condition)
    return np.unravel_index(idx[np.take(a, idx).argmin()], a.shape)
基准测试

接近-

def flatnonzero_based(a,condition): # Proposed soln
    idx = np.flatnonzero(condition)
    return np.unravel_index(idx[np.take(a, idx).argmin()], a.shape)

def where_based(a, condition):  # @Paul Panzer's soln
    nz = np.where(condition)
    return np.array(nz)[:, np.argmin(a[nz])]
时间安排和核查-

In [233]: a = np.random.rand(40,50,30)

In [234]: nan_idx = np.random.choice(range(a.size), size = a.size//100, replace=0)

In [235]: a.ravel()[nan_idx] = np.nan

In [236]: condition = ~np.isnan(a)

In [237]: where_based(a, condition)
Out[237]: array([16, 10,  8])

In [238]: flatnonzero_based(a, condition)
Out[238]: (16, 10, 8)

In [239]: %timeit where_based(a, condition)
1000 loops, best of 3: 877 µs per loop

In [240]: %timeit flatnonzero_based(a, condition)
10000 loops, best of 3: 143 µs per loop
使用
4D
数据-

In [255]: a = np.random.rand(40,50,30,30)

In [256]: nan_idx = np.random.choice(range(a.size), size = a.size//100, replace=0)

In [257]: a.ravel()[nan_idx] = np.nan

In [258]: condition = ~np.isnan(a)

In [259]: where_based(a, condition)
Out[259]: array([34, 14,  5, 10])

In [260]: flatnonzero_based(a, condition)
Out[260]: (34, 14, 5, 10)

In [261]: %timeit where_based(a, condition)
10 loops, best of 3: 64.9 ms per loop

In [262]: %timeit flatnonzero_based(a, condition)
100 loops, best of 3: 5.32 ms per loop
合并-


对于NaN案例,这里有一个非常好的建议。在我的帖子中将它添加到计时测试中,看起来非常有效@Demetri P发布的解决方案对你有用吗?@Divakar是的,我一定会接受一个解决方案。
In [255]: a = np.random.rand(40,50,30,30)

In [256]: nan_idx = np.random.choice(range(a.size), size = a.size//100, replace=0)

In [257]: a.ravel()[nan_idx] = np.nan

In [258]: condition = ~np.isnan(a)

In [259]: where_based(a, condition)
Out[259]: array([34, 14,  5, 10])

In [260]: flatnonzero_based(a, condition)
Out[260]: (34, 14, 5, 10)

In [261]: %timeit where_based(a, condition)
10 loops, best of 3: 64.9 ms per loop

In [262]: %timeit flatnonzero_based(a, condition)
100 loops, best of 3: 5.32 ms per loop
In [267]: np.unravel_index(np.nanargmin(a), a.shape)
Out[267]: (34, 14, 5, 10)

In [268]: %timeit np.unravel_index(np.nanargmin(a), a.shape)
100 loops, best of 3: 4.54 ms per loop