Python 3.x 如何在从头开始实现K近邻的同时修复代码中的KeyError错误?

Python 3.x 如何在从头开始实现K近邻的同时修复代码中的KeyError错误?,python-3.x,machine-learning,data-science,knn,Python 3.x,Machine Learning,Data Science,Knn,我正在尝试用Python从头开始实现K-最近邻算法。我编写的代码在威斯康星州乳腺癌数据集上运行良好。 然而,当我尝试运行Iris.csv数据集时,我的实现失败,并给出了KeyError 这两个数据集的唯一区别在于,威斯康星州乳腺癌.csv中只有两个类别('2'表示恶性,4'表示良性),两个标签都是整数,Iris.csv中有3个类别('setosa','versicolor','virginica'),所有这3个标签都是字符串类型 这是我写的代码(用于Iris.csv): 当我运行上面的代码时,我

我正在尝试用Python从头开始实现K-最近邻算法。我编写的代码在威斯康星州乳腺癌数据集
上运行良好。
然而,当我尝试运行Iris.csv数据集时,我的实现失败,并给出了KeyError

这两个数据集的唯一区别在于,威斯康星州乳腺癌.csv中只有两个类别('2'表示恶性,4'表示良性),两个标签都是整数,
Iris.csv中有3个类别('setosa','versicolor','virginica'),所有这3个标签都是字符串类型

这是我写的代码(用于
Iris.csv
):

当我运行上面的代码时,我在第49行得到一条
KeyError
消息

谁能告诉我哪里出了问题?另外,如果有人能指出我该如何修改这个算法,以便在将来对多个类(而不是2或3个)进行分类,那就太好了

另外,如果类是字符串类型而不是整数类型,如何处理

我想到的一个解决方案是将所有字符串类型转换为整数类型,然后尝试求解,但这样行吗

参考资料


让我们从你的最后一个问题开始:

我想到的一个解决方案是将所有字符串类型转换为整数类型,然后尝试求解,但这样行吗

是的,那会有用的。您不必硬编码代码中每个问题的所有类的名称。相反,您只需编写一个函数,读取class属性的所有不同值,并为每个不同值指定一个数值

谁能告诉我哪里出了问题

最有可能的问题是,您正在读取的实例的class属性不是
'setosa',versicolor',virginica'
(可能类似于
Iris setosa
)。上面的想法应该可以解决这个问题

另外,如果有人能指出我该如何修改这个算法,以便在将来对多个类(而不是2或3个)进行分类,那就太好了

正如前面所讨论的,您只需要避免在代码中硬编码类的名称

另外,如果类是字符串类型而不是整数类型,如何处理


像这样的函数将返回所有类(无论类型)和数字代码(从0到N-1)之间的映射。使用此映射还可以解决前面提到的所有问题。

将CSV文件中的字符串标签转换为整数标签

在经历了一些GitHub回购之后,我遇到了一段非常简单但优雅的代码,它解决了上述问题。希望它能帮助那些以前遇到过这个问题的人(尤其是初学者!)

调试后

事实证明,我们也不需要使用上面的代码,也就是说,我可以在不显式地将字符串标签转换为整数标签的情况下得到答案(使用上面的代码)

我已经发布了一些小改动后的原始代码(如下),关键错误现在已经修复。此外,我现在获得了97%到100%的准确率(仅在IRIS数据集上)

这是唯一需要对我发布的原始代码进行的更改,以使其正常工作!!简单

但是,请注意,数字必须以整数而不是字符串形式给出(否则会导致键错误!)

总结

在原始代码中有一些注释行,我认为如果有人遇到一些问题,最好解释一下。下面是一个删除了注释的代码段(与问题中的原始代码进行比较)

以下是您得到的输出:

ValueError:无法将字符串转换为浮点:“virginica”

出了什么问题

注意,这里我们没有将字符串标签转换为整数标签。因此,当我们试图将CSV中的数据转换为浮点值时,内核抛出了一个错误,因为字符串无法转换为浮点值

因此,一种方法是不将数据转换为浮点值,这样就不会出现此错误。但是,在许多情况下,您需要将所有数据转换为浮点(例如,归一化、精度、长时间数学计算、防止精度损失等)

因此,经过大量调试和阅读大量文章后,我终于找到了原始代码的简单版本(如下所示):


希望这有帮助

是的,那会有用的。谢谢你的回答。然而,我发现了一个更简单的解决方案,并将其发布为下面的答案。不,我在算法中使用CSV文件之前检查了它。所有实例都命名为“Setosa”、“Virginica”、“Versicolor”。因此,错误不是因为iris setosa或不同名称的实例。事实上,这个错误是一个微妙的错误,我已经在下面发布了解决方案,并尽我所能对其进行了解释。请务必让我知道我的答案是否可以改进。
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
from matplotlib import style
from collections import Counter
import warnings
import pandas as pd
import random

style.use('fivethirtyeight')

dataset = {'k':[[1,2],[2,3],[3,1]], 'r':[[6,5],[7,7],[8,6]]}
new_features = [5,7]

#[[plt.scatter(j[0],j[1], s=100, color=i) for j in dataset[i]] for i in dataset]
#plt.scatter(new_features[0], new_features[1], s=100)
#plt.show()

def k_nearest_neighbors(data, predict, k=3):
    if len(data) >= k:
        warnings.warn('K is set to a value less than total voting groups!')

    distances = []

    for group in data:
        for features in data[group]:
            euclidean_distance = np.linalg.norm(np.array(features) - np.array(predict))
            distances.append([euclidean_distance, group])

    votes = [i[1] for i in sorted(distances)[:k]]
    vote_result = Counter(votes).most_common(1)[0][0]

    return vote_result

df = pd.read_csv('iris.csv')
df.replace('?', -99999, inplace=True)

#full_data = df.astype(float).values.tolist()
#random.shuffle(full_data)

test_size = 0.2
train_set = {'setosa':[], 'versicolor':[], 'virginica':[]}
test_set = {'setosa':[], 'versicolor':[], 'virginica':[]}

train_data = full_data[:-int(test_size*len(full_data))]
test_data = full_data[-int(test_size*len(full_data)):]

for i in train_data:
    train_set[i[-1]].append(i[:-1])

for i in test_data:
    test_set[i[-1]].append(i[:-1])

correct = 0
total = 0

for group in test_set:
    for data in test_set[group]:
        vote = k_nearest_neighbors(train_set, data, k=5)
        if group == vote:
            correct += 1
        total += 1

print('Accuracy : ', correct/total)
def get_class_values(data):

    classes_seen = {}
    for i in data:
       _class = data[-1]
       if _class not in classes_seen:
           classes_seen[_class] = len(classes_seen)

    return classes_seen
% read the csv file
df = pd.read_csv('iris.csv')

% clean the data file
df.replace('?', -99999, inplace=True)

% convert the string classes into integer types.
% integers are assigned from 0 to N-1.
% species is the name of the column which has class labels.

df['species'] = df['species'].astype('category')
df['species_value'] = df['species'].cat.codes
df.drop(['species'], 1, inplace=True)

% convert the data frame to list
full_data = df.astype(float).values.tolist()
random.shuffle(full_data)
test_size = 0.2
train_set = {0:[], 1:[], 2:[]}
test_set = {0:[], 1:[], 2:[]}
df = pd.read_csv('iris.csv')
df.replace('?', -99999, inplace=True)

full_data = df.astype(float).values.tolist()
random.shuffle(full_data)
import numpy as np
from math import sqrt
import matplotlib.pyplot as plt
from matplotlib import style
from collections import Counter
import warnings
import pandas as pd
import random

def k_nearest_neighbors(data, predict, k=3):
    if len(data) >= k:
        warnings.warn('K is set to a value less than total voting groups!')

    distances = []

    for group in data:
        for features in data[group]:
            euclidean_distance = np.linalg.norm(np.array(features) - np.array(predict))
            distances.append([euclidean_distance, group])

    votes = [i[1] for i in sorted(distances)[:k]]
    vote_result = Counter(votes).most_common(1)[0][0]

    return vote_result

df = pd.read_csv('iris.csv')
df.replace('?', -99999, inplace=True)

df['species'] = df['species'].astype('category')
df['species_value'] = df['species'].cat.codes
df.drop(['species'], 1, inplace=True)

full_data = df.astype(float).values.tolist()
random.shuffle(full_data)

test_size = 0.2
train_set = {0:[], 1:[], 2:[]}
test_set = {0:[], 1:[], 2:[]}

train_data = full_data[:-int(test_size*len(full_data))]
test_data = full_data[-int(test_size*len(full_data)):]

for i in train_data:
    train_set[i[-1]].append(i[:-1])

for i in test_data:
    test_set[i[-1]].append(i[:-1])

correct = 0
total = 0

for group in test_set:
    for data in test_set[group]:
        vote = k_nearest_neighbors(train_set, data, k=5)
        if group == vote:
            correct += 1
        total += 1

print('Accuracy : ', (correct/total)*100,'%')