Java和Python中的LibSVM提供不同的输出。

Java和Python中的LibSVM提供不同的输出。,java,python,processing,svm,libsvm,Java,Python,Processing,Svm,Libsvm,我一直在使用高级python脚本svmutil.py来生成svm模型 但实际上,我需要用Java运行我的SVM。使用相同的数据集(3000个条目,包含18个特性和5个类)训练一个模型,可以得到不同数量的NSV >>> from svmutil import * >>> y,x = svm_read_problem('train.txt') >>> m = svm_train(y[:3000],x[:3000], '-t 2 -s 0') *

我一直在使用高级python脚本svmutil.py来生成svm模型

但实际上,我需要用Java运行我的SVM。使用相同的数据集(3000个条目,包含18个特性和5个类)训练一个模型,可以得到不同数量的NSV

>>> from svmutil import *
>>> y,x = svm_read_problem('train.txt')
>>> m = svm_train(y[:3000],x[:3000], '-t 2 -s 0')
*
optimization finished, #iter = 67
nu = 0.105257
obj = -89.960869, rho = -0.027008
nSV = 128, nBSV = 126
***
[a couple of more iters here... ]
***
optimization finished, #iter = 19
nu = 0.016800
obj = -10.178571, rho = -0.078282
nSV = 22, nBSV = 19
Total nSV = 430
如您所见,NSV的总量为430

我的Java实现如下所示(我想我必须提到,我将在处理中运行它):

因此,基本上,对于相同的数据集,我得到了不同数量的nSV,正如您在outout of the processing草图中所看到的,预测函数表示nr 111的值在类3中,但实际上它在类0中。
(python程序说,使用此测试和训练数据集,SVM可以达到97%的准确率)

我解决了这个问题。我忘记了我在处理代码中对数据集进行了规范化,同时以与svm不同的方式收集数据。我解决了这个问题。我忘记了我在处理代码中对数据集进行了规范化,同时以与svm不同的方式进行了收集。
import libsvm.*;

double[][] train = new double[3000][];  // 3000 entries in training file
double[][] test = new double[952][];    // 952 entries in testing file

Table t; //generated data is saved in a csv in the form of a processing-table
int classes = 5;
svm_model m; 

private svm_model svmTrain() {
    svm_problem prob = new svm_problem();
    int dataCount = train.length;
    prob.y = new double[dataCount];
    prob.l = dataCount;
    prob.x = new svm_node[dataCount][];     

    for (int i = 0; i < dataCount; i++){            
        double[] features = train[i];
        prob.x[i] = new svm_node[features.length-1];
        for (int j = 1; j < features.length; j++){
            svm_node node = new svm_node();
            node.index = j;
            node.value = features[j];
            prob.x[i][j-1] = node;
        }           
        prob.y[i] = features[0];
    }               

    svm_parameter param = new svm_parameter();
    param.probability = 1;
    param.gamma = 0.5;
    param.nu = 0.5;
    param.C = 1;
    param.svm_type = svm_parameter.C_SVC;
    param.kernel_type = svm_parameter.RBF;       
    param.cache_size = 10000;
    param.eps = 0.1;      

    svm_model model = svm.svm_train(prob, param);

    return model;
}

public double evaluate(double[] features, svm_model model) 
{
    svm_node[] nodes = new svm_node[features.length-1];
    for (int i = 1; i < features.length; i++)
    {
        svm_node node = new svm_node();
        node.index = i;
        node.value = features[i];
        nodes[i-1] = node;
    }

    int totalClasses = classes;       
    int[] labels = new int[totalClasses];
    svm.svm_get_labels(model,labels);

    double[] prob_estimates = new double[totalClasses];

    double v = svm.svm_predict_probability(model, nodes, prob_estimates);

    for (int i = 0; i < totalClasses; i++){
        System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")");
    }
    System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");                
    return v;
}

void setup(){

int q = 0;
t = loadTable("train.csv", "header");
println(t.getRowCount() + " total rows in table");

 for( TableRow row : t.rows()){
   double[] vals = new double[19]; // 18 features + classID
  for( int p = 0; p< 19; p++){   
   vals[p] = row.getFloat(p);
  }
  train[q] = vals;
  q++; 
 }

 m = svmTrain();
 q = 0;


t = loadTable("test.csv", "header");
println(t.getRowCount() + " total rows in table");

 for( TableRow row : t.rows()){
   double[] vals = new double[18];
  for( int p = 0; p< 18; p++){

   vals[p] = row.getFloat(p);
  Float k = row.getFloat(p); 
  }
  test[q] = vals;
  q++; 
 }    

  double b = evaluate(test[111],  m); 
}
optimization finished, #iter = 11
nu = 0.005126779895638029
obj = -2.974715882922519, rho = -0.24619879130684083
nSV = 13, nBSV = 2
Total nSV = 154
952 total rows in table
(0:0.012058316041050699)(1:0.004177821114087953)(2:0.0010059539653873603)(3:0.9816075047230208)(4:0.0011504041564533865)(Actual:0.0 Prediction:3.0)
3.0