如何在python joblib中写入共享变量
下面的代码并行化for循环如何在python joblib中写入共享变量,python,parallel-processing,shared-memory,joblib,Python,Parallel Processing,Shared Memory,Joblib,下面的代码并行化for循环 import networkx as nx; import numpy as np; from joblib import Parallel, delayed; import multiprocessing; def core_func(repeat_index, G, numpy_arrary_2D): for u in G.nodes(): numpy_arrary_2D[repeat_index][u] = 2; return; if __n
import networkx as nx;
import numpy as np;
from joblib import Parallel, delayed;
import multiprocessing;
def core_func(repeat_index, G, numpy_arrary_2D):
for u in G.nodes():
numpy_arrary_2D[repeat_index][u] = 2;
return;
if __name__ == "__main__":
G = nx.erdos_renyi_graph(100000,0.99);
nRepeat = 5000;
numpy_array = np.zeros([nRepeat,G.number_of_nodes()]);
Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
print(np.mean(numpy_array));
可以看出,要打印的预期值为2。但是,当我在集群(多核、共享内存)上运行代码时,它返回0.0
我认为问题在于每个worker都创建了自己的
numpy\u数组
对象副本,而在main函数中创建的副本没有更新。如何修改代码以更新numpy数组numpy\u数组
。joblib
默认使用多进程池,如下所示:
在引擎盖下,并行对象创建一个多处理池
在多个进程中分叉Python解释器以执行以下各项
列表中的项目。延迟函数是一个简单的技巧
能够通过函数调用创建元组(函数、args、kwargs)
语法
这意味着,每个进程都继承了数组的原始状态,但当进程退出时,它在数组中写入的任何内容都将丢失。只有函数结果被传递回调用(主)进程。但是您不返回任何内容,因此None
将被返回
要使共享阵列可修改,有两种方法:使用线程和使用共享内存
与进程不同,线程共享内存。因此,您可以写入数组,每个作业都会看到此更改。根据
joblib
手册,它是这样做的:
Parallel(n_jobs=4, backend="threading")(delayed(core_func)(repeat_index, G, numpy_array) for repeat_index in range(nRepeat));
运行时:
$ python r1.py
2.0
但是,当您将复杂的内容写入数组时,请确保正确处理数据或数据块周围的锁,否则您将遇到竞争条件(googleit)
还要仔细阅读GIL,因为Python中的计算多线程是有限的(与I/O多线程不同)
如果您仍然需要这些进程(例如,由于GIL),您可以将该阵列放入共享内存中
这是一个有点复杂的主题,但在
joblib
手册中也有显示。正如Sergey在回答中所写,进程不共享状态和内存。这就是为什么你看不到预期的答案
线程在同一进程下运行时,共享状态和内存空间。如果您有许多I/O操作,这将非常有用。因为GIL,它不会给你带来更多的处理能力(更多的CPU)
进程间通信的一种技术是使用Manager代理对象。您可以创建一个管理器对象,用于在进程之间同步资源
manager()返回的manager对象控制一个服务器进程,该进程保存Python对象,并允许其他进程使用代理操作它们
我还没有测试这段代码(我没有您使用的所有模块),它可能需要对代码进行更多的修改,但是使用Manager对象应该是这样的
if __name__ == "__main__":
G = nx.erdos_renyi_graph(100000,0.99);
nRepeat = 5000;
manager = multiprocessing.Manager()
numpys = manager.list(np.zeros([nRepeat, G.number_of_nodes()])
Parallel(n_jobs=4)(delayed(core_func)(repeat_index, G, numpys, que) for repeat_index in range(nRepeat));
print(np.mean(numpys));
那里的数据结构在语义上是浮点列表(矩阵/表),但实际上是
numpy.array
的numpy.array
的numpy.float64
值的一个实例。通过默认管理器同步这些自定义数据类型会有很多问题,默认管理器只支持很少的标量值、本机列表和dict。那么,您决定答案了吗?;-)