Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/scala/18.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181

Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/apache-spark/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Scala Spark数据帧筛选器在随机情况下未按预期工作_Scala_Apache Spark_Dataframe_Apache Spark Sql - Fatal编程技术网

Scala Spark数据帧筛选器在随机情况下未按预期工作

Scala Spark数据帧筛选器在随机情况下未按预期工作,scala,apache-spark,dataframe,apache-spark-sql,Scala,Apache Spark,Dataframe,Apache Spark Sql,这是我的数据帧 df.groupBy($"label").count.show +-----+---------+ |label| count| +-----+---------+ | 0.0|400000000| | 1.0| 10000000| +-----+---------+ 我正在尝试使用以下内容对label==0.0的记录进行子采样: val r

这是我的数据帧

df.groupBy($"label").count.show
+-----+---------+                                                               
|label|    count|
+-----+---------+
|  0.0|400000000|
|  1.0| 10000000|
+-----+---------+
我正在尝试使用以下内容对label==0.0的记录进行子采样:

val r = scala.util.Random
val df2 = df.filter($"label" === 1.0 || r.nextDouble > 0.5) // keep 50% of 0.0
我的输出如下所示:

df2.groupBy($"label").count.show
+-----+--------+                                                                
|label|   count|
+-----+--------+
|  1.0|10000000|
+-----+--------+

r.nextDouble
是表达式中的一个常量,因此实际计算结果与您的意思大不相同。根据实际采样值,它是

scala> r.setSeed(0)

scala> $"label" === 1.0 || r.nextDouble > 0.5
res0: org.apache.spark.sql.Column = ((label = 1.0) OR true)

因此,经过简化后,它只是:

true
(保存所有记录)或

(仅保留一个,您观察到的案例)分别

要生成随机数,您应该使用

尽管Spark已经提供了分层抽样工具:

df.stat.sampleBy(
  "label",  // column
  Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
  42 // seed 
)
label = 1.0 
scala> import org.apache.spark.sql.functions.rand
import org.apache.spark.sql.functions.rand

scala> $"label" === 1.0 || rand > 0.5
res1: org.apache.spark.sql.Column = ((label = 1.0) OR (rand(3801516599083917286) > 0.5))
df.stat.sampleBy(
  "label",  // column
  Map(0.0 -> 0.5, 1.0 -> 1.0),  // fractions
  42 // seed 
)