Python 用Sympy生成C代码。将Pow(x,2)替换为x*x

Python 用Sympy生成C代码。将Pow(x,2)替换为x*x,python,python-3.x,sympy,code-generation,codegen,Python,Python 3.x,Sympy,Code Generation,Codegen,我正在使用公共子表达式消除(CSE)例程和ccode打印机生成C代码 但是,我希望将增强表达式设置为(x*x),而不是pow(x,2) 不管怎样,你想这么做吗 例如: import sympy as sp a= sp.MatrixSymbol('a',3,3) b=sp.Matrix(a)*sp.Matrix(a) res = sp.cse(b) lines = [] for tmp in res[0]: lines.append(sp.ccode(tmp[1], tmp

我正在使用公共子表达式消除(CSE)例程和ccode打印机生成C代码

但是,我希望将增强表达式设置为(x*x),而不是pow(x,2)

不管怎样,你想这么做吗

例如:

import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)

res = sp.cse(b)

lines = []
     
for tmp in res[0]:
    lines.append(sp.ccode(tmp[1], tmp[0]))

for i,result in enumerate(res[1]):
    lines.append(sp.ccode(result,"result_%i"%i))
将输出:

x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;
致以最诚挚的问候

您可以将子类化,并且只更改一个您希望不同的函数。您需要进行调查,以找到正确的函数名和默认实现,以便确保不会出错。只要小心一点,就可以自动生成所需的括号,精确到所需的时间和位置

下面是一个简单的例子:

将sympy作为sp导入
从sympy.printing.c导入C99CodePrinter
从sympy.printing.preference导入优先级
从sympy.abc导入x
类别CustomCodePrinter(C99CodePrinter):
def_打印功率(自身,expr):
PREC=优先权(expr)
如果expr.exp==2:
返回'({0}*{0}').format(self.branchize(expr.base,PREC))
其他:
返回super()。\u print\u Pow(expr)
默认打印机=C99CodePrinter().doprint
custom_printer=CustomCodePrinter().doprint
表达式=[x,(2+x)**2,x**3,x**15,sp.sqrt(5),sp.sqrt(x)**4,1/x,1/(x*x)]
打印(“默认:{}”。格式(默认打印机(表达式)))
打印(“自定义:{}”。格式(自定义打印机(表达式)))
输出:

Default:[x,pow(x+2,2),pow(x,3),pow(x,15),sqrt(5),pow(x,2),1.0/x,pow(x,-2)]
自定义:[x,((x+2)*(x+2)),功率(x,3),功率(x,15),sqrt(5),(x*x),1.0/x,功率(x,-2)]
PS:为了支持更广泛的指数,您可以使用

class CustomCodePrinter(C99CodePrinter):
def_打印功率(自身,expr):
PREC=优先权(expr)
如果expr.exp在范围(2,7)内:
返回'*'.join([self.branchize(expr.base,PREC)]*int(expr.exp))
范围(-6,0)内的elif expr.exp:
返回'1.0/('+('*')join([self.branchize(expr.base,PREC)]*int(-expr.exp)))+''
其他:
返回super()。\u print\u Pow(expr)

我想我将采用用户功能方法:

正如上面评论中所建议的,我将使用sp.ccode的
用户功能
功能: 假设我们有一个像
a^3

sp.ccode(a**3,user_函数={'Pow':[(lambda x,y:y.is_整数,lambda x,y:'*''..join(['('+x+')]]*int(y)),(lambda x,y:not y.is_整数,'Pow'))

应输出:
”(a)*(a)*(a)

在将来,我将尝试改进该函数,以便在需要时只放入括号


欢迎任何改进

有一个名为
create\u expand\u pow\u optimization
的函数,它创建一个包装器来优化这方面的表达式。它将用显式乘法替换的最高幂作为参数

包装器返回一个
UnevaluatedExpr
,该值受到保护,不会自动简化,从而恢复此更改

import sympy as sp
from sympy.codegen.rewriting import create_expand_pow_optimization

expand_opt = create_expand_pow_optimization(3)

a = sp.Matrix(sp.MatrixSymbol('a',3,3))
res = sp.cse(a@a)

for i,result in enumerate(res[1]):
    print(sp.ccode(expand_opt(result),"result_%i"%i))

最后,请注意,对于足够高的优化级别,您的编译器将处理此问题(并且可能在这方面做得更好)。

也许这有帮助:该函数看起来很有趣!但是,我无法在
ccode
user\u函数中工作。此外,它似乎仅适用于symbol^integer的情况。如果我有:
sp.ccode(expand_opt(a+b)**2)
它会给我
'pow(a+b,2)
应该是
sp.ccode(expand_opt((a+b)**2))
,但即使这样它也不起作用。好像是个bug。我在上面创建的。你仍然需要检查负能量。例如,
1/a
a**(-2)
1/(a**2+1)
出现错误。对于您可能想要使用的高功率,这可能会导致大量的加速。(还要注意,大多数C编译器都会自动进行这种优化。)