Python 比较numpy.ndarray中特定位置的元素

Python 比较numpy.ndarray中特定位置的元素,python,python-3.x,numpy,Python,Python 3.x,Numpy,我不知道标题是否描述了我的问题。我有一个从sigmoid激活函数获得的浮动列表 outputs = [[0.015161413699388504, 0.6720218658447266, 0.0024502829182893038, 0.21356457471847534, 0.002232735510915518, 0.026410426944494247], [0.006432057358324528, 0.0059209042228758335, 0

我不知道标题是否描述了我的问题。我有一个从sigmoid激活函数获得的浮动列表

  outputs = 
  [[0.015161413699388504,
  0.6720218658447266,
  0.0024502829182893038,
  0.21356457471847534,
  0.002232735510915518,
  0.026410426944494247],
 [0.006432057358324528,
  0.0059209042228758335,
  0.9866275191307068,
  0.004609372932463884,
  0.007315939292311668,
  0.010821194387972355],
 [0.02358204871416092,
  0.5838017225265503,
  0.005475651007145643,
  0.012086033821106,
  0.540218658447266,
  0.010054176673293114]] 
为了计算我的度量,我想说,如果任何神经元的输出值大于0.5,则假定注释属于类(多标签问题)。我可以很容易地使用
outputs=np.where(np.array(outputs)>=0.5,1,0)
但是,我想添加一个条件,只考虑更大的值,如果类γ5和任何其他类都有值>0.5(因为类5不能与其他类一起发生)。如何编写该条件? 在我的示例中,输出应为:

[[0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 1 0 0 0 0]]
而不是:

[[0 1 0 0 0 0]
 [0 0 1 0 0 0]
 [0 1 0 0 1 0]]

谢谢,

您可以编写一个自定义函数,然后使用
np将该函数应用于
输出中的每个子数组。
函数:

def choose_class(a):
    if (len(np.argwhere(a >= 0.5)) > 1) & (a[4] >= 0.5):
        return np.where(a == a.max(), 1, 0)
    return np.where(a >= 0.5, 1, 0)

outputs = np.apply_along_axis(choose_class, 1, outputs)
outputs
# array([[0, 1, 0, 0, 0, 0],
#       [0, 0, 1, 0, 0, 0],
#       [0, 1, 0, 0, 0, 0]])
另一个解决方案:

# Your example input array
out = np.array([[0.015, 0.672, 0.002, 0.213, 0.002, 0.026],
                [0.006, 0.005, 0.986, 0.004, 0.007, 0.010],
                [0.023, 0.583, 0.005, 0.012, 0.540, 0.010]])

# We get the desired result
val = (out>=0.5)*out//(out.max(axis=1))[:,None]
此解决方案执行以下操作:

  • 将所有小于0.5的值设置为零
  • 按行将最大值设置为1(如果此值大于等于0.5)

  • 对于简单的掩码,您不需要
    np.where

    mask = outputs >= 0.5
    
    如果需要整数而不是布尔值:

    mask = (outputs >= 0.5).view(np.uint8)
    
    要检查第五列,需要保留对周围原始数据的引用。您可以使用以下命令获得每个相关行中的最大屏蔽值:

    rows = np.flatnonzero(mask[:, 4])
    keep = (outputs[mask] * mask[rows]).argmax()
    
    然后,您可以清空行并仅设置最大值:

    mask[rows] = 0
    mask[rows, keep] = 1
    

    这不是很有效。如果你使用//而不是/,你不需要比较:整数除法会将所有值设置为零,最大值除外。假设每行的值>=0.5,负值不重要。我感谢“o”米是空的,但是再次感谢你。我真的很喜欢你的答案,所以我一直在寻找改进