Python numpy沿第一个和最后一个出现值所在的轴查找切片
我有一个带有整数值的3D numpy数组,定义如下:Python numpy沿第一个和最后一个出现值所在的轴查找切片,python,numpy,Python,Numpy,我有一个带有整数值的3D numpy数组,定义如下: import numpy as np x = np.random.randint(0, 100, (10, 10, 10)) 现在我要做的是沿着给定的轴(比如1)找到出现特定值的最后一个切片(或者第一个切片)。目前,我做了如下工作: first=None last=None val = 20 for i in range(len(x.shape[1]): slice = x[:, i, :] if len(slice[s
import numpy as np
x = np.random.randint(0, 100, (10, 10, 10))
现在我要做的是沿着给定的轴(比如1)找到出现特定值的最后一个切片(或者第一个切片)。目前,我做了如下工作:
first=None
last=None
val = 20
for i in range(len(x.shape[1]):
slice = x[:, i, :]
if len(slice[slice==val]) > 0:
if not first:
first = i
last = i
return first, last
这似乎有点不合音律,我想知道是否有一些
numpy
魔法可以做到这一点?您可能可以将其优化得更快,但以下是您搜索内容的矢量化版本:
axis = 1
mask = np.where(x==val)[axis]
first, last = np.amin(mask), np.amax(mask)
它首先使用
np查找数组中的元素val
,其中
,并返回沿所需轴的索引的min
和max
。根据您的问题,您需要检查是否存在任何此类有效切片,从而获得开始/首先、停止/最后索引。在没有任何有效切片的情况下,我们必须返回其中的None。那需要额外检查。此外,我们可以使用掩蔽
以高效的方式获得这些索引,如下所示-
def slice_info(x, val):
n = (x==val).any((0,2))
if n.any():
return n.argmax(), len(n)-n[::-1].argmax()-1
else:
return None,None
标杆管理
其他拟议解决方案:
时间安排-
# Same setup as in given sample
In [157]: np.random.seed(0)
...: x = np.random.randint(0, 100, (10, 10, 10))
In [158]: %timeit where_amin_amax(x, val=20)
...: %timeit slice_info(x, val=20)
15.1 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.63 µs ± 43.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# Bigger
In [159]: np.random.seed(0)
...: x = np.random.randint(0, 100, (100, 100, 100))
In [160]: %timeit where_amin_amax(x, val=20)
...: %timeit slice_info(x, val=20)
3.34 ms ± 31.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
691 µs ± 3.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
请注意,如果循环实际为0,则将首先替换
,因为if
语句将0
和None
都视为False
。建议的答案解决了这个问题。
# Same setup as in given sample
In [157]: np.random.seed(0)
...: x = np.random.randint(0, 100, (10, 10, 10))
In [158]: %timeit where_amin_amax(x, val=20)
...: %timeit slice_info(x, val=20)
15.1 µs ± 287 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
9.63 µs ± 43.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# Bigger
In [159]: np.random.seed(0)
...: x = np.random.randint(0, 100, (100, 100, 100))
In [160]: %timeit where_amin_amax(x, val=20)
...: %timeit slice_info(x, val=20)
3.34 ms ± 31.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
691 µs ± 3.69 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)