Scala 从训练集数据中排除项目

Scala 从训练集数据中排除项目,scala,apache-spark,apache-spark-sql,apache-spark-mllib,Scala,Apache Spark,Apache Spark Sql,Apache Spark Mllib,我的数据有两种颜色和排除的颜色 颜色包含所有颜色 excluded_colors包含一些我希望从培训集中排除的颜色 我正在尝试将数据拆分为一个训练集和测试集,并确保排除的颜色中的颜色不在我的训练集中,而是存在于测试集中 为了实现上述目标,我做了以下工作 var colors = spark.sql(""" select colors.* from colors LEFT JOIN excluded_colors ON excluded_colors.color_id

我的数据有两种
颜色
排除的颜色

颜色
包含所有颜色
excluded_colors
包含一些我希望从培训集中排除的颜色

我正在尝试将数据拆分为一个训练集和测试集,并确保
排除的颜色中的颜色不在我的训练集中,而是存在于测试集中

为了实现上述目标,我做了以下工作

var colors = spark.sql("""
   select colors.* 
   from colors 
   LEFT JOIN excluded_colors 
   ON excluded_colors.color_id = colors.color_id
   where excluded_colors.color_id IS NULL
"""
)
val trainer: (Int => Int) = (arg:Int) => 0
val sqlTrainer = udf(trainer)
val tester: (Int => Int) = (arg:Int) => 1
val sqlTester = udf(tester)

val rsplit = colors.randomSplit(Array(0.7, 0.3))  
val train_colors = splits(0).select("color_id").withColumn("test",sqlTrainer(col("color_id")))
val test_colors = splits(1).select("color_id").withColumn("test",sqlTester(col("color_id")))
然而,我意识到通过执行上述操作,
排除的颜色中的颜色被完全忽略。它们甚至不在我的测试集中

问题
如何将数据分成70/30,同时确保
排除的颜色中的颜色不在训练中,而是在测试中出现

我们要做的是从训练集中删除“排除的颜色”,但在测试中使用它们,并将训练/测试划分为70/30

我们需要的是一点数学知识

给定总数据集(TD)和排除的颜色数据集(E),我们可以说对于列车数据集(Tr)和测试数据集(Ts):

我们还知道
|Tr |=0.7 | TD |

因此
x=0.7 | TD |/(| TD |-| E |)

现在我们知道了采样因子
x
,我们可以说:

Tr = (TD-E).sample(withReplacement = false, fraction = x)
// where (TD - E) is the result of the SQL expr above

Ts = TD.sample(withReplacement = false, fraction = 0.3)
// we sample the test set from the original dataset

在你的例子中,
x
是什么?
x
是抽样概率,就像你问题中的
0.7
0.3
。我们的想法是,你需要从较小的数据集中选取一个较大的样本,以保持问题中提出的70/30的比例。好的,我想我明白了。我将把你的数学转换成代码,看看我在这方面是否成功。我将相应地更新问题。感谢与
TD
的乘法不起作用。这就是我所拥有的
scala>TD.count res54:Long=46520 scala>E.count res55:Long=41868 scala>val x=0.7*TD:55:error:重载方法值*,可选项:
如下,
|TD |
totalDataset.count()
|E |
排除的颜色.count。因此,有了
val-totalDatasetCount=totalDataset.count()
val-exclocorscont=excluded\u colors.count()
,我们应该能够计算
val-adjustedSamplingFactor=0.7d*totalDatasetCount/(totalDatasetCount-exclocorscourscont)
Tr = (TD-E).sample(withReplacement = false, fraction = x)
// where (TD - E) is the result of the SQL expr above

Ts = TD.sample(withReplacement = false, fraction = 0.3)
// we sample the test set from the original dataset