我可以用全局dict并行化这个小Python脚本吗?
我有一个问题,在我的Python3 MBA课程中需要2.8秒。因为在它的核心我们有一个缓存字典,所以我认为哪个调用首先命中缓存并不重要,所以也许我可以从线程中获得一些好处。不过我还是不太明白。这比我通常问的问题要高一点,但是有人能带我完成这个问题的并行化过程吗我可以用全局dict并行化这个小Python脚本吗?,python,Python,我有一个问题,在我的Python3 MBA课程中需要2.8秒。因为在它的核心我们有一个缓存字典,所以我认为哪个调用首先命中缓存并不重要,所以也许我可以从线程中获得一些好处。不过我还是不太明白。这比我通常问的问题要高一点,但是有人能带我完成这个问题的并行化过程吗 import time import threading even = lambda n: n%2==0 next_collatz = lambda n: n//2 if even(n) else 3*n+1 cache = {1:
import time
import threading
even = lambda n: n%2==0
next_collatz = lambda n: n//2 if even(n) else 3*n+1
cache = {1: 1}
def collatz_chain_length(n):
if n not in cache: cache[n] = 1 + collatz_chain_length(next_collatz(n))
return cache[n]
if __name__ == '__main__':
valid = range(1, 1000000)
for n in valid:
# t = threading.Thread(target=collatz_chain_length, args=[n] )
# t.start()
collatz_chain_length(n)
print( max(valid, key=cache.get) )
或者,如果它是一个不好的候选者,为什么?线程在性能方面不会给您带来太多好处,因为它无法绕过全局解释器锁,全局解释器锁在任何给定时刻只运行一个线程。它甚至可能会因为上下文切换而减慢您的速度
如果您想在Python中利用并行化来提高性能,那么必须使用多处理来一次实际利用多个内核。如果您的工作负载是CPU密集型的,那么Python中的线程将无法得到很好的提升。这是因为由于GIL(全局解释器锁),一次只有一个线程实际使用处理器 但是,如果您的工作负载是I/O绑定的(例如等待网络请求的响应),线程会给您带来一些提升,因为如果您的线程在等待网络响应时被阻塞,那么另一个线程可以执行有用的工作 正如HDN提到的,使用多处理将有所帮助——这将使用多个Python解释器来完成工作 我的方法是将迭代次数除以您计划创建的进程数。例如,如果创建4个进程,则为每个进程分配一个
1000000/4
工作片段
最后,您需要汇总每个过程的结果,并应用
max()
来获得结果。我成功地将代码在单核上加速了16.5x,请进一步阅读
正如前面所说的,在纯Python中,多线程并没有带来任何改进,这是因为
关于多处理-有两个选项1)实现共享字典并直接从不同的进程读取/写入。2) 若要将值的范围划分为多个部分,并为不同进程上的不同子范围求解任务,则只需从所有进程的答案中取最大值即可
第一个选项将非常慢,因为在您的代码中,读/写字典是主要的耗时操作,使用进程间共享字典将使其速度慢5倍以上,而不会带来多核的改进
第二个选项将提供一些改进,但也不是很好,因为不同的进程将多次重新计算相同的值。只有在集群中有很多内核或使用许多独立的机器时,此选项才能提供相当大的改进
我决定实施另一种方法来改进您的任务(选项3)——使用和执行其他优化。我的解决方案也适用于选项2(子范围的并行化)
要使用numba运行代码,只需安装
pip install numba
(目前,Python版本支持Nuba,这是否回答了您的问题?@Tomerikoo这里的任务更难。问题是,有一个基于共享字典的缓存,需要由所有进程更新。而且该缓存读/写占用了大部分计算时间。这意味着如果在所有进程之间使用共享dicte程序将变得比单核版本慢得多。所以简单的并行化是没有帮助的。@Tomerikoo缓存在这里是需要的,并且使用得很好。这是众所周知的Collatz问题在这里得到解决。许多数字将在途中重复使用,并且会有很多缓存命中。
import time, threading, time, numba
def solve_py(start, stop):
even = lambda n: n%2==0
next_collatz = lambda n: n//2 if even(n) else 3*n+1
cache = {1: 1}
def collatz_chain_length(n):
if n not in cache: cache[n] = 1 + collatz_chain_length(next_collatz(n))
return cache[n]
for n in range(start, stop):
collatz_chain_length(n)
r = max(range(start, stop), key = cache.get)
return r, cache[r]
@numba.njit(cache = True, locals = {'n': numba.int64, 'l': numba.int64, 'zero': numba.int64})
def solve_nm(start, stop):
zero, l, cs = 0, 0, stop * 10
ns = [zero] * 10000
cache_lo = [zero] * cs
cache_lo[1] = 1
cache_hi = {zero: zero}
for n in range(start, stop):
if cache_lo[n] != 0:
continue
nsc = 0
while True:
if n < cs:
cg = cache_lo[n]
else:
cg = cache_hi.get(n, zero)
if cg != 0:
l = 1 + cg
break
ns[nsc] = n
nsc += 1
n = (n >> 1) if (n & 1) == 0 else 3 * n + 1
for i in range(nsc - 1, -1, -1):
if ns[i] < cs:
cache_lo[ns[i]] = l
else:
cache_hi[ns[i]] = l
l += 1
maxn, maxl = 0, 0
for k in range(start, stop):
v = cache_lo[k]
if v > maxl:
maxn, maxl = k, v
return maxn, maxl
if __name__ == '__main__':
solve_nm(1, 100000) # heat-up, precompile numba
for stop in [1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000]:
tr, resr = None, None
for is_nm in [False, True]:
if stop > 16000000 and not is_nm:
continue
tb = time.time()
res = (solve_nm if is_nm else solve_py)(1, stop)
te = time.time()
print(('py', 'nm')[is_nm], 'limit', stop, 'time', round(te - tb, 2), 'secs', end = '')
if not is_nm:
resr, tr = res, te - tb
print(', n', res[0], 'len', res[1])
else:
if tr is not None:
print(', boost', round(tr / (te - tb), 2))
assert resr == res, (resr, res)
else:
print(', n', res[0], 'len', res[1])
py limit 1000000 time 3.34 secs, n 837799 len 525
nm limit 1000000 time 0.19 secs, boost 17.27
py limit 2000000 time 6.72 secs, n 1723519 len 557
nm limit 2000000 time 0.4 secs, boost 16.76
py limit 4000000 time 13.47 secs, n 3732423 len 597
nm limit 4000000 time 0.83 secs, boost 16.29
py limit 8000000 time 27.32 secs, n 6649279 len 665
nm limit 8000000 time 1.68 secs, boost 16.27
py limit 16000000 time 55.42 secs, n 15733191 len 705
nm limit 16000000 time 3.48 secs, boost 15.93
nm limit 32000000 time 7.38 secs, n 31466382 len 706
nm limit 64000000 time 16.83 secs, n 63728127 len 950
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <tuple>
#include <iostream>
#include <stdexcept>
#include <chrono>
typedef int64_t i64;
static std::tuple<i64, i64> Solve(i64 start, i64 stop) {
i64 cs = stop * 10, n = 0, l = 0, nsc = 0;
std::vector<i64> cache_lo(cs), ns(10000);
cache_lo[1] = 1;
std::unordered_map<i64, i64> cache_hi;
for (i64 i = start; i < stop; ++i) {
if (cache_lo[i] != 0)
continue;
n = i;
nsc = 0;
while (true) {
i64 cg = 0;
if (n < cs)
cg = cache_lo[n];
else {
auto it = cache_hi.find(n);
if (it != cache_hi.end())
cg = it->second;
}
if (cg != 0) {
l = 1 + cg;
break;
}
ns.at(nsc) = n;
++nsc;
n = (n & 1) ? 3 * n + 1 : (n >> 1);
}
for (i64 i = nsc - 1; i >= 0; --i) {
i64 n = ns[i];
if (n < cs)
cache_lo[n] = l;
else
cache_hi[n] = l;
++l;
}
}
i64 maxn = 0, maxl = 0;
for (size_t i = start; i < stop; ++i)
if (cache_lo[i] > maxl) {
maxn = i;
maxl = cache_lo[i];
}
return std::make_tuple(maxn, maxl);
}
int main() {
try {
for (auto stop: std::vector<i64>({1000000, 2000000, 4000000, 8000000, 16000000, 32000000, 64000000})) {
auto tb = std::chrono::system_clock::now();
auto r = Solve(1, stop);
auto te = std::chrono::system_clock::now();
std::cout << "cpp limit " << stop
<< " time " << double(std::chrono::duration_cast<std::chrono::milliseconds>(te - tb).count()) / 1000.0 << " secs"
<< ", n " << std::get<0>(r) << " len " << std::get<1>(r) << std::endl;
}
return 0;
} catch (std::exception const & ex) {
std::cout << "Exception: " << ex.what() << std::endl;
return -1;
}
}
cpp limit 1000000 time 0.17 secs, n 837799 len 525
cpp limit 2000000 time 0.357 secs, n 1723519 len 557
cpp limit 4000000 time 0.757 secs, n 3732423 len 597
cpp limit 8000000 time 1.571 secs, n 6649279 len 665
cpp limit 16000000 time 3.275 secs, n 15733191 len 705
cpp limit 32000000 time 7.112 secs, n 31466382 len 706
cpp limit 64000000 time 17.165 secs, n 63728127 len 950