Python Numba parallel prange无法扩展-非常缓慢

Python Numba parallel prange无法扩展-非常缓慢,python,loops,parallel-processing,numba,Python,Loops,Parallel Processing,Numba,我用Numba来加速一些数字运算代码。这项工作很容易并行化,我正在尝试麻木。我希望线程的数量几乎是线性的(至少在达到内存带宽之前是这样),但我几乎没有得到任何扩展 我只是在片中编写一个Numpy数组,每个线程在不同的片上工作: @numba.njit def do_work(i_row_begin, i_row_end, out): a = np.empty(5) for i_row in range(i_row_begin, i_row_end): a[0] =

我用Numba来加速一些数字运算代码。这项工作很容易并行化,我正在尝试麻木。我希望线程的数量几乎是线性的(至少在达到内存带宽之前是这样),但我几乎没有得到任何扩展

我只是在片中编写一个Numpy数组,每个线程在不同的片上工作:

@numba.njit
def do_work(i_row_begin, i_row_end, out):
    a = np.empty(5)
    for i_row in range(i_row_begin, i_row_end):
        a[0] = i_row
        for index_s in range(out.shape[1]):
            a[1] = index_s
            for index_t in range(out.shape[2]):
                a[2] = index_t
                a[3] = index_t / (1.2 + i_row)
                a[4] = index_t / (1.8 + i_row)
                out[i_row, index_s, index_t] = np.sum(a / (1 + np.sum(a)))


@numba.njit(parallel=True)
def do_work_parallel(num_threads, num_rows):

    out = np.empty((num_rows, 3, 300))

    # calculate threads
    num_rows_per_thread = int(math.ceil(num_rows / num_threads))

    for index_thread in numba.prange(num_threads):
        # Loop over loan parts
        i_row_begin = index_thread * num_rows_per_thread
        i_row_end = min(num_rows, (index_thread + 1) * num_rows_per_thread)

        do_work(i_row_begin, i_row_end, out)

    return out
这是“主”脚本:

n_rows = 10000

def run(num_threads):
    return do_work_parallel(num_threads, n_rows)

# Execute function once to compile numba functions
_ = run(2)

for num_threads in [1, 2, 3, 4]:
    print("Num threads", num_threads)
    now = time.time()
    _ = run(num_threads)
    diff = time.time() - now
    print("Time elapsed: {:.3e}".format(diff))
    print("Speed: {:.3e} rows/s/core".format(n_rows/(diff * num_threads)))
    print("")

print("DONE")
我得到的输出如下:

Num threads: 1
Time elapsed: 1.078e+00
Speed: 9.275e+03 rows/s/core

Num threads: 2
Time elapsed: 1.484e+00
Speed: 3.368e+03 rows/s/core

Num threads: 3
Time elapsed: 1.442e+00
Speed: 2.311e+03 rows/s/core

Num threads: 4
Time elapsed: 1.469e+00
Speed: 1.702e+03 rows/s/core
所以很明显根本没有加速。实际上,并行化正在恶化性能。谁能解释为什么会发生这种情况

编辑:
似乎与此错误有关:

什么是读写?要回答一个只有伪代码的问题几乎是不可能的。你能就问题的某些真实部分给出一个简短的例子吗?我刚刚编辑了这个问题,并发布了一个可以重现该行为的简化版本的代码。刚刚发现它与此bug有关: