Python “Numpy”;其中;函数无法避免计算Sqrt(负)

Python “Numpy”;其中;函数无法避免计算Sqrt(负),python,numpy,Python,Numpy,似乎np.where函数首先计算所有可能的结果,然后再计算条件。这意味着,在我的例子中,它将计算-5,-4,-3,-2,-1的平方根,即使以后不再使用它 我的代码运行正常。但我的问题是警告。我避免使用循环来计算每个元素,因为它的运行速度比np.where慢得多 所以,在这里,我在问 有没有办法使np。在哪里首先评估条件 我能关掉这个特别的警告吗?怎么做 如果你有更好的建议,还有更好的方法 这里只是一个与我的真实代码相对应的简短示例代码,它是巨大的。但本质上也有同样的问题 输入: import n

似乎
np.where
函数首先计算所有可能的结果,然后再计算条件。这意味着,在我的例子中,它将计算-5,-4,-3,-2,-1的平方根,即使以后不再使用它

我的代码运行正常。但我的问题是警告。我避免使用循环来计算每个元素,因为它的运行速度比
np.where
慢得多

所以,在这里,我在问

  • 有没有办法使
    np。在哪里
    首先评估条件
  • 我能关掉这个特别的警告吗?怎么做
  • 如果你有更好的建议,还有更好的方法
  • 这里只是一个与我的真实代码相对应的简短示例代码,它是巨大的。但本质上也有同样的问题

    输入:

    import numpy as np
    
    c=np.arange(10)-5
    d=np.where(c>=0, np.sqrt(c) ,c )
    
    输出:

    RuntimeWarning: invalid value encountered in sqrt
    d=np.where(c>=0,np.sqrt(c),c)
    

    这是你第二个问题的答案

    是的,您可以关闭警告。使用模块


    一种解决方案是不使用
    np.where
    ,而是使用索引

    c = np.arange(10)-5
    d = c.copy()
    c_positive = c > 0
    d[c_positive] = np.sqrt(c[c_positive])
    
    有一种更好的方法可以做到这一点。让我们看看你的代码在做什么,看看为什么。

    np.其中
    接受三个数组作为输入。数组不支持延迟求值

    d = np.where(c >= 0, np.sqrt(c), c)
    
    因此,这一行相当于

    a = (c >= 0)
    b = np.sqrt(c)
    d = np.where(a, b, c)
    
    请注意,输入是在调用
    之前立即计算的

    幸运的是,您根本不需要使用
    where
    。相反,只需使用布尔掩码:

    mask = (c >= 0)
    d = np.empty_like(c)
    d[mask] = np.sqrt(c[mask])
    d[~mask] = c[~mask]
    
    如果您希望看到大量的负片,您可以复制所有元素,而不仅仅是负片:

    d = c.copy()
    d[mask] = np.sqrt(c[mask])
    
    更好的解决方案可能是使用屏蔽阵列:

    d = np.ma.masked_array(c, c < 0)
    d = np.ma.sqrt(d)
    
    d=np.ma.masked_数组(c,c<0)
    d=np.ma.sqrt(d)
    

    要在屏蔽部分不变的情况下访问整个数据数组,请使用
    d.data

    np.sqrt
    ufunc
    并接受
    where
    参数。在这种情况下,它可用作遮罩:

    In [61]: c = np.arange(10)-5.0
    In [62]: d = c.copy()
    In [63]: np.sqrt(c, where=c>=0, out=d);
    In [64]: d
    Out[64]: 
    array([-5.        , -4.        , -3.        , -2.        , -1.        ,
            0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ])
    

    np.where
    情况相反,这并不计算~where元素处的函数。

    根据numpy,语句
    d=np.where(c>=0,np.sqrt(c),c)
    相当于
    [sqcv if cond else cv for(cond,sqcv,cv)In zip(c>=0,np.sqrt(c),c)]
    。换句话说,术语
    np.sqrt(c)
    的计算与
    c>=0的条件无关。您好,谢谢。它工作得很好。但当我尝试你最后的方法时。我仍然需要负值
    d1=np.ma.masked_数组(c.copy(),c0)
    。然后d=d1+d2。它不起作用。你能再帮我一点忙吗?我在末尾加了一张便条
    In [61]: c = np.arange(10)-5.0
    In [62]: d = c.copy()
    In [63]: np.sqrt(c, where=c>=0, out=d);
    In [64]: d
    Out[64]: 
    array([-5.        , -4.        , -3.        , -2.        , -1.        ,
            0.        ,  1.        ,  1.41421356,  1.73205081,  2.        ])