Warning: file_get_contents(/data/phpspider/zhask/data//catemap/9/java/322.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
kNN算法中k值的改变——Java_Java_Algorithm_Csv_Machine Learning_Knn - Fatal编程技术网

kNN算法中k值的改变——Java

kNN算法中k值的改变——Java,java,algorithm,csv,machine-learning,knn,Java,Algorithm,Csv,Machine Learning,Knn,我应用了KNN算法对手写数字进行分类。这些数字最初是矢量格式的8*8,然后拉伸形成矢量1*64 目前,我的代码应用kNN算法,但只使用k=1。在尝试了几件事情之后,我不完全确定如何更改k值,我一直被抛出错误。如果有人能帮助我朝着正确的方向前进,我将不胜感激。可以找到训练数据集和验证集 ImageMatrix.java import java.util.*; public class ImageMatrix { private int[] data; private int cl

我应用了KNN算法对手写数字进行分类。这些数字最初是矢量格式的8*8,然后拉伸形成矢量1*64

目前,我的代码应用kNN算法,但只使用k=1。在尝试了几件事情之后,我不完全确定如何更改k值,我一直被抛出错误。如果有人能帮助我朝着正确的方向前进,我将不胜感激。可以找到训练数据集和验证集

ImageMatrix.java

import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;
    private int curData;
public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }
    public int getCurData() {
        return curData;
    }



}
import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results 
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum);
    }


    public static int classify(ImageMatrixDB trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                label = matrix.getClassCode();
            }
        }
        return label;
    }


    public int size() {

        return list.size(); //returns size of the list

        }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB trainingSet = new ImageMatrixDB();
        ImageMatrixDB validationSet = new ImageMatrixDB();
        trainingSet.load("cw2DataSet1.csv");
        validationSet.load("cw2DataSet2.csv"); 
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        } //285 correct
        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%"); 
        System.out.println();
    }
ImageMatrixDB.java

import java.util.*;

public class ImageMatrix {
    private int[] data;
    private int classCode;
    private int curData;
public ImageMatrix(int[] data, int classCode) {
    assert data.length == 64; //maximum array length of 64
    this.data = data;
    this.classCode = classCode;
}

    public String toString() {
        return "Class Code: " + classCode + " Data :" + Arrays.toString(data) + "\n"; //outputs readable
    }

    public int[] getData() {
        return data;
    }

    public int getClassCode() {
        return classCode;
    }
    public int getCurData() {
        return curData;
    }



}
import java.util.*;
import java.io.*;
import java.util.ArrayList;
public class ImageMatrixDB implements Iterable<ImageMatrix> {
    private List<ImageMatrix> list = new ArrayList<ImageMatrix>();

    public ImageMatrixDB load(String f) throws IOException {
        try (
            FileReader fr = new FileReader(f);
            BufferedReader br = new BufferedReader(fr)) {
            String line = null;

            while((line = br.readLine()) != null) {
                int lastComma = line.lastIndexOf(',');
                int classCode = Integer.parseInt(line.substring(1 + lastComma));
                int[] data = Arrays.stream(line.substring(0, lastComma).split(","))
                                   .mapToInt(Integer::parseInt)
                                   .toArray();
                ImageMatrix matrix = new ImageMatrix(data, classCode); // Classcode->100% when 0 -> 0% when 1 - 9..
                list.add(matrix);
            }
        }
        return this;
    }

    public void printResults(){ //output results 
        for(ImageMatrix matrix: list){
            System.out.println(matrix);
        }
    }


    public Iterator<ImageMatrix> iterator() {
        return this.list.iterator();
    }

    /// kNN implementation ///
    public static int distance(int[] a, int[] b) {
        int sum = 0;
        for(int i = 0; i < a.length; i++) {
            sum += (a[i] - b[i]) * (a[i] - b[i]);
        }
        return (int)Math.sqrt(sum);
    }


    public static int classify(ImageMatrixDB trainingSet, int[] curData) {
        int label = 0, bestDistance = Integer.MAX_VALUE;
        for(ImageMatrix matrix: trainingSet) {
            int dist = distance(matrix.getData(), curData);
            if(dist < bestDistance) {
                bestDistance = dist;
                label = matrix.getClassCode();
            }
        }
        return label;
    }


    public int size() {

        return list.size(); //returns size of the list

        }


    public static void main(String[] argv) throws IOException {
        ImageMatrixDB trainingSet = new ImageMatrixDB();
        ImageMatrixDB validationSet = new ImageMatrixDB();
        trainingSet.load("cw2DataSet1.csv");
        validationSet.load("cw2DataSet2.csv"); 
        int numCorrect = 0;
        for(ImageMatrix matrix:validationSet) {
            if(classify(trainingSet, matrix.getData()) == matrix.getClassCode()) numCorrect++;
        } //285 correct
        System.out.println("Accuracy: " + (double)numCorrect / validationSet.size() * 100 + "%"); 
        System.out.println();
    }
import java.util.*;
导入java.io.*;
导入java.util.ArrayList;
公共类ImageMatrixDB实现了Iterable{
私有列表=新的ArrayList();
公共ImageMatrixDB加载(字符串f)引发IOException{
试一试(
FileReader fr=新的FileReader(f);
BufferedReader br=新的BufferedReader(fr)){
字符串行=null;
而((line=br.readLine())!=null){
int lastcoma=line.lastIndexOf(',');
int classCode=Integer.parseInt(line.substring(1+lastcoma));
int[]data=Arrays.stream(line.substring(0,lastcoma).split(“,”))
.mapToInt(整数::parseInt)
.toArray();
ImageMatrix矩阵=新的ImageMatrix(数据,类代码);//类代码->0时为100%->1-9时为0%。。
列表。添加(矩阵);
}
}
归还这个;
}
public void printResults(){//输出结果
用于(ImageMatrix矩阵:列表){
系统输出打印LN(矩阵);
}
}
公共迭代器迭代器(){
返回此.list.iterator();
}
///kNN实现///
公共静态整数距离(整数[]a,整数[]b){
整数和=0;
for(int i=0;i
在分类的for循环中,您试图找到最接近测试点的训练示例。您需要使用查找最接近测试数据的训练点中的K的代码来切换该示例。然后,您应该为这些K点中的每一个调用getClassCode并找到大多数(即最常见的)类别代码。分类将返回找到的主要类别代码

您可以根据自己的需要,以任何方式打破联系(即,为相同数量的培训数据分配2+最频繁的类别代码)

我在Java方面确实缺乏经验,但仅通过查看语言参考,我就想到了下面的实现

public static int classify(ImageMatrixDB trainingSet, int[] curData, int k) {
    int label = 0, bestDistance = Integer.MAX_VALUE;
    int[][] distances = new int[trainingSet.size()][2];
    int i=0;

    // Place distances in an array to be sorted
    for(ImageMatrix matrix: trainingSet) {
        distances[i][0] = distance(matrix.getData(), curData);
        distances[i][1] = matrix.getClassCode();
        i++;
    }

    Arrays.sort(distances, (int[] lhs, int[] rhs) -> lhs[0]-rhs[0]);

    // Find frequencies of each class code
    i = 0;
    Map<Integer,Integer> majorityMap;
    majorityMap = new HashMap<Integer,Integer>();
    while(i < k) {
        if( majorityMap.containsKey( distances[i][1] ) ) {
            int currentValue = majorityMap.get(distances[i][1]);
            majorityMap.put(distances[i][1], currentValue + 1);
        }
        else {
            majorityMap.put(distances[i][1], 1);
        }
        ++i;
    }

    // Find the class code with the highest frequency
    int maxVal = -1;
    for (Entry<Integer, Integer> entry: majorityMap.entrySet()) {
        int entryVal = entry.getValue();
        if(entryVal > maxVal) {
            maxVal = entryVal;
            label = entry.getKey();
        }
    }

    return label;
}
公共静态int分类(ImageMatrixDB训练集,int[]curData,int k){
int label=0,bestDistance=Integer.MAX_值;
int[][]距离=新int[trainingSet.size()][2];
int i=0;
//将距离放置在要排序的数组中
用于(图像矩阵:培训集){
距离[i][0]=距离(matrix.getData(),curData);
距离[i][1]=matrix.getClassCode();
i++;
}
排序(距离,(int[]lhs,int[]rhs)->lhs[0]-rhs[0]);
//查找每个类别代码的频率
i=0;
地图地图;
majorityMap=新HashMap();
而(imaxVal){
最大值=入口值;
label=entry.getKey();
}
}
退货标签;
}

您只需添加K作为参数。但是,请记住,上面的代码并没有以特定的方式处理关系。

尽管您的问题在于
分类方法,但我认为对图像使用欧几里德距离不是一个好主意。一旦拉伸图像,就会丢失相关信息。F或者,例如,两张属于同一个人、背景颜色不同的图像会产生很高的欧几里德距离。谢谢你的帮助。在阅读了你的图片后,我发现我最初的尝试有什么问题,真的很有帮助。