Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/336.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 用numpy求解大量小方程组_Python_Numpy_Linear Algebra - Fatal编程技术网

Python 用numpy求解大量小方程组

Python 用numpy求解大量小方程组,python,numpy,linear-algebra,Python,Numpy,Linear Algebra,我有很多小的线性方程组,我想用numpy有效地解决它们。基本上,给定A[:,:,:]和b[:,:],我希望找到x[:,:]由A[I,:,:]给出的.dot(x[I,:])=b[I,:]。所以,如果我不在乎速度,我可以解决这个问题 for i in range(n): x[i,:] = np.linalg.solve(A[i,:,:],b[i,:]) 但是,由于这涉及到python中的显式循环,并且由于A通常具有类似(1000000,3,3)的形状,这样的解决方案将非常缓慢。如果nump

我有很多小的线性方程组,我想用numpy有效地解决它们。基本上,给定
A[:,:,:]
b[:,:]
,我希望找到
x[:,:]
A[I,:,:]给出的.dot(x[I,:])=b[I,:]
。所以,如果我不在乎速度,我可以解决这个问题

for i in range(n):
    x[i,:] = np.linalg.solve(A[i,:,:],b[i,:])

但是,由于这涉及到python中的显式循环,并且由于
A
通常具有类似
(1000000,3,3)
的形状,这样的解决方案将非常缓慢。如果numpy不能做到这一点,我可以用fortran(即使用f2py)来完成这个循环,但如果可能的话,我更喜欢用python。

我认为你可以一次完成,用一个(3x100000,3x100000)矩阵在对角线上由3x3个块组成

未测试:

b_new = np.vstack([ b[i,:] for i in range(len(i)) ])
x_new = np.zeros(shape=(3x10000,3) )

A_new = np.zeros(shape=(3x10000,3x10000) )
n,m = A.shape
for i in range(n):
   A_new[3*i:3*(i+1),3*i:3*(i+1)] = A[i,:,:]

x = np.linalg.solve(A_new,b_new)

我猜回答自己的问题有点失礼,但这是我目前掌握的fortran解决方案,也就是说,其他解决方案在速度和简洁性方面都在与之竞争

function pixsolve(A, b) result(x)
    implicit none
    real*8    :: A(:,:,:), b(:,:), x(size(b,1),size(b,2))
    integer*4 :: i, n, m, piv(size(b,1)), err
    n = size(A,3); m = size(A,1)
    x = b
    do i = 1, n
        call dgesv(m, 1, A(:,:,i), m, piv, x(:,i), m, err)
    end do
end function
这将被汇编为:

f2py -c -m foo{,.f90} -llapack -lblas
并从python中调用为

x = foo.pixsolve(A.T, b.T).T
(需要
.T
s,因为f2py中的设计选择不好,如果省略
.T
s,这都会导致不必要的复制、低效的内存访问模式和看起来不自然的fortran索引。)


这也避免了setup.py等。我没有必要选择fortran(只要不涉及字符串),但我希望numpy可能有一些简短而优雅的东西,可以做同样的事情。

我认为您错误地认为显式循环是一个问题。通常它只是最里面的循环,值得优化,我认为这在这里是正确的。例如,我们可以测量开销代码与实际计算成本:

import numpy as np

n = 10**6
A = np.random.random(size=(n, 3, 3))
b = np.random.random(size=(n, 3))
x = b*0

def f():
    for i in xrange(n):
        x[i,:] = np.linalg.solve(A[i,:,:],b[i,:])

np.linalg.pseudosolve = lambda a,b: b

def g():
    for i in xrange(n):
        x[i,:] = np.linalg.pseudosolve(A[i,:,:],b[i,:])
这让我

In [66]: time f()
CPU times: user 54.83 s, sys: 0.12 s, total: 54.94 s
Wall time: 55.62 s

In [67]: time g()
CPU times: user 5.37 s, sys: 0.01 s, total: 5.38 s
Wall time: 5.40 s
瞧,它只花了10%的时间做任何事情,而不是真正解决你的问题。现在,我完全可以相信,
np.linalg.solve
本身相对于Fortran来说太慢了,所以你想做些别的事情。想想看,在小问题上可能尤其如此:IIRC我曾经发现手工展开某些小解决方案的速度更快,尽管那是很久以前的事了


但就其本身而言,在第一个索引上使用显式循环并不会使整体解决方案变得相当缓慢。如果
np.linalg.solve
足够快,这里的循环不会有太大的改变。

对于现在回来阅读这个问题的人,我想我应该节省其他人的时间,并提到numpy现在使用广播来处理这个问题

因此,在numpy 1.8.0及更高版本中,以下可用于求解N个线性方程组

x = np.linalg.solve(A,b)

当然,大量的小型线性方程组可以组合在一个大型稀疏线性系统中,一次求解一次?是的,它们可以。那会有效率吗?你认为它会比fortran中的循环更好吗?老实说,像这样的东西才是Cython的亮点。它并没有完全停留在python中,但也没有偏离太远,并且使用numpy数组是完全无缝的。它不会像fortran那么快,但也不会慢。@joferkington:那么你会在循环中使用Cython建立稀疏矩阵,然后作为稀疏系统求解它?从Cython循环调用
np.linalg.solve
将是徒劳无益的,不是吗,因为Cython将无法消除该函数调用的python开销。你完全正确,瓶颈将是python函数调用的开销,但在尝试稀疏解决方案之前,我仍然会尝试在cython的循环中调用
np.linalg.solve
。若有必要,可以通过使用numpy的C接口避免大量python开销。当然,由于黑客太多,仅仅使用fortran会更干净。让我看看是否能拼凑出一个例子。它可能没有我说的那么快…嗯。不过,这种方法仍然有n次迭代的python循环,这正是我试图避免的。但我认为scipy.sparse.block_diag可以一次性构建新的稀疏A_。经过一些测试,它看起来像scipy.sparse.block_diag(A)非常慢,并且使用了大量内存。我可能读错了,但你是否建议我生成并解决一个300000×300000的密集矩阵系统,这是一个巨大而沉重的问题,只是为了避免一些python函数调用开销?嗯,这些代码似乎不适合我。我对f2py或fortran不太了解,但这对我来说确实有用:(基本上,我在调用dgesv时需要另一个参数,并且我假设A具有形状(N,3,3)就像你的问题一样。@jorgeca:f2py处理C/fortran维度排序不好-它的默认行为是假设你希望它制作一个昂贵的副本来保留索引方案而不是实际的内存布局。为了避免这种情况,必须首先转换传递给f2py的每个numpy数组(此转置实际上不会导致复制,而是用于防止复制)。我上面的示例假设您这样做,即您从python中调用x=pixsolve(a.T,b.T).T。避免f2py复制会使速度提高约20%。但它应该是
a.transpose((1,2,0))
而不是
a.T
(否则您将解决系统的转置问题)。
A.transpose((1,2,0))
会涉及到复制,不是吗?你能详细说明我将如何用
A.t
解决我的系统的转置问题吗?实际上,f2py在任何地方都隐式插入
.t
,而我插入的
.t
只是抵消了这些。我认为它不是在复制(使用
.base
np进行检查。可以共享内存
)。使用
A.T
A.transpose((1,2,0))
调用pixsolve的速度差很小(5%,而不是超过20%),我同意基准测试对于这类事情很重要。我用sa设置了
A
b