Python 优化此功能--numpy广播问题

Python 优化此功能--numpy广播问题,python,numpy,Python,Numpy,我有一个函数contains,它检查给定的2D数组u,如果[min,max]框包含u的每一行。如果需要,我需要它重塑u,但是u的值的数量将始终是d的倍数(可以是零) 我正在使用下面的代码片段。此函数可以运行数千次。能产生更快的代码吗?如果你这么认为的话,有什么建议吗 import numpy as np def contains(u, min, max, dim, strict = True): u = np.array(u).reshape(-1 ,dim) if stric

我有一个函数
contains
,它检查给定的2D数组
u
,如果[min,max]框包含
u
的每一行。如果需要,我需要它重塑
u
,但是
u
的值的数量将始终是
d
的倍数(可以是零)

我正在使用下面的代码片段。此函数可以运行数千次。能产生更快的代码吗?如果你这么认为的话,有什么建议吗

import numpy as np

def contains(u, min, max, dim, strict = True):
    u = np.array(u).reshape(-1 ,dim)
    if strict:
        return np.all((u > min) & (u < max), axis=1)
    else:
        return np.all((u >= min) & (u <= max), axis=1)

# Usage examples : 
d = 4
min = np.random.uniform(size=d)*1/2
max = np.random.uniform(size=d)*1/2+1/2
u1 = np.random.uniform(size=d)
u2 = np.random.uniform(size=(100,d))
u3 = u2[np.repeat(False,100)]

contains(u1,min,max,d) # should return a boolean array of shape (1,)
contains(u2,min,max,d) # shape (100,)
contains(u3,min,max,d) # shape (0,)
将numpy导入为np
def包含(u、min、max、dim、strict=True):
u=np.阵列(u).重塑(-1,尺寸)
如果严格:
返回np.all((u>min)和(u返回np.all((u>=min)和(u试着加速,阅读更多

从numba导入jit
@jit(nopython=True)
def包含(u、min、max、dim、strict=True):
u=np.阵列(u).重塑(-1,尺寸)
如果严格:
返回np.all((u>min)和(u返回np.all((u>=min)和(u试着加速,阅读更多

从numba导入jit
@jit(nopython=True)
def包含(u、min、max、dim、strict=True):
u=np.阵列(u).重塑(-1,尺寸)
如果严格:
返回np.all((u>min)和(u=min)和(u)(编辑的:修复注释中@max9111提出的计时测量问题,并包括
numexpr
-修改的解决方案)

瓶颈最终会出现在
np.all()
调用中。 使用Numba可以加快速度,如下所示:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] < arr[i, j] < b_arr[j]:
                result[i] = False
                break
    return result
可获得以下基准:

这表明Numba解决方案始终是最快的。 相反,使用
numexpr
似乎对所研究的参数范围不利

(提供完整基准测试)

已编辑):修复@max9111在评论中提出的计时测量问题,并包括
numexpr
-修改的解决方案)

瓶颈最终会出现在
np.all()
调用中。 使用Numba可以加快速度,如下所示:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def contains_nb(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    n = arr.shape[0]
    result = np.ones(n, dtype=np.bool8)
    for i in range(n):       
        for j in range(m):
            if not a_arr[j] < arr[i, j] < b_arr[j]:
                result[i] = False
                break
    return result
可获得以下基准:

这表明Numba解决方案始终是最快的。 相反,使用
numexpr
似乎对所研究的参数范围不利



(提供完整的基准测试)

此函数的多个调用可以组合吗?有些已经是:/但是实际上有些不是。我将对此进行研究。例如使用BroadcTable最小/最大参数?请注意,改进工作代码的请求通常更受欢迎;在堆栈溢出方面,我们重点关注具有特定、狭窄、众所周知问题的程序“X太慢”可以是具体的和狭隘的,但前提是你知道你的问题来自于什么特定的函数调用,并且已经知道它可以毫无疑问地得到改进。@CharlesDuffy,速度改进问题是常规问题,至少在
[numpy]上是如此
tag。没有那么多的
numpy
关注CR。CR也倾向于更挑剔代码和示例的完整性。因此,解决方案往往侧重于更好的整体数组操作,以及编译工具,如
numba
cython
。CR解决方案往往侧重于风格和组织。与那些
包含的
调用?不同的min,max参数?在numpy中,我们通常通过对整个数组或更高的dumensions进行操作来获得最好的加速。调整任何一个调用都不会带来很大的收益。这个函数的多个调用可以组合吗?有些已经是:/但是实际上,有些不是。我会研究这个问题。例如BroadcTable最小/最大参数?请注意,改进工作代码的请求通常更受欢迎;在堆栈溢出方面,我们重点关注那些存在特定、狭窄、众所周知的问题的程序。“X太慢了”可以是具体和狭义的,但前提是您知道问题来自哪个特定函数调用,并且已经知道它可以毫无疑问地得到改进。@CharlesDuffy,速度改进问题是如此常规-至少在
[numpy]中是如此
tag。没有那么多的
numpy
关注CR。CR也倾向于更挑剔代码和示例的完整性。因此,解决方案往往侧重于更好的整体数组操作,以及编译工具,如
numba
cython
。CR解决方案往往侧重于风格和组织。与那些
包含
调用?不同的最小、最大参数?在numpy中,我们通常通过对整个阵列或更高的密度进行操作来获得最佳加速。调整任何一个调用都不会带来很大的收益。感谢基准测试。但是,关于您的基准测试,
n
的值我倾向于要求numpy版本!哟ur计时有问题。使用像这样的默认值
包含\u nb(arr,a\u arr=a,b\u arr=b)
并调用像
包含\u nb(输入\u arr)这样的函数
导致每次调用都要重新编译。每次调用的开销约为270µs,这完全决定了小输入大小的计时。如果显式提供输入,则大约为1µs。@max9111感谢您发现了这一点,我已在编辑中对其进行了修复。@Irnv请查看更新的基准测试。也许您需要重新编译在基于Numba的方法下。本质上,代码没有改变,它只是基准测试的方式。感谢基准测试。然而,
n
的值,关于您的基准测试,我倾向于要求numpy版本!您的计时有问题。使用像这样的默认值
包含\nb(arr,a\u arr=a,b\u arr=b)
调用像
这样的函数contains\u nb(input\u arr)
会在每次调用时导致重新编译。这会导致
import numpy as np


def contains_np(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    return np.all((arr >= a_arr) & (arr <= b_arr), axis=1)
import numpy as np
import numexpr as ne


def contains_ne(arr, a_arr, b_arr):
    m = a_arr.size
    arr = arr.reshape(-1, m)
    result = ne.evaluate('(arr >= a_arr) & (arr <= b_arr)')
    return np.all(result, axis=1)