Python 获取2D numpy数组中大于阈值的元素的列数
我有一个这样的数组,希望返回值超过阈值0.6的每一行的列号:Python 获取2D numpy数组中大于阈值的元素的列数,python,numpy,Python,Numpy,我有一个这样的数组,希望返回值超过阈值0.6的每一行的列号: X = array([[ 0.16, 0.40, 0.61, 0.48, 0.20], [ 0.42, 0.79, 0.64, 0.54, 0.52], [ 0.64, 0.64, 0.24, 0.63, 0.43], [ 0.33, 0.54, 0.61, 0.43, 0.29], [ 0.25, 0.56, 0.42, 0.69,
X = array([[ 0.16, 0.40, 0.61, 0.48, 0.20],
[ 0.42, 0.79, 0.64, 0.54, 0.52],
[ 0.64, 0.64, 0.24, 0.63, 0.43],
[ 0.33, 0.54, 0.61, 0.43, 0.29],
[ 0.25, 0.56, 0.42, 0.69, 0.62]])
结果将是:
[[2],
[1, 2],
[0, 1, 3],
[2],
[3, 4]]
有没有更好的方法通过双for循环来实现这一点
def get_column_over_threshold(data, threshold):
column_numbers = [[] for x in xrange(0,len(data))]
for sample in data:
for i, value in enumerate(data):
if value >= threshold:
column_numbers[i].extend(i)
return topic_predictions
使用
np.where
获取行、列索引,然后使用np.split
获取列索引列表作为数组输出-
In [18]: r,c = np.where(X>0.6)
In [19]: np.split(c,np.flatnonzero(r[:-1] != r[1:])+1)
Out[19]: [array([2]), array([1, 2]), array([0, 1, 3]), array([2]), array([3, 4])]
为了使它更通用,可以处理没有任何匹配的行,我们可以循环通过从np.where
获得的列索引,并分配到一个初始化数组中,如下所示-
def col_indices_per_row(X, thresh):
mask = X>thresh
r,c = np.where(mask)
out = np.empty(len(X), dtype=object)
grp_idx = np.r_[0,np.flatnonzero(r[:-1] != r[1:])+1,len(r)]
valid_rows = r[np.r_[True,r[:-1] != r[1:]]]
for (row,i,j) in zip(valid_rows,grp_idx[:-1],grp_idx[1:]):
out[row] = c[i:j]
return out
样本运行-
In [92]: X
Out[92]:
array([[0.16, 0.4 , 0.61, 0.48, 0.2 ],
[0.42, 0.79, 0.64, 0.54, 0.52],
[0.1 , 0.1 , 0.1 , 0.1 , 0.1 ],
[0.33, 0.54, 0.61, 0.43, 0.29],
[0.25, 0.56, 0.42, 0.69, 0.62]])
In [93]: col_indices_per_row(X, thresh=0.6)
Out[93]:
array([array([2]), array([1, 2]), None, array([2]), array([3, 4])],
dtype=object)
对于每一行,您可以询问元素大于0.6的索引:
result = [where(row > 0.6) for row in X]
这将执行您想要的计算,但是结果
的格式有点不方便,因为在本例中的结果是一个大小为1的元组
,包含一个带有索引的NumPy数组。我们可以用flatnonzero
替换where
,直接获取数组,而不是元组。为了获得列表列表,我们将此数组显式转换为列表:
result = [list(flatnonzero(row > 0.6)) for row in X]
(在上面的代码中,我假设您使用了numpy import*的)