Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/328.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
设计矩阵函数的Matlab到Python转换_Python_Matlab_Numpy - Fatal编程技术网

设计矩阵函数的Matlab到Python转换

设计矩阵函数的Matlab到Python转换,python,matlab,numpy,Python,Matlab,Numpy,去年,我在Matlab中为线性回归程序中的设计矩阵编写了一个代码。它很好用。现在,我需要将其转换为Python并在Pycharm中运行。我已经做了好几天了,虽然我对Python真的很陌生,但我在翻译中找不到任何错误,但在代码与程序的其余部分一起运行时,我遇到了一个错误 matlab中的代码: function DesignMatrix = design_matrix( xTrain, M ) % This function calculates the Design Matrix for % a

去年,我在Matlab中为线性回归程序中的设计矩阵编写了一个代码。它很好用。现在,我需要将其转换为Python并在Pycharm中运行。我已经做了好几天了,虽然我对Python真的很陌生,但我在翻译中找不到任何错误,但在代码与程序的其余部分一起运行时,我遇到了一个错误

matlab中的代码:

function DesignMatrix = design_matrix( xTrain, M )
% This function calculates the Design Matrix for
% a M-th degree polynomial
% xTrain - training set Nx1
% M - polynomial degree 0,1,2,...

N = size(xTrain,1);
DesignMatrix = zeros(N,M+1); 
for i=1:M+1
  DesignMatrix(:,i)=xTrain.^(i-1)
end
end
我的Python翻译(np代表numpy,它是导入的):

错误指向这一行:
desm[:,i]=np.power(x_列,(i-1))
,这是一个值错误。我尝试使用在线翻译ompc,但它似乎已经过时,因为它不适合我。如果我的翻译中有任何明显的错误,谁能给我解释一下吗?我知道这是一个更大程序的一部分,但我要问的只是语法翻译本身。如果它是正确的,我将尝试找出任何其他错误,尽管我到目前为止没有发现任何错误。多谢各位

编辑:回溯

ERROR: test_design_matrix (test.TestDesignMatrix)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "...\test.py", line 61, in test_design_matrix
    dm_computed = design_matrix(x_train, M)
  File "...\content.py", line 34, in design_matrix
    desm[:,i] = np.power(x_train, (i-1))
ValueError: could not broadcast input array from shape (20,1) into shape (20)
我无法更改test.py文件,它是提供给我的,无法更改,因此我只依赖第二个错误

从给出错误的函数的test.py中摘录:

def test_design_matrix(self):
    x_train = TEST_DATA['design_matrix']['x_train']
    M = TEST_DATA['design_matrix']['M']
    dm = TEST_DATA['design_matrix']['dm']
    dm_computed = design_matrix(x_train, M)
    max_diff = np.max(np.abs(dm - dm_computed))
    self.assertAlmostEqual(max_diff, 0, 8)
我看到三个错误:

  • 在Python中,索引从零开始

  • 要为阵列的所有项目供电,可以使用
    **
    操作符

  • pass
    不执行任何操作,因为它放在
    return
    语句之后。函数从未达到这一点

我想试试这个:

def design_matrix(x_train,M):
    '''
    :param x_train: input vector Nx1
    :param M: polynomial degree 0,1,2,...
    :return: Design Matrix Nx(M+1) for M degree polynomial
    '''
    desm = np.zeros(shape =(len(x_train), M+1))
    for i in range(0, M+1):
        desm[:,i] = x_train.squeeze() ** (i-1)
    return desm
你能试试这个吗

def design_matrix(x_train,M):
    '''
    :param x_train: input vector Nx1
    :param M: polynomial degree 0,1,2,...
    :return: Design Matrix Nx(M+1) for M degree polynomial
    '''
    x_train = np.asarray(x_train)
    desm = np.zeros(shape =(len(x_train), M+1))
    for i in range(0, M+1):
        desm[:,i] = np.power(x_train, i).reshape(x_train.shape[0],)
    return desm

错误来自不兼容的Numpy数组维度。desm[:,i]具有形状(n,),但您试图存储到它的值具有形状(n,1),因此您需要将其重塑为(n,)。此外,正如GLR所提到的,Python索引从0开始,因此您需要修改索引,函数执行在返回行停止,因此根本达不到传递行。

您可能有兴趣知道,您可以使用patsy语言和模块为多项式回归创建正交设计矩阵

>>> import numpy as np
>>> from patsy import dmatrices, dmatrix, demo_data, Poly
>>> data = demo_data("a", "b", "x1", "x2", "y", "z column")
>>> dmatrix('C(x1, Poly)', data)
DesignMatrix with shape (8, 8)
Columns:
['Intercept', 'C(x1, Poly).Linear', 'C(x1, Poly).Quadratic', 'C(x1, Poly).Cubic', 'C(x1, Poly)^4', 'C(x1, Poly)^5', 'C(x1, Poly)^6', 'C(x1, Poly)^7']
Terms:
'Intercept' (column 0), 'C(x1, Poly)' (columns 1:8)
(to view full data, use np.asarray(this_obj))
>>> dm = dmatrix('C(x1, Poly)', data)
>>> np.asarray(dm)
array([[ 1.        ,  0.23145502, -0.23145502, -0.43082022, -0.12087344,
         0.36376642,  0.55391171,  0.35846409],
       [ 1.        , -0.23145502, -0.23145502,  0.43082022, -0.12087344,
        -0.36376642,  0.55391171, -0.35846409],
       [ 1.        ,  0.07715167, -0.38575837, -0.18463724,  0.36262033,
         0.32097037, -0.30772873, -0.59744015],
       [ 1.        ,  0.54006172,  0.54006172,  0.43082022,  0.28203804,
         0.14978617,  0.06154575,  0.01706972],
       [ 1.        ,  0.38575837,  0.07715167, -0.30772873, -0.52378493,
        -0.49215457, -0.30772873, -0.11948803],
       [ 1.        , -0.54006172,  0.54006172, -0.43082022,  0.28203804,
        -0.14978617,  0.06154575, -0.01706972],
       [ 1.        , -0.07715167, -0.38575837,  0.18463724,  0.36262033,
        -0.32097037, -0.30772873,  0.59744015],
       [ 1.        , -0.38575837,  0.07715167,  0.30772873, -0.52378493,
         0.49215457, -0.30772873,  0.11948803]])

你能添加回溯以便我们能看到有关错误的更多详细信息吗?当然,刚刚添加。你能添加test\u design\u matrix的代码块以便我们能看到你是如何调用design\u matrix的吗?现在就在上面。我试过了,但得到的错误与之前相同。在接下来的尝试中,我将保持pass被删除,索引为零。请尝试编辑。问题是x_列是一个矩阵(它有两个维度),当你做desm[:,i]时,你正在访问一个数组。为了消除不必要的维度,您可以使用
压缩
。不幸的是,仍然不起作用,我尝试了它,但现在它说这是一个失败而不是错误,我得到了以下回溯:文件“…\test.py”,第63行,在test\u design\u matrix self.assertAlmostEqual(max\u diff,0,8)断言错误:22357.537052901051!=8个位置内为0。我猜这是一个计算错误?这个错误是因为
max_diff
与零非常不同。我会尝试在一个交互式Python会话中加载
dm
(在
test.py
中),看看那里发生了什么。对不起,我对Python真的很陌生。。。您是说问题会出现在该函数的test.py文件中吗?我开始认为可能是这样的。我试过了,并且理解了变化,但我还是得到了和以前一样的错误:(嗯。所以你仍然得到了ValueError:无法将输入数组从形状(20,1)广播到形状(20)?是的,就是这个。你能试一下desm[:,I]=np.power(x_-train,I)。重塑(x_-train.shape[0],1)吗唉,什么都没有改变。仍然是相同的错误,只有相同的值,(20,1)和(20)
>>> import numpy as np
>>> from patsy import dmatrices, dmatrix, demo_data, Poly
>>> data = demo_data("a", "b", "x1", "x2", "y", "z column")
>>> dmatrix('C(x1, Poly)', data)
DesignMatrix with shape (8, 8)
Columns:
['Intercept', 'C(x1, Poly).Linear', 'C(x1, Poly).Quadratic', 'C(x1, Poly).Cubic', 'C(x1, Poly)^4', 'C(x1, Poly)^5', 'C(x1, Poly)^6', 'C(x1, Poly)^7']
Terms:
'Intercept' (column 0), 'C(x1, Poly)' (columns 1:8)
(to view full data, use np.asarray(this_obj))
>>> dm = dmatrix('C(x1, Poly)', data)
>>> np.asarray(dm)
array([[ 1.        ,  0.23145502, -0.23145502, -0.43082022, -0.12087344,
         0.36376642,  0.55391171,  0.35846409],
       [ 1.        , -0.23145502, -0.23145502,  0.43082022, -0.12087344,
        -0.36376642,  0.55391171, -0.35846409],
       [ 1.        ,  0.07715167, -0.38575837, -0.18463724,  0.36262033,
         0.32097037, -0.30772873, -0.59744015],
       [ 1.        ,  0.54006172,  0.54006172,  0.43082022,  0.28203804,
         0.14978617,  0.06154575,  0.01706972],
       [ 1.        ,  0.38575837,  0.07715167, -0.30772873, -0.52378493,
        -0.49215457, -0.30772873, -0.11948803],
       [ 1.        , -0.54006172,  0.54006172, -0.43082022,  0.28203804,
        -0.14978617,  0.06154575, -0.01706972],
       [ 1.        , -0.07715167, -0.38575837,  0.18463724,  0.36262033,
        -0.32097037, -0.30772873,  0.59744015],
       [ 1.        , -0.38575837,  0.07715167,  0.30772873, -0.52378493,
         0.49215457, -0.30772873,  0.11948803]])