Python sklearn中GridSearchCV中的记分员问题

Python sklearn中GridSearchCV中的记分员问题,python,scikit-learn,Python,Scikit Learn,我正在尝试在RF分类器上执行网格搜索,其中评分函数是sklearn.metrics模块的precision_score。这是代码 from sklearn.metrics import precision_score param_grid = {'n_estimators': [51, 101, 201, 301, 501], 'max_depth': [3, 5, 10, None], 'min_samples_split': [2,

我正在尝试在RF分类器上执行网格搜索,其中评分函数是sklearn.metrics模块的precision_score。这是代码

from sklearn.metrics import precision_score

param_grid = {'n_estimators': [51, 101, 201, 301, 501],
              'max_depth': [3, 5, 10, None],
              'min_samples_split': [2, 5, 10],
              'criterion': ['gini', 'entropy'],
              'bootstrap': [True, False]}

def fit_gridCV_RFclassifier(param_grid):
    from sklearn.ensemble import RandomForestClassifier
    rf = RandomForestClassifier()
    clf = GridSearchCV(estimator=rf, param_grid=param_grid,
                       cv=5, scoring=precision_score,
                       refit=True)
    clf.fit(train_X, train_y)
    return clf

gridsearch_rf = fit_gridCV_RFclassifier(param_grid)
在运行该函数时,出现以下错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-34-6f91362a017c> in <module>()
----> 1 gridsearch_rf = fit_gridCV_RFclassifier(param_grid)

<ipython-input-33-974d026d5dc8> in fit_gridCV_RFclassifier(param_grid)
     11                        scoring=precision_score,
     12                        cv=5, refit=True)
---> 13     clf.fit(train_X, train_y)
     14     return clf

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in fit(self, X, y)
    594 
    595         """
--> 596         return self._fit(X, y, ParameterGrid(self.param_grid))
    597 
    598 

/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in _fit(self, X, y, parameter_iterable)
    376                                     train, test, self.verbose, parameters,
    377                                     self.fit_params, return_parameters=True)
--> 378             for parameters in parameter_iterable
    379             for train, test in cv)
    380 

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __call__(self, iterable)
    651             self._iterating = True
    652             for function, args, kwargs in iterable:
--> 653                 self.dispatch(function, args, kwargs)
    654 
    655             if pre_dispatch == "all" or n_jobs == 1:

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in dispatch(self, func, args, kwargs)
    398         """
    399         if self._pool is None:
--> 400             job = ImmediateApply(func, args, kwargs)
    401             index = len(self._jobs)
    402             if not _verbosity_filter(index, self.verbose):

/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in __init__(self, func, args, kwargs)
    136         # Don't delay the application, to avoid keeping the input
    137         # arguments in memory
--> 138         self.results = func(*args, **kwargs)
    139 
    140     def get(self):

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters)
   1238     else:
   1239         estimator.fit(X_train, y_train, **fit_params)
-> 1240     test_score = _score(estimator, X_test, y_test, scorer)
   1241     if return_train_score:
   1242         train_score = _score(estimator, X_train, y_train, scorer)

/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in _score(estimator, X_test, y_test, scorer)
   1294         score = scorer(estimator, X_test)
   1295     else:
-> 1296         score = scorer(estimator, X_test, y_test)
   1297     if not isinstance(score, numbers.Number):
   1298         raise ValueError("scoring must return a number, got %s (%s) instead."

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_score(y_true, y_pred, labels, pos_label, average, sample_weight)
   1883                                                  average=average,
   1884                                                  warn_for=('precision',),
-> 1885                                                  sample_weight=sample_weight)
   1886     return p
   1887 

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in precision_recall_fscore_support(y_true, y_pred, beta, labels, pos_label, average, warn_for, sample_weight)
   1667         raise ValueError("beta should be >0 in the F-beta score")
   1668 
-> 1669     y_type, y_true, y_pred = _check_clf_targets(y_true, y_pred)
   1670 
   1671     label_order = labels  # save this for later

/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in _check_clf_targets(y_true, y_pred)
    107     y_pred : array or indicator matrix
    108     """
--> 109     y_true, y_pred = check_arrays(y_true, y_pred, allow_lists=True)
    110     type_true = type_of_target(y_true)
    111     type_pred = type_of_target(y_pred)

/anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc in check_arrays(*arrays, **options)
    252         if size != n_samples:
    253             raise ValueError("Found array with dim %d. Expected %d"
--> 254                              % (size, n_samples))
    255 
    256         if not allow_lists or hasattr(array, "shape"):

ValueError: Found array with dim 317760. Expected 51
---------------------------------------------------------------------------
ValueError回溯(最近一次调用上次)
在()
---->1 gridsearch\u rf=fit\u gridCV\u RFclassifier(参数网格)
在fit_gridCV_RFclassifier(参数网格)中
11分=精度分数,
12 cv=5,重新安装=正确)
--->13 clf.配合(第X列、第y列)
14返回clf
/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in-fit(self,X,y)
594
595         """
-->596返回自拟合(X,y,参数网格(自参数网格))
597
598
/anaconda/lib/python2.7/site-packages/sklearn/grid_search.pyc in_fit(self,X,y,parameter_iterable)
376列车、试验、自详细、参数、,
377 self.fit_参数,返回_参数=真)
-->378用于参数_iterable中的参数
379用于列车,在cv中进行试验)
380
/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in_u_调用(self,iterable)
651自迭代=真
652对于iterable中的函数、参数和kwargs:
-->653自动调度(功能、参数、kwargs)
654
655如果预调度==“所有”或n个作业==1:
/调度中的anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc(self、func、args、kwargs)
398         """
399如果self.\u池为无:
-->400作业=立即应用(func、args、kwargs)
401索引=len(自作业)
402如果不是详细过滤器(索引,self.verbose):
/anaconda/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.pyc in_u________(self、func、args、kwargs)
136#不要延迟应用程序,以免保留输入
137#内存中的参数
-->138 self.results=func(*args,**kwargs)
139
140 def get(自我):
/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in_fit_和_分数(估计器、X、y、计分器、训练、测试、详细、参数、拟合参数、返回训练分数、返回参数)
1238其他:
1239估算器拟合(X_序列、y_序列、**拟合参数)
->1240测试分数=_分数(估计员、X测试、y测试、计分员)
1241如果返回列车评分:
1242训练分数=_分数(估计员、X训练、y训练、计分员)
/anaconda/lib/python2.7/site-packages/sklearn/cross_validation.pyc in_分数(估计器、X_测试、y_测试、计分器)
1294分=记分员(估计员,X_检验)
1295其他:
->1296分=记分员(估计员、X_检验、y_检验)
1297如果不存在(分数、数字、数字):
1298 raise VALUERROR(“评分必须返回一个数字,取而代之的是%s(%s)。”
/精度评分中的anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc(y_真、y_pred、标签、位置标签、平均值、样本重量)
1883平均值=平均值,
1884 warn_for=(‘精度’,),
->1885样品重量=样品重量)
1886返回p
1887
/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc的精确性、召回率、核心支持率(y_真、y_pred、beta、标签、位置标签、平均值、警告、样本重量)
1667 raise VALUE ERROR(“F-beta分数中的贝塔值应大于0”)
1668
->1669 y_type,y_true,y_pred=\u check\u clf\u targets(y_true,y_pred)
1670
1671 label_order=labels#将此保存以备以后使用
/anaconda/lib/python2.7/site-packages/sklearn/metrics/metrics.pyc in_check_clf_targets(y_true,y_pred)
107 y_pred:阵列或指示符矩阵
108     """
-->109 y_true,y_pred=检查_数组(y_true,y_pred,允许_列表=true)
110 type_true=_目标的类型(y_true)
111 type_pred=_目标的类型(y_pred)
/检查数组(*数组,**选项)中的anaconda/lib/python2.7/site-packages/sklearn/utils/validation.pyc
252如果尺寸!=n_样本:
253 raise VALUEMERROR(“找到的数组具有dim%d。应为%d)”
-->254%(尺寸,不含样品))
255
256如果不允许使用列表或hasattr(数组,“形状”):
ValueError:找到dim为317760的数组。应为51
似乎错误来自评分功能。如有任何帮助,将不胜感激。谢谢

My scikit学习版:0.15.2

评分参数取()

评分:字符串,可调用或无,可选,默认值:无

A string (see model evaluation documentation) or a scorer callable object / function with signature scorer(estimator, X, y).
“precision_score”函数具有不同的签名。您只需给出一个字符串,因为“precision”是内置指标之一():


你能用一个玩具数据集发布你的代码吗?现在很难说发生了什么
clf = GridSearchCV(estimator=rf, param_grid=param_grid,
                   cv=5, scoring="precision",
                   refit=True)