Scala Spark:如何将行分组到固定大小的数组中?

Scala Spark:如何将行分组到固定大小的数组中?,scala,apache-spark,apache-spark-sql,partitioning,Scala,Apache Spark,Apache Spark Sql,Partitioning,我的数据集如下所示: +---+ |col| +---+ | a| | b| | c| | d| | e| | f| | g| +---+ 我想重新格式化此数据集,以便将行聚合为固定长度的数组,如下所示: +------+ | col| +------+ |[a, b]| |[c, d]| |[e, f]| | [g]| +------+ 我试过这个: spark.sqlselect collect_listcol from select col,row_number

我的数据集如下所示:

+---+
|col|
+---+
|  a|
|  b|
|  c|
|  d|
|  e|
|  f|
|  g|
+---+
我想重新格式化此数据集,以便将行聚合为固定长度的数组,如下所示:

+------+
|   col|
+------+
|[a, b]|
|[c, d]|
|[e, f]|
|   [g]|
+------+
我试过这个:

spark.sqlselect collect_listcol from select col,row_number over order by col row_number from dataset group by floorrow_number/2


但问题是,我的实际数据集太大,无法在单个分区中处理行号

因为您希望分发此数据集,所以需要执行几个步骤

如果您希望运行代码,我将从以下内容开始:

var df = List(
  "a", "b", "c", "d", "e", "f", "g"
).toDF("col")
val desiredArrayLength = 2
首先,将您的数据帧拆分为一个小的数据帧,您可以在单个节点上处理,而较大的数据帧的行数是您的示例中所需数组大小的倍数,这是2

val nRowsPrune = 1 //number of rows to prune such that remaining dataframe has number of
                   // rows is multiples of the desired length of array
val dfPrune = df.sort(desc("col")).limit(nRowsPrune)
df = df.join(dfPrune,Seq("col"),"left_anti") //separate small from large dataframe
通过构造,您可以在小数据帧上应用原始代码

val groupedPruneDf = dfPrune//.withColumn("g",floor((lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
                            //.groupBy("g")
                            .agg( collect_list("col").alias("col"))
                            .select("col")
现在,我们需要找到一种方法来处理剩余的大型数据帧。但是,现在我们确定df有许多行,它们是数组大小的倍数。 这就是我们使用的一个很好的技巧,即使用repartitionByRange重新分区。基本上,分区保证保留排序,并且在进行分区时,每个分区的大小都相同。 现在,您可以收集每个分区中的每个数组

   val nRows = df.count()
   val maxNRowsPartition = desiredArrayLength //make sure its a multiple of desired array length
   val nPartitions = math.max(1,math.floor(nRows/maxNRowsPartition) ).toInt
   df = df.repartitionByRange(nPartitions, $"col".desc)
          .withColumn("partitionId",spark_partition_id())

    val w = Window.partitionBy($"partitionId").orderBy("col")
    val groupedDf = df
        .withColumn("g",  floor( (lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
        .groupBy("partitionId","g")
        .agg( collect_list("col").alias("col"))
        .select("col")
最后,将这两个结果结合起来,得到您想要的结果

val result = groupedDf.union(groupedPruneDf)
result.show(truncate=false)

谢谢这对我有用。对我来说,新的和有用的概念是repartitionByRange和spark_partition_id。第一部分拆分成单独的剩余数据帧是否重要?即使我不做这部分工作,似乎我只会有一条记录,大小=记录数%分区大小。修剪是必要的,以确保所有数组的大小相同,并将连续行分组,但剩余行除外