Python 在scikit学习中为k-nn使用用户定义的距离度量
我有以下代码:Python 在scikit学习中为k-nn使用用户定义的距离度量,python,scikit-learn,Python,Scikit Learn,我有以下代码: import pandas as pd import numpy as np import matplotlib.pyplot as plt import sklearn.neighbors as ng def mydist(x, y): return np.sum((x-y)**2) if __name__ == '__main__': nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='bal
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sklearn.neighbors as ng
def mydist(x, y):
return np.sum((x-y)**2)
if __name__ == '__main__':
nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',metric='mydist')
我正在使用sci工具包学习0.18.1,我得到了这个错误
ValueError: Metric 'mydist' not valid for algorithm 'ball_tree'
我还尝试使用algorithm='brute',但错误仍然存在
这是什么原因造成的?如何正确使用用户定义的距离度量?以下是
ball_树算法的有效度量列表-scikit learn
在内部检查指定度量是否在其中:
In [114]: from sklearn.neighbors import BallTree
In [115]: BallTree.valid_metrics
Out[115]:
['euclidean',
'l2',
'minkowski',
'p',
'manhattan',
'cityblock',
'l1',
'chebyshev',
'infinity',
'seuclidean',
'mahalanobis',
'wminkowski',
'hamming',
'canberra',
'braycurtis',
'matching',
'jaccard',
'dice',
'kulsinski',
'rogerstanimoto',
'russellrao',
'sokalmichener',
'sokalsneath',
'haversine',
'pyfunc'] # <--- NOTE
knn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',
metric='pyfunc', metric_params={"func":mydist})