基于Spark和Java的分层抽样

基于Spark和Java的分层抽样,java,apache-spark,machine-learning,apache-spark-mllib,Java,Apache Spark,Machine Learning,Apache Spark Mllib,我想确保我是在我的数据分层样本上进行培训的 Spark 2.1和早期版本似乎通过JavaPairdd.sampleByKey(…)和JavaPairdd.sampleByKeyExact(…)支持这一点 但是:我的数据存储在数据集中,而不是javapairdd。第一列是标签,其他所有列都是特征(从libsvm格式的文件导入) 获取我的dataset实例的分层样本并在最后再次创建dataset的最简单方法是什么 从某种程度上说,这个问题与我有关 这根本没有提到数据集,Java中也没有。它没有回答我

我想确保我是在我的数据分层样本上进行培训的

Spark 2.1和早期版本似乎通过
JavaPairdd.sampleByKey(…)
JavaPairdd.sampleByKeyExact(…)
支持这一点

但是:我的数据存储在
数据集中,而不是
javapairdd
。第一列是标签,其他所有列都是特征(从libsvm格式的文件导入)

获取我的dataset实例的分层样本并在最后再次创建
dataset
的最简单方法是什么

从某种程度上说,这个问题与我有关

这根本没有提到数据集,Java中也没有。它没有回答我的问题。

好的,因为的答案实际上不是针对Java,所以我用Java重写了它

推理仍然是相同的想法。我们仍在使用
sampleByKeyExact
。目前没有现成的奇迹功能(spark 2.1.0

那么,给你:

package org.awesomespark.examples;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.*;
import scala.Tuple2;

import java.util.Map;

public class StratifiedDatasets {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder()
                .appName("Stratified Datasets")
                .getOrCreate();

        Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt");

        JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0));
        Map<Double, Double> fractions = rdd.map(Tuple2::_1)
                .distinct()
                .mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8))
                .collectAsMap();

        JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values();
        Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema());

        sampledData.show();
        sampledData.printSchema();
    }
}
对于python用户,您也可以查看我的答案

$ sbt package
[...]
// [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM

$ spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar 
[...]

// +-----+--------------------+
// |label|            features|
// +-----+--------------------+
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[158,159,160...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[152,153,154...|
// |  1.0|(692,[151,152,153...|
// |  0.0|(692,[129,130,131...|
// |  1.0|(692,[99,100,101,...|
// |  0.0|(692,[154,155,156...|
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[154,155,156...|
// |  0.0|(692,[151,152,153...|
// |  1.0|(692,[129,130,131...|
// |  0.0|(692,[154,155,156...|
// |  1.0|(692,[150,151,152...|
// |  0.0|(692,[124,125,126...|
// |  0.0|(692,[152,153,154...|
// |  1.0|(692,[97,98,99,12...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[156,157,158...|
// |  1.0|(692,[127,128,129...|
// +-----+--------------------+
// only showing top 20 rows

// root
//  |-- label: double (nullable = true)
//  |-- features: vector (nullable = true)