Python 在scikit学习中为k-nn使用用户定义的距离度量

Python 在scikit学习中为k-nn使用用户定义的距离度量,python,scikit-learn,Python,Scikit Learn,我有以下代码: import pandas as pd import numpy as np import matplotlib.pyplot as plt import sklearn.neighbors as ng def mydist(x, y): return np.sum((x-y)**2) if __name__ == '__main__': nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='bal

我有以下代码:

import pandas as pd
import numpy as np 
import matplotlib.pyplot as plt 
import sklearn.neighbors as ng 

def mydist(x, y):
    return np.sum((x-y)**2)

if __name__ == '__main__':
    nn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',metric='mydist')
我正在使用sci工具包学习0.18.1,我得到了这个错误

ValueError: Metric 'mydist' not valid for algorithm 'ball_tree'
我还尝试使用algorithm='brute',但错误仍然存在


这是什么原因造成的?如何正确使用用户定义的距离度量?

以下是
ball_树
算法的有效度量列表-
scikit learn
在内部检查指定度量是否在其中:

In [114]: from sklearn.neighbors import BallTree

In [115]: BallTree.valid_metrics
Out[115]:
['euclidean',
 'l2',
 'minkowski',
 'p',
 'manhattan',
 'cityblock',
 'l1',
 'chebyshev',
 'infinity',
 'seuclidean',
 'mahalanobis',
 'wminkowski',
 'hamming',
 'canberra',
 'braycurtis',
 'matching',
 'jaccard',
 'dice',
 'kulsinski',
 'rogerstanimoto',
 'russellrao',
 'sokalmichener',
 'sokalsneath',
 'haversine',
 'pyfunc']       # <--- NOTE
knn = ng.KNeighborsClassifier(n_neighbors=3, algorithm='ball_tree',
                              metric='pyfunc', metric_params={"func":mydist})