Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/347.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 不同于交叉验证和迭代Kfold的RMSE_Python_Scikit Learn_Cross Validation_K Fold - Fatal编程技术网

Python 不同于交叉验证和迭代Kfold的RMSE

Python 不同于交叉验证和迭代Kfold的RMSE,python,scikit-learn,cross-validation,k-fold,Python,Scikit Learn,Cross Validation,K Fold,我想为交叉验证编写自己的函数,因为在这种情况下我不能使用交叉验证 如果我错了,请更正我的交叉验证代码: cv = cross_validate(elastic.est,X,y,cv=5,scoring='neg_mean_squared_error') 输出: {'fit_time': array([3.90563273, 5.272861 , 2.19111824, 6.42427135, 5.62084389]), 'score_time': array([0.05504966, 0.

我想为交叉验证编写自己的函数,因为在这种情况下我不能使用交叉验证

如果我错了,请更正我的交叉验证代码:

cv = cross_validate(elastic.est,X,y,cv=5,scoring='neg_mean_squared_error')
输出:

{'fit_time': array([3.90563273, 5.272861  , 2.19111824, 6.42427135, 5.62084389]),
 'score_time': array([0.05504966, 0.06105542, 0.0530467 , 0.06006551, 0.05603933]),
 'test_score': array([-0.00942235, -0.01220626, -0.01157624, -0.00998556, -0.01144867])}
我这样做是为了计算RMSE

math.sqrt(abs(cv["test_score"]).mean())
结果总是在0.104左右

然后我编写了下面的函数来循环kFolds,我总是得到一个更低的RMSE分数,它的运行速度大约快10倍

def get_rmse(y_true,y_pred):    
    score = math.sqrt(((y_pred-y_true) ** 2).mean())
    return score

listval=[]

kf = KFold(n_splits=5,shuffle=True)

for train_index, test_index in kf.split(X,y):

    Xx = np.array(X)
    yy = np.array(y)

    X_train, X_test = Xx[train_index], Xx[test_index]
    y_train, y_test = yy[train_index], yy[test_index]

    elastic.est.fit(X_train,y_train)
    preds = elastic.est.predict(X_test)
    listval.append(get_rmse(y_test,preds))

np.mean(listval)
结果为0.0729,并始终在该值附近着陆


我错过了什么?相同的数据、相同的esitmator、相同的折叠量?

您观察到的差异来自于这样一个事实,即您计算最终数字的方式不同:

对于cross_validate(交叉验证)输出,您首先对折叠的MSE求平均值,然后取平方根。 对于自定义实现,首先取根,然后才取折叠的平均值。
当然,在一般情况下,平均值的根不等于根的平均值。

不确定它是否有效-请尝试发送KFold对象,而不是使用int 5作为cross_validate的cv参数