Python 从numpy数组B中选择切片(如果从A中选择其他切片)
我有两个这样的阵列:Python 从numpy数组B中选择切片(如果从A中选择其他切片),python,numpy,Python,Numpy,我有两个这样的阵列: import numpy as np A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int) A = np.reshape(A, (2,2,3)) B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5]) B = np.reshape(B, (2,2,3)) print(repr(A)) # array([[[100, 100, 3
import numpy as np
A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int)
A = np.reshape(A, (2,2,3))
B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5])
B = np.reshape(B, (2,2,3))
print(repr(A))
# array([[[100, 100, 3],
# [ 0, 0, 0]],
# [[ 0, 0, 0],
# [100, 3, 5]]])
print(repr(B))
# array([[[ 3, 6, 2],
# [ 6, 3, 2]],
# [[100, 3, 2],
# [100, 100, 5]]])
# desired result
out = np.array([100, 100, 3, 0, 0, 0, 100, 3, 2, 100, 100, 5])
out = np.reshape(out, (2,2,3))
print(repr(out))
# array([[[100, 100, 3],
# [ 0, 0, 0]],
# [[100, 3, 2],
# [100, 100, 5]]])
我想做的是从B中选择2x3个切片,其中至少有一个值大于10。如果不满足此条件,我希望从A获得相应的切片,如下所示:
import numpy as np
A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int)
A = np.reshape(A, (2,2,3))
B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5])
B = np.reshape(B, (2,2,3))
print(repr(A))
# array([[[100, 100, 3],
# [ 0, 0, 0]],
# [[ 0, 0, 0],
# [100, 3, 5]]])
print(repr(B))
# array([[[ 3, 6, 2],
# [ 6, 3, 2]],
# [[100, 3, 2],
# [100, 100, 5]]])
# desired result
out = np.array([100, 100, 3, 0, 0, 0, 100, 3, 2, 100, 100, 5])
out = np.reshape(out, (2,2,3))
print(repr(out))
# array([[[100, 100, 3],
# [ 0, 0, 0]],
# [[100, 3, 2],
# [100, 100, 5]]])
我可以找到我想要的索引:
filt = ~np.all(B < 10, axis=2)
可能有一种更直接的方法,我怀疑是NumPy select,但我还没有弄清楚如何使尺寸合适。想法
import numpy as np
A = np.array([100, 100, 3, 0, 0, 0, 0, 0, 0, 100, 3, 5], dtype=int)
A = np.reshape(A, (2,2,3))
B = np.array([3, 6, 2, 6, 3, 2, 100, 3, 2, 100, 100, 5])
B = np.reshape(B, (2,2,3))
B[B<10] = A[B<10]
# out = B
使用numpy切片,可以轻松地比较和替换大小匹配数组之间的值。我希望这就是您想要的。您可以使用:
print(np.where(np.any(B > 10, axis=2)[..., None], B, A))
# [[[100 100 3]
# [ 0 0 0]]
# [[100 3 2]
# [100 100 5]]]
np.anyB>10,轴=2相当于您的过滤索引。由于在最后一个轴上缩小,因此将产生一个2,2数组,而a和B都是2,2,3,因此np.wherenp.anyB>10,axis=2,B,a将产生索引错误
幸运的是,np.where支持,因此您可以通过无索引插入大小为1的新最终轴,而np.where将有效地将其视为一个2、2、3数组,其中包含重复3次的filt索引。通过将keepdims=True传递给np.any以保留单例最终维度,可以实现相同的效果:
np.where(np.any(B > 10, axis=2, keepdims=1), B, A)
差不多,但不完全是B您的结果输出[13]:数组[[100100,3],[0,0,0],[100,0,0],[100100,5]]输出所需的结果输出[15]:数组[[100100,3],[0,0,0],[100,3,2],[100100,5]]你的解决方案丢失了第三行中的值。是的,这就是我在B之后得到的结果[B]在我的黑客解决方案中,结果是在数据数组的最内层进行过滤。假设它是一个二维单元格数组,而不是一个三维整数数组,我们在单元格上操作,而不是在整数上操作。这有帮助吗?