Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/353.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python/NumPy子数组的首次出现_Python_Numpy_Arrays - Fatal编程技术网

Python/NumPy子数组的首次出现

Python/NumPy子数组的首次出现,python,numpy,arrays,Python,Numpy,Arrays,在Python或NumPy中,找出子数组第一次出现的最佳方法是什么 例如,我有 a = [1, 2, 3, 4, 5, 6] b = [2, 3, 4] 找出b在a中的位置的最快方法是什么(从运行时角度看)?我知道对于字符串来说这非常简单,但是对于列表或numpy ndarray呢 非常感谢 [编辑]我更喜欢numpy解决方案,因为根据我的经验,numpy矢量化比Python列表理解快得多。同时,大数组是巨大的,所以我不想把它转换成字符串;那太长了 另一种尝试,但我相信有更多的python&更

在Python或NumPy中,找出子数组第一次出现的最佳方法是什么

例如,我有

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
找出b在a中的位置的最快方法是什么(从运行时角度看)?我知道对于字符串来说这非常简单,但是对于列表或numpy ndarray呢

非常感谢


[编辑]我更喜欢numpy解决方案,因为根据我的经验,numpy矢量化比Python列表理解快得多。同时,大数组是巨大的,所以我不想把它转换成字符串;那太长了

另一种尝试,但我相信有更多的python&更有效的方法来做到这一点

def array_match(a, b): for i in xrange(0, len(a)-len(b)+1): if a[i:i+len(b)] == b: return i return None
以下代码应该可以工作:

[x for x in xrange(len(a)) if a[x:x+len(b)] == b]

返回模式开始时的索引。

我假设您正在寻找特定于numpy的解决方案,而不是简单的列表理解或for循环。一种简单的方法是使用该技术搜索适当大小的窗口

这种方法简单、正确,并且比任何纯Python解决方案都快得多。对于许多用例来说,它应该足够了。然而,由于一些原因,这不是可能的最有效的方法。有关更复杂但在预期情况下渐近最优的方法,请参阅中基于
numba
的实现

以下是滚动窗口功能:

>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 
然后你可以做类似的事情

>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)
要使其真正有用,您必须使用
all
沿轴1减小它:

>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)
然后你可以使用它,但是你可以使用一个布尔数组。获取索引的一种简单方法:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])
对于列表,您可以调整其中一个迭代器以使用类似的方法

对于非常大的阵列和子阵列,可以这样节省内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)
另一方面,这可能会稍微慢一些。

可以调用tostring()方法将数组转换为字符串,然后可以使用快速字符串搜索。当您有许多子阵列要检查时,此方法可能会更快

import numpy as np

a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize

基于卷积的方法,应该比基于
的方法更节省内存:

def find_subsequence(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq,
                                       subseq, mode='valid') == target)[0]
    # some of the candidates entries may be false positives, double check
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    return candidates[mask]
对于非常大的阵列,可能无法使用
stride\u tricks
方法,但此方法仍然有效:

haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
    haystack[idx:idx+100] = needle

In [3]: find_subsequence(haystack, needle)
Out[3]: 
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
       961100, 973481], dtype=int64)

In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True

In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop

这里有一个相当直截了当的选择:

def first_subarray(full_array, sub_array):
    n = len(full_array)
    k = len(sub_array)
    matches = np.argwhere([np.all(full_array[start_ix:start_ix+k] == sub_array) 
                   for start_ix in range(0, n-k+1)])
    return matches[0]
然后使用原始a,b向量,我们得到:

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
first_subarray(a, b)
Out[44]: 
array([1], dtype=int64)

快速比较三种建议的解决方案(随机创建向量的平均100次迭代时间):

这将导致(在我的旧机器上):


首先,将列表转换为字符串

a = ''.join(str(i) for i in a)
b = ''.join(str(i) for i in b)
转换为字符串后,您可以使用以下字符串函数轻松找到子字符串的索引

a.index(b)
干杯

编辑以包含更深入的讨论、更好的代码和更多基准)


总结 对于原始速度和效率,可以使用经典算法之一的Cython或Numba加速版本(当输入分别是Python序列或NumPy数组时)

建议的方法是:

  • find_kmp\u cy()
    用于Python序列(
    list
    tuple
    等)
  • find\u kmp\u nb()
    用于NumPy数组
其他有效的方法是
find_-rk_-cy()
find_-rk_-nb()
,它们的内存效率更高,但不能保证在线性时间内运行

如果Cython/Numba不可用,那么对于大多数用例来说,
find_kmp()
find_rk()
都是一个很好的全面解决方案,尽管在一般情况下和Python序列中,以某种形式,特别是
find_pivot()
的天真方法可能更快。对于NumPy阵列,
find_conv()
(from)优于任何非加速的天真方法

(完整代码如下,和。)


理论 这是计算机科学中的一个经典问题,称为字符串搜索或字符串匹配问题。 基于两个嵌套循环的朴素方法的计算复杂度平均为
O(n+m)
,但最坏的情况是
O(nm)
。 多年来,已经开发了许多可确保更好的最坏情况性能的方法

在经典算法中,最适合通用序列的算法(因为它们不依赖字母表)是:

  • 天真算法(基本上由两个嵌套循环组成)
最后一种算法的效率依赖于a的计算,因此可能需要一些关于输入的额外知识以获得最佳性能。 最终,它最适合于同质数据,例如数字数组。 当然,Python中数字数组的一个显著例子是NumPy数组

评论
  • 这种幼稚的算法非常简单,可以在Python中以不同程度的运行时速度进行不同的实现
  • 其他算法在可通过语言技巧优化的方面灵活性较差
  • Python中的显式循环可能是速度瓶颈,可以使用几种技巧在解释器外部执行循环
  • 特别擅长于加速泛型Python代码的显式循环
  • 特别擅长加速NumPy数组上的显式循环
  • 这是一个很好的生成器用例,因此所有代码都将使用这些函数,而不是常规函数

Python序列(
list
tuple
等) 基于天真算法
  • find\u loop()
    find\u loop\u cy()
    find\u loop\u nb()
    ,它们分别是纯Python、Cython和Numba JITing中仅显式循环的实现。注意Numba版本中的
    forceobj=True
    ,这是必需的,因为我们使用的是Python对象输入
  • find_all()
    将内部循环替换为
    import time
    import collections
    import numpy as np
    
    
    def function_1(seq, sub):
        # direct comparison
        seq = list(seq)
        sub = list(sub)
        return [i for i in range(len(seq) - len(sub)) if seq[i:i+len(sub)] == sub]
    
    def function_2(seq, sub):
        # Jamie's solution
        target = np.dot(sub, sub)
        candidates = np.where(np.correlate(seq, sub, mode='valid') == target)[0]
        check = candidates[:, np.newaxis] + np.arange(len(sub))
        mask = np.all((np.take(seq, check) == sub), axis=-1)
        return candidates[mask]
    
    def function_3(seq, sub):
        # HYRY solution
        return seq.tostring().index(sub.tostring())//seq.itemsize
    
    
    # --- assessment time performance
    N = 100
    
    seq = np.random.choice([0, 1, 2, 3, 4, 5, 6], 3000)
    sub = np.array([1, 2, 3])
    
    tim = collections.OrderedDict()
    tim.update({function_1: 0.})
    tim.update({function_2: 0.})
    tim.update({function_3: 0.})
    
    for function in tim.keys():
        for _ in range(N):
            seq = np.random.choice([0, 1, 2, 3, 4], 3000)
            sub = np.array([1, 2, 3])
            start = time.time()
            function(seq, sub)
            end = time.time()
            tim[function] += end - start
    
    timer_dict = collections.OrderedDict()
    for key, val in tim.items():
        timer_dict.update({key.__name__: val / N})
    
    print(timer_dict)
    
    OrderedDict([
    ('function_1', 0.0008518099784851074), 
    ('function_2', 8.157730102539063e-05), 
    ('function_3', 6.124973297119141e-06)
    ])
    
    a = ''.join(str(i) for i in a)
    b = ''.join(str(i) for i in b)
    
    a.index(b)
    
    def find_loop(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            found = True
            for j in range(m):
                if seq[i + j] != subseq[j]:
                    found = False
                    break
            if found:
                yield i
    
    %%cython -c-O3 -c-march=native -a
    #cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
    
    
    def find_loop_cy(seq, subseq):
        cdef Py_ssize_t n = len(seq)
        cdef Py_ssize_t m = len(subseq)
        for i in range(n - m + 1):
            found = True
            for j in range(m):
                if seq[i + j] != subseq[j]:
                    found = False
                    break
            if found:
                yield i
    
    find_loop_nb = nb.jit(find_loop, forceobj=True)
    find_loop_nb.__name__ = 'find_loop_nb'
    
    def find_all(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            if all(seq[i + j] == subseq[j] for j in range(m)):
                yield i
    
    def find_slice(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            if seq[i:i + m] == subseq:
                yield i
    
    def find_mix(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            if seq[i] == subseq[0] and seq[i:i + m] == subseq:
                yield i
    
    def find_mix2(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                    and seq[i:i + m] == subseq:
                yield i
    
    def index_all(seq, item, start=0, stop=-1):
        try:
            n = len(seq)
            if n > 0:
                start %= n
                stop %= n
                i = start
                while True:
                    i = seq.index(item, i)
                    if i <= stop:
                        yield i
                        i += 1
                    else:
                        return
            else:
                return
        except ValueError:
            pass
    
    
    def find_pivot(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if m > n:
            return
        for i in index_all(seq, subseq[0], 0, n - m):
            if seq[i:i + m] == subseq:
                yield i
    
    def find_pivot2(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if m > n:
            return
        for i in index_all(seq, subseq[0], 0, n - m):
            if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
                yield i
    
    def find_kmp(seq, subseq):
        n = len(seq)
        m = len(subseq)
        # : compute offsets
        offsets = [0] * m
        j = 1
        k = 0
        while j < m: 
            if subseq[j] == subseq[k]: 
                k += 1
                offsets[j] = k
                j += 1
            else: 
                if k != 0: 
                    k = offsets[k - 1] 
                else: 
                    offsets[j] = 0
                    j += 1
        # : find matches
        i = j = 0
        while i < n: 
            if seq[i] == subseq[j]: 
                i += 1
                j += 1
            if j == m:
                yield i - j
                j = offsets[j - 1] 
            elif i < n and seq[i] != subseq[j]: 
                if j != 0: 
                    j = offsets[j - 1] 
                else: 
                    i += 1
    
    %%cython -c-O3 -c-march=native -a
    #cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
    
    
    def find_kmp_cy(seq, subseq):
        cdef Py_ssize_t n = len(seq)
        cdef Py_ssize_t m = len(subseq)
        # : compute offsets
        offsets = [0] * m
        cdef Py_ssize_t j = 1
        cdef Py_ssize_t k = 0
        while j < m: 
            if subseq[j] == subseq[k]: 
                k += 1
                offsets[j] = k
                j += 1
            else: 
                if k != 0: 
                    k = offsets[k - 1] 
                else: 
                    offsets[j] = 0
                    j += 1
        # : find matches
        cdef Py_ssize_t i = 0
        j = 0
        while i < n: 
            if seq[i] == subseq[j]: 
                i += 1
                j += 1
            if j == m:
                yield i - j
                j = offsets[j - 1] 
            elif i < n and seq[i] != subseq[j]: 
                if j != 0: 
                    j = offsets[j - 1] 
                else: 
                    i += 1
    
    def find_rk(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if seq[:m] == subseq:
            yield 0
        hash_subseq = sum(hash(x) for x in subseq)  # compute hash
        curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
        for i in range(1, n - m + 1):
            curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
            if hash_subseq == curr_hash and seq[i:i + m] == subseq:
                yield i
    
    %%cython -c-O3 -c-march=native -a
    #cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True
    
    
    def find_rk_cy(seq, subseq):
        cdef Py_ssize_t n = len(seq)
        cdef Py_ssize_t m = len(subseq)
        if seq[:m] == subseq:
            yield 0
        cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq)  # compute hash
        cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
        cdef Py_ssize_t old_item, new_item
        for i in range(1, n - m + 1):
            old_item = hash(seq[i - 1])
            new_item = hash(seq[i + m - 1])
            curr_hash += new_item - old_item  # update hash
            if hash_subseq == curr_hash and seq[i:i + m] == subseq:
                yield i
    
    def gen_input(n, k=2):
        return tuple(random.randint(0, k - 1) for _ in range(n))
    
    def gen_input_worst(n, k=-2):
        result = [0] * n
        result[k] = 1
        return tuple(result)
    
    @nb.jit
    def _is_equal_nb(seq, subseq, m, i):
        for j in range(m):
            if seq[i + j] != subseq[j]:
                return False
        return True
    
    
    @nb.jit
    def find_loop_nb(seq, subseq):
        n = len(seq)
        m = len(subseq)
        for i in range(n - m + 1):
            if _is_equal_nb(seq, subseq, m, i):
                yield i
    
    def find_pivot(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if m > n:
            return
        max_i = n - m
        for i in np.where(seq == subseq[0])[0]:
            if i > max_i:
                return
            elif np.all(seq[i:i + m] == subseq):
                yield i
    
    
    def find_pivot2(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if m > n:
            return
        max_i = n - m
        for i in np.where(seq == subseq[0])[0]:
            if i > max_i:
                return
            elif seq[i + m - 1] == subseq[m - 1] \
                    and np.all(seq[i:i + m] == subseq):
                yield i
    
    def rolling_window(arr, size):
        shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
        strides = arr.strides + (arr.strides[-1],)
        return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)
    
    
    def find_rolling(seq, subseq):
        bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
        yield from np.mgrid[0:len(bool_indices)][bool_indices]
    
    def find_rolling2(seq, subseq):
        windows = rolling_window(seq, len(subseq))
        hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
        for i, x in enumerate(subseq):
            hits &= np.in1d(windows[:, i], [x])
        yield from hits.nonzero()[0]
    
    find_kmp_nb = nb.jit(find_kmp)
    find_kmp_nb.__name__ = 'find_kmp_nb'
    
    @nb.jit
    def sum_hash_nb(arr):
        result = 0
        for x in arr:
            result += hash(x)
        return result
    
    
    @nb.jit
    def find_rk_nb(seq, subseq):
        n = len(seq)
        m = len(subseq)
        if _is_equal_nb(seq, subseq, m, 0):
            yield 0
        hash_subseq = sum_hash_nb(subseq)  # compute hash
        curr_hash = sum_hash_nb(seq[:m])  # compute hash
        for i in range(1, n - m + 1):
            curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])  # update hash
            if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
                yield i
    
    def find_conv(seq, subseq):
        target = np.dot(subseq, subseq)
        candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
        check = candidates[:, np.newaxis] + np.arange(len(subseq))
        mask = np.all((np.take(seq, check) == subseq), axis=-1)
        yield from candidates[mask]
    
    def gen_input(n, k=2):
        return np.random.randint(0, k, n)
    
    def gen_input_worst(n, k=-2):
        result = np.zeros(n, dtype=int)
        result[k] = 1
        return result