Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/361.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 一系列矩阵的快速乘法_Python_Numpy_Matrix Multiplication - Fatal编程技术网

Python 一系列矩阵的快速乘法

Python 一系列矩阵的快速乘法,python,numpy,matrix-multiplication,Python,Numpy,Matrix Multiplication,最快的跑步方式是什么: reduce(lambda x,y : x@y, ls) 用python 查看矩阵列表ls。我没有Nvidia GPU,但我有很多CPU内核可供使用。我认为我可以使流程并行工作(将其拆分为logiterations),但似乎对于小型(1000x1000)矩阵,这实际上是最糟糕的。以下是我尝试的代码: from multiprocessing import Pool import numpy as np from itertools import zip_long

最快的跑步方式是什么:

    reduce(lambda x,y : x@y, ls)
用python

查看矩阵列表
ls
。我没有Nvidia GPU,但我有很多CPU内核可供使用。我认为我可以使流程并行工作(将其拆分为
log
iterations),但似乎对于小型(
1000x1000
)矩阵,这实际上是最糟糕的。以下是我尝试的代码:

from multiprocessing import Pool
import numpy as np
from itertools import zip_longest

def matmul(x):
    if x[1] is None:
        return x[0]
    return x[1]@x[0]

def fast_mul(ls):
    while True:
        
        n = len(ls)
        if n == 0:
            raise Exception("Splitting Error")
        if n == 1:
            return ls[0]
        if n == 2:
            return ls[1]@ls[0]

        with Pool(processes=(n//2+1)) as pool:
            ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
    


编辑:加入了另一个可能的函数

编辑:我添加了带有的结果,期望它会比其他结果快,但实际上不知何故它要慢得多。我想它是在设计时考虑了其他类型的用例


我不确定你能跑得比那快得多。对于数据为三维方阵阵列的情况,以下是几种不同的简化实现:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)
来自多处理导入池的

从functools导入reduce
将numpy作为np导入
进口麻木为nb
def matmul_n_naive(数据):
返回减少(np.matmul,数据)
#如果您不关心修改数据,那么pass copy=False
def matmul_n_二进制(数据,副本=真):
如果len(数据)<1:
升值误差
data=np.array(data,copy=copy)
n、 r,c=data.shape
dt=data.dtype
s=1
而(n+s-1)//s>1:
a=数据[:n-s:2*s]
b=数据[s:n:2*s]
np.matmul(a,b,out=a)
s*=2
返回np.array(a[0])
def matmul_n_池(数据):
如果len(数据)<1:
升值误差
lst=数据
使用Pool()作为池:
而len(lst)>1:
lst_next=pool.starmap(np.matmul,zip(lst[::2],lst[1::2]))
如果len(lst)%2!=0:
lst_next.append(lst[-1])
lst=lst\u下一个
返回lst[0]
@注意:njit(平行=假)
def matmul_n_numba_nopar(数据):
res=np.eye(data.shape[1],data.shape[2],dtype=data.dtype)
对于nb.prange中的i(len(数据)):
res=res@data[i]
返回res
@注意:njit(平行=真实)
def matmul_n_numba_par(数据):
res=np.eye(data.shape[1],data.shape[2],dtype=data.dtype)
对于我在nb.prange(len(data)):#Numba知道如何正确地进行并行归约
res=res@data[i]
返回res
def matmul_n_多点(数据):
返回np.linalg.multi_点(数据)
还有一个测试:

#测试
将numpy作为np导入
np.random.seed(0)
a=np.rand.rand(10100100)*2-1
b1=matmul_n_naive(a)
b2=matmul_n_二进制(a)
b3=matmul_n_池(a)
b4=matmul\u n\u numba\u nopar(a)
b5=matmul_n_numba_par(a)
b6=matmul\u n\u多点(a)
打印(np.allclose(b1、b2))
#真的
打印(np.allclose(b1,b3))
#真的
打印(np.allclose(b1,b4))
#真的
打印(np.allclose(b1,b5))
#真的
打印(np.allclose(b1,b6))
#真的
这里有一些基准测试,似乎没有一致的赢家,但“朴素”的解决方案在各个方面都很好,二进制和Numba各不相同,进程池不是很好,
np.linalg.multi_dot
似乎对方阵不是很有利

将numpy导入为np
#10个矩阵1000x1000
np.random.seed(0)
a=np.rand.rand(1010001000)*0.1-0.05
%timeit matmul_n_naive(a)
#每个回路121 ms±6.09 ms(7次运行的平均值±标准偏差,每个10个回路)
%timeit matmul n_二进制(a)
#每个回路165 ms±3.68 ms(7次运行的平均值±标准偏差,每个10个回路)
%timeit matmul_n_numba_nopar(a)
#每个回路108 ms±510µs(7次运行的平均值±标准偏差,每个10个回路)
%timeit matmul_n_numba_par(a)
#每个回路244 ms±7.66 ms(7次运行的平均值±标准偏差,每个回路1次)
%timeit matmul\u n\u多点(a)
#每个回路132 ms±2.41 ms(7次运行的平均值±标准偏差,每个10个回路)
#200个矩阵100x100
np.random.seed(0)
a=np.rand.rand(200100100)*0.1-0.05
%timeit matmul_n_naive(a)
#每个回路4.4 ms±226µs(7次运行的平均值±标准偏差,每个100个回路)
%timeit matmul n_二进制(a)
#每个回路13.4 ms±299µs(7次运行的平均值±标准偏差,每个100个回路)
%timeit matmul_n_numba_nopar(a)
#每个回路9.51 ms±126µs(7次运行的平均值±标准偏差,每个100个回路)
%timeit matmul_n_numba_par(a)
#每个回路4.93 ms±146µs(7次运行的平均值±标准偏差,每个100个回路)
%timeit matmul\u n\u多点(a)
#每个回路1.14 s±22.1 ms(7次运行的平均值±标准偏差,每个回路1次)
#300矩阵10x10
np.random.seed(0)
a=np.rand.rand(300,10,10)*0.1-0.05
%timeit matmul_n_naive(a)
#每个回路526µs±953 ns(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul n_二进制(a)
#每个回路152µs±508 ns(7次运行的平均值±标准偏差,每个10000个回路)
%timeit matmul n_池(a)
#每个回路610 ms±5.93 ms(7次运行的平均值±标准偏差,每个回路1次)
%timeit matmul_n_numba_nopar(a)
#每个回路239µs±1.1µs(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul_n_numba_par(a)
#每个回路175µs±422 ns(7次运行的平均值±标准偏差,每个10000个回路)
%timeit matmul\u n\u多点(a)
#每个回路3.68 s±87 ms(7次运行的平均值±标准偏差,每个回路1次)
#1000矩阵10x10
np.random.seed(0)
a=np.rand.rand(1000,10,10)*0.1-0.05
%timeit matmul_n_naive(a)
#每个回路1.56 ms±4.49µs(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul n_二进制(a)
#每个回路392µs±790 ns(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul n_池(a)
#每个回路727 ms±12.2 ms(7次运行的平均值±标准偏差,每个回路1次)
%timeit matmul_n_numba_nopar(a)
#每个回路589µs±356 ns(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul_n_numba_par(a)
#每个回路451µs±1.68µs(7次运行的平均值±标准偏差,每个1000个回路)
%timeit matmul\u n\u多点(a)
#从未完成。。。

有一个函数可以实现这一点:,据说是为最佳评估顺序而优化的:

np.linalg.multi_dot(ls)
事实上,医生说的话与你最初的措辞非常接近:

可以认为:

您也可以尝试
np.einsum
,这将允许您将最多25个矩阵相乘:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)
定时

简单案例:

ls = np.random.rand(100, 1000, 1000) - 0.5

%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
更复杂的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
请注意
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)