Python numpy阵列的共享目录?

Python numpy阵列的共享目录?,python,numpy,python-multiprocessing,Python,Numpy,Python Multiprocessing,我想用许多numpy数组存储一个dict,并跨进程共享它 import ctypes import multiprocessing from typing import Dict, Any import numpy as np dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict() def get_numpy(key): if key not in dict_of_np: share

我想用许多numpy数组存储一个dict,并跨进程共享它

import ctypes
import multiprocessing
from typing import Dict, Any

import numpy as np

dict_of_np: Dict[Any, np.ndarray] = multiprocessing.Manager().dict()


def get_numpy(key):
    if key not in dict_of_np:
        shared_array = multiprocessing.Array(ctypes.c_int32, 5)
        shared_np = np.frombuffer(shared_array.get_obj(), dtype=np.int32)
        dict_of_np[key] = shared_np
    return dict_of_np[key]


if __name__ == "__main__":
    a = get_numpy("5")
    a[1] = 5
    print(a)  # prints [0 5 0 0 0]
    b = get_numpy("5")
    print(b)  # prints [0 0 0 0 0]

我按照中的说明使用缓冲区创建numpy数组,但当我尝试将生成的numpy数组保存在dict中时,它不起作用。如上所示,再次使用键访问dict时,不会保存对numpy数组的更改


如何共享numpy阵列的dict?我需要dict和数组共享并使用相同的内存。

根据我们对问题的讨论,我可能已经想出了一个解决方案:在主进程中使用线程来处理
多处理.shared\u memory.SharedMemory
对象的实例化,您可以确保对共享内存对象的引用保持不变,并且不会过早删除底层内存。这只解决了windows的问题,在windows中,当不再存在对文件的引用时,文件被删除。它不能解决只要需要底层memoryview,就需要保留每个打开的实例的问题

此管理器线程“侦听”输入
多处理.Queue
上的消息,并创建/返回有关共享内存对象的数据。锁用于确保响应被正确的进程读取(否则响应可能会混淆)

所有共享内存对象首先由主进程创建,并一直保留到显式删除,以便其他进程可以访问它们

例如:

import multiprocessing
from multiprocessing import shared_memory, Queue, Process, Lock
from threading import Thread
import numpy as np

class Exit_Flag: pass
 
class SHMController:
    def __init__(self):
        self._shm_objects = {}
        self.mq = Queue() #message input queue
        self.rq = Queue() #response output queue
        self.lock = Lock() #only let one child talk to you at a time
        self._processing_thread = Thread(target=self.process_messages)
    
    def start(self): #to be called after all child processes are started
        self._processing_thread.start()
        
    def stop(self):
        self.mq.put(Exit_Flag())
        
    def __enter__(self):
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
    
    def process_messages(self):
        while True:
            message_obj = self.mq.get()
            if isinstance(message_obj, Exit_Flag):
                break
            elif isinstance(message_obj, str):
                message = message_obj
                response = self.handle_message(message)
                self.rq.put(response)
        self.mq.close()
        self.rq.close()
    
    def handle_message(self, message):
        method, arg = message.split(':', 1)
        if method == "exists":
            if arg in self._shm_objects: #if shm.name exists or not
                return "ok:true"
            else:
                return "ok:false"
        if method == "size":
            if arg in self._shm_objects:
                return f"ok:{len(self._shm_objects[arg].buf)}"
            else:
                return "ko:-1"
        if method == "create":
            args = arg.split(",") #name, size or just size
            if len(args) == 1:
                name = None
                size = int(args[0])
            elif len(args) == 2:
                name = args[0]
                size = int(args[1])
            if name in self._shm_objects:
                return f"ko:'{name}' already created"
            else:
                try:
                    shm = shared_memory.SharedMemory(name=name, create=True, size=size)
                except FileExistsError:
                    return f"ko:'{name}' already exists"
                self._shm_objects[shm.name] = shm
                return f"ok:{shm.name}"
        if method == "destroy":
            if arg in self._shm_objects:
                self._shm_objects[arg].close()
                self._shm_objects[arg].unlink()
                del self._shm_objects[arg]
                return f"ok:'{arg}' destroyed"
            else:
                return f"ko:'{arg}' does not exist"
    
def create(mq, rq, lock):
    #helper functions here could make access less verbose
    with lock:
        mq.put("create:key123,8")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        name = response.split(':')[1]
        with lock:
            mq.put(f"size:{name}")
            response = rq.get()
        print(response)
        if response[:2] == "ok":
            size = int(response.split(":")[1])
            shm = shared_memory.SharedMemory(name=name, create=False, size=size)
        else:
            print("Oh no....")
            return
    else:
        print("Uh oh....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[:] = (1,2)
    print(arr)
    shm.close()
    
def modify(mq, rq, lock):
    while True: #until the shm exists
        with lock:
            mq.put("exists:key123")
            response = rq.get()
        if response == "ok:true":
            print("key:exists")
            break
    with lock:
        mq.put("size:key123")
        response = rq.get()
    print(response)
    if response[:2] == "ok":
        size = int(response.split(":")[1])
        shm = shared_memory.SharedMemory(name="key123", create=False, size=size)
    else:
        print("Oh no....")
        return
    arr = np.ndarray((2,), buffer=shm.buf, dtype=np.int32)
    arr[0] += 5
    print(arr)
    shm.close()
    
def delete(mq, rq, lock):
    pass #TODO make a test for this?

 
if __name__ == "__main__":
    multiprocessing.set_start_method("spawn") #because I'm mixing threads and processes
    with SHMController() as controller:
        mq, rq, lock = controller.mq, controller.rq, controller.lock
        create_task = Process(target=create, args=(mq, rq, lock))
        create_task.start()
        create_task.join()
        modify_task = Process(target=modify, args=(mq, rq, lock))
        modify_task.start()
        modify_task.join()
    print("finished")
为了解决每个shm在数组中保持活动状态的问题,必须保留对该特定shm对象的引用。通过将引用作为属性附加到自定义数组子类(从numpy指南复制到子类),将引用保留在数组旁边相当简单


我发现它也有同样的问题,但这里描述的解决方案是在修改时复制numpy数组,这对我不起作用。使用numpy的目的是避免除了修改的元素之外的任何复制。我当前的问题是,访问由使用共享内存名创建的缓冲区支持的numpy数组是错误的。我想我必须提交一份bug报告。最奇怪的是,它在一个函数中运行良好,但当我从另一个函数返回numpy数组时就会中断。这与引用无关。github注释上的代码与此无关<当函数返回时,函数中的code>shm会得到GC'd。啊,我原以为您只需要对原始shm的引用,而不是对用于创建数组的特定shm的引用。现在一切都有意义了。谢谢有两个不同的问题,是的。保留原始数组可以防止windows删除它(一旦不存在引用,就会发生这种情况),但保留用于创建数组的数组可以防止segfault。
class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array

    def __new__(cls, input_array, shm=None):
        obj = np.asarray(input_array).view(cls)
        obj.shm = shm
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.shm = getattr(obj, 'shm', None)
#example
shm = shared_memory.SharedMemory(name=name)
np_array = SHMArray(np.ndarray(shape, buffer=shm.buf, dtype=np.int32), shm)