Python 如何在数组的一部分调用scipy.minimize?

Python 如何在数组的一部分调用scipy.minimize?,python,numpy,scipy,scipy-optimize,scipy-optimize-minimize,Python,Numpy,Scipy,Scipy Optimize,Scipy Optimize Minimize,我正在努力获得scipy.minimize以获得一个优化参数,该参数是一个数组,其中我只查看目标函数中数组的一部分 import numpy as np from scipy.optimize import minimize n = 5 X_true = np.random.normal(size=(n,n)) X_guess = np.random.normal(size=(n,n)) indices = np.triu_indices(n) def mean_square_error(X

我正在努力获得
scipy.minimize
以获得一个优化参数,该参数是一个数组,其中我只查看目标函数中数组的一部分

import numpy as np
from scipy.optimize import minimize

n = 5
X_true = np.random.normal(size=(n,n))
X_guess = np.random.normal(size=(n,n))
indices = np.triu_indices(n)

def mean_square_error(X):
    return ((X.flatten() - X_true.flatten()) ** 2).mean()

def mean_square_error_over_indices(X):
    return ((X[indices].flatten() - X_true[indices].flatten()) ** 2).mean()

# works fine
print(mean_square_error(X_guess)) 

# works fine
print(mean_square_error_over_indices(X_guess)) 

# works fine (flatten is necessary inside the objective function)
print(minimize(mean_square_error, X_guess).x)

# IndexError
print(minimize(mean_square_error_over_indices, X_guess).x)
回溯:

IndexError                                Traceback (most recent call last)
<ipython-input-1-08d40604e22a> in <module>
     20 print(minimize(mean_square_error, X_guess).x) # works fine
     21 
---> 22 print(minimize(mean_square_error_over_indices, X_guess).x) # error

C:\Anaconda\lib\site-packages\scipy\optimize\_minimize.py in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    593         return _minimize_cg(fun, x0, args, jac, callback, **options)
    594     elif meth == 'bfgs':
--> 595         return _minimize_bfgs(fun, x0, args, jac, callback, **options)
    596     elif meth == 'newton-cg':
    597         return _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in _minimize_bfgs(fun, x0, args, jac, callback, gtol, norm, eps, maxiter, disp, return_all, **unknown_options)
    968     else:
    969         grad_calls, myfprime = wrap_function(fprime, args)
--> 970     gfk = myfprime(x0)
    971     k = 0
    972     N = len(x0)

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in function_wrapper(*wrapper_args)
    298     def function_wrapper(*wrapper_args):
    299         ncalls[0] += 1
--> 300         return function(*(wrapper_args + args))
    301 
    302     return ncalls, function_wrapper

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in approx_fprime(xk, f, epsilon, *args)
    728 
    729     """
--> 730     return _approx_fprime_helper(xk, f, epsilon, args=args)
    731 
    732 

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in _approx_fprime_helper(xk, f, epsilon, args, f0)
    662     """
    663     if f0 is None:
--> 664         f0 = f(*((xk,) + args))
    665     grad = numpy.zeros((len(xk),), float)
    666     ei = numpy.zeros((len(xk),), float)

C:\Anaconda\lib\site-packages\scipy\optimize\optimize.py in function_wrapper(*wrapper_args)
    298     def function_wrapper(*wrapper_args):
    299         ncalls[0] += 1
--> 300         return function(*(wrapper_args + args))
    301 
    302     return ncalls, function_wrapper

<ipython-input-1-08d40604e22a> in mean_square_error_over_indices(X)
     11 
     12 def mean_square_error_over_indices(X):
---> 13     return ((X[indices].flatten() - X_true[indices].flatten()) ** 2).mean()
     14 
     15 

IndexError: too many indices for array
索引器错误回溯(最近一次调用)
在里面
20打印(最小化(均方误差,X_猜测)。X)#效果良好
21
--->22打印(最小化(指数上的均方误差,X_猜测)。X)#误差
最小化中的C:\Anaconda\lib\site packages\scipy\optimize\\u minimize.py(fun、x0、args、method、jac、hess、hessp、bounds、constraints、tol、callback、options)
593返回最小化cg(乐趣、x0、参数、jac、回调,**选项)
594 elif meth==“bfgs”:
-->595返回_最小化_bfgs(乐趣、x0、参数、jac、回调,**选项)
596 elif meth==‘牛顿重心’:
597返回_最小化_newtoncg(fun、x0、args、jac、hess、hessp、callback、,
C:\Anaconda\lib\site packages\scipy\optimize\optimize.py in_minimize_bfgs(fun、x0、args、jac、callback、gtol、norm、eps、maxiter、disp、return_all、**未知选项)
968其他:
969 grad_调用,myfprime=wrap_函数(fprime,args)
-->970 gfk=myfprime(x0)
971K=0
972 N=len(x0)
C:\Anaconda\lib\site packages\scipy\optimize\optimize.py在函数\u wrapper(*wrapper\u args)中
298 def函数包装器(*包装器参数):
299 nCall[0]+=1
-->300返回函数(*(包装器参数+参数))
301
302返回NCALL,函数包装器
C:\Anaconda\lib\site packages\scipy\optimize\optimize.py,单位约为(xk,f,epsilon,*args)
728
729     """
-->730返回近似优先辅助对象(xk,f,ε,args=args)
731
732
C:\Anaconda\lib\site packages\scipy\optimize\optimize.py在\u近似\u fprime\u帮助程序中(xk,f,epsilon,args,f0)
662     """
663如果f0为无:
-->664 f0=f(*(xk,)+args))
665梯度=整数零((len(xk),浮点)
666 ei=整数零((len(xk),浮点)
C:\Anaconda\lib\site packages\scipy\optimize\optimize.py在函数\u wrapper(*wrapper\u args)中
298 def函数包装器(*包装器参数):
299 nCall[0]+=1
-->300返回函数(*(包装器参数+参数))
301
302返回NCALL,函数包装器
指数(X)上的均方误差
11
指数(X)上的12 def均方误差:
--->13返回((X[index].flatte()-X_true[index].flatte())**2).mean()
14
15
索引器:数组的索引太多
基于
scipy.optimize.minimize
接受1d数组,因此您使用“展平()”是正确的,但您也应该将其用于传递到minimize()的初始猜测。以下是我解决问题的建议:

将numpy导入为np
从scipy.optimize导入最小化
#初始化
n=5
x_true=np.random.normal(大小=(n,n))
x_guess=np.random.normal(大小=(n,n))
指数=np.triu_指数(n)
#展平初始值以最小化
guess\u x0=x\u guess.flatten()
guess\u indices\u x0=x\u guess[index].flatten()
#定义目标函数
mse=lambda x:((x-x_true.flatte())**2.mean()
指数上方的mse_=lambda x:((x-x_真[指数].flatte())**2).mean()
#很好
打印(“MSE:%5f”%MSE(猜测)
打印(“索引的MSE:%5f”%MSE超过索引(猜测索引x0))
#工作正常(目标函数内部需要展平)
打印(“结果1:”,最小化(mse,guess_x0).x)
打印(“结果2:”,最小化(mse超过指数,猜测指数)。x)
输出:

MSE: 2.763674
MSE for indices: 3.192139
Result 1: [-1.2828193   0.49468516 -0.99500157 -0.47284983  1.6380719  -0.33051017
  0.13769163 -0.23920633 -0.87430572  0.63945803  1.38327467  0.8484247
  0.31888506 -1.15764468  1.06891773 -0.28372002  1.34104286  1.21024251
 -0.11020374  1.37024001  1.08940389  1.82391261  0.32469148  0.64567877
  0.54364199]
Result 2: [-1.28281964  0.49468503 -0.99500147 -0.47284976  1.63807209  0.13769154
 -0.23920624 -0.87430606  0.63945812  0.31888521 -1.15764475  1.06891776
 -0.11020373  1.37024006  0.54364213]

您是否检查了函数从
minimize
获得的
X
的形状?谢谢,这帮助我解决了问题!但是,我最大的错误是
X[index]
是在
索引中处理
X
的错误方法,因此我问了一个错误的问题!它们需要解包,例如
X,y=np.转置(索引)
,否则代码会解释得很奇怪。如果你能在答案中提到这一点,我很乐意勾选它。我不确定,为了编辑问题,你到底指的是什么。欢迎你编辑答案并添加缺少的部分。