Python 带Sympy的多项式根下误差
我创建了一个函数,在给定系数范围的情况下,用这些系数构造多项式,并输出其所有根的列表。然而,Numba不喜欢它。是这样的:Python 带Sympy的多项式根下误差,python,sympy,numba,Python,Sympy,Numba,我创建了一个函数,在给定系数范围的情况下,用这些系数构造多项式,并输出其所有根的列表。然而,Numba不喜欢它。是这样的: import math import numpy as np import itertools from numba import jit from sympy.solvers import solve from sympy import Symbol from sympy import Poly @jit def polyn(ranges=[[-20,20],[-20,
import math
import numpy as np
import itertools
from numba import jit
from sympy.solvers import solve
from sympy import Symbol
from sympy import Poly
@jit
def polyn(ranges=[[-20,20],[-20,20],[-20,20],[-20,20]],step=4):
l = []
x = Symbol('x')
rangl = [np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step)) for i in ranges]
coeffl = iter(itertools.product(*rangl))
leng = 1
for i in rangl:
leng *= len(i)
for i in range(0, leng):
a = solve(Poly(list(next(coeffl)),x),x)
for j in a:
l.append(j)
return np.array(l)
当我试着运行它时,它会输出一个神秘的:
AssertionError:在对象处失败(对象模式前端)
我不明白。。。有人能帮忙吗?你的代码中有很多东西是Numba目前无法处理的。第一个是您在其中构建的列表理解
rangl
:
[np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step)) for i in ranges]
您应该将其替换为NumPy解决方案,如:
rangl = np.empty((len(ranges), step))
for i in ranges:
rangl[i] = np.linspace(i[0],i[1],math.floor((i[1]-i[0])/step))
第二件Numba无法处理的事情是itertools.product。您也可以用NumPy和for循环替换它
一般来说,试着通过注释代码的下半部分来减少代码,直到Numba接受它,然后自上而下地工作,看看哪些部分无法编译。要有条不紊,一步一步地进行,并尝试对循环和数组使用简单的结构,如simple
。Numba将无法加快SymPy代码的速度。如果这是一个瓶颈,您可以尝试使用数值解算器。另一件要尝试的事情是解一个一般的三次方(带有符号系数),并插入一般解的值。不,我知道-但是它不会加快在一个大数组上的这种运算的迭代吗?(sympy解决)我怀疑它是否会,除非你用它与nogil选项并行运行。谢谢!您如何建议将itertools.product替换为任意数量的数组?@IskyMathews:使其适用于固定数量,然后进行推广。您可以从itertools实现中获得灵感(甚至可能是一些代码)。