Scala 为什么可变映射在Spark的UserDefinedAggregateFunction(UDAF)中自动变为不可变

Scala 为什么可变映射在Spark的UserDefinedAggregateFunction(UDAF)中自动变为不可变,scala,apache-spark,mutable,user-defined-aggregate,Scala,Apache Spark,Mutable,User Defined Aggregate,我试图在Spark中定义一个UserDefinedAggregateFunction(UDAF),它统计组中一列中每个唯一值的出现次数 这是一个例子: 假设我有这样一个数据帧df +----+----+ |col1|col2| +----+----+ | a| a1| | a| a1| | a| a2| | b| b1| | b| b2| | b| b3| | b| b1| | b| b1| +----+----+ import org.apac

我试图在Spark中定义一个UserDefinedAggregateFunction(UDAF),它统计组中一列中每个唯一值的出现次数

这是一个例子: 假设我有这样一个数据帧
df

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}
我将有一个UDAF DistinctValues

val func = new DistinctValues
然后我将其应用于数据帧df

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV"))
我希望有这样的东西:

+----+--------------------------+
|col1|DV                        |
+----+--------------------------+
|   a|  Map(a1->2, a2->1)       |
|   b|  Map(b1->3, b2->1, b3->1)|
+----+--------------------------+
所以我拿出了一个像这样的UDAF

+----+----+
|col1|col2|
+----+----+
|   a|  a1|
|   a|  a1|
|   a|  a2|
|   b|  b1|
|   b|  b2|
|   b|  b3|
|   b|  b1|
|   b|  b1|
+----+----+
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.LongType
import Array._

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = scala.collection.mutable.Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp.put(str, c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0)
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1.put(k ,c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[scala.collection.mutable.Map[String, LongType]](0)
  }
}
然后我的数据框上有这个函数

val func = new DistinctValues
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV"))
它给了这样的错误,

func: DistinctValues = $iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map
at $iwC$$iwC$DistinctValues.update(<console>:39)
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)
at org.apache.spark.scheduler.Task.run(Task.scala:89)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
at java.lang.Thread.run(Thread.java:745)
func:distinctValue=$iwC$$iwC$DistinctValues@17f48a25
org.apache.spark.SparkException:作业因阶段失败而中止:阶段32.0中的任务1失败4次,最近的失败:阶段32.0中的任务1.3丢失(TID 884,ip-172-31-22-166.ec2.internal):java.lang.ClassCastxception:scala.collection.immutable.Map$EmptyMap$无法转换为scala.collection.mutable.Map
在$iwC$$iwC$DistinctValue.update(:39)
位于org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431)
位于org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187)
位于org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180)
位于org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116)
位于org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152)
位于org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
在scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
在scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
位于org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149)
在org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73)上
在org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41)上
位于org.apache.spark.scheduler.Task.run(Task.scala:89)
位于org.apache.spark.executor.executor$TaskRunner.run(executor.scala:213)
位于java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145)
位于java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615)
运行(Thread.java:745)
更新(buffer:MutableAggregationBuffer,input:Row)
方法中,变量
buffer
是一个
不可变的.Map
,程序试图将其强制转换为
mutable.Map

但是我使用
mutable.Map
initialize(buffer:MutableAggregationBuffer,input:Row)
方法中初始化
buffer
变量。传递给
update
方法的变量是否相同?而且
buffer
mutableAggregationBuffer
,所以它应该是可变的,对吗

为什么我的可变映射变得不可变?有人知道发生了什么吗


我真的需要在这个函数中的可变映射来完成任务。我知道有一种解决方法可以从不可变映射创建可变映射,然后更新它。但是我真的很想知道为什么程序中的可变类型会自动转换为不可变类型,这对我来说毫无意义。

相信这就是你的
结构类型中的
映射类型<因此,code>buffer
保存一个不可变的
Map

您可以转换它,但为什么不让它保持不变并执行以下操作:

mp = mp + (k -> c)
要向不可变的
映射添加条目

工作示例如下:

class DistinctValues extends UserDefinedAggregateFunction {
  def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil)

  def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil)

  def dataType: DataType =  MapType(StringType, LongType)
  def deterministic: Boolean = true

  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = Map()
  }

  def update(buffer: MutableAggregationBuffer, input: Row) : Unit = {
    val str = input.getAs[String](0)
    var mp = buffer.getAs[Map[String, Long]](0)
    var c:Long = mp.getOrElse(str, 0)
    c = c + 1
    mp = mp  + (str -> c)
    buffer(0) = mp
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = {
    var mp1 = buffer1.getAs[Map[String, Long]](0)
    var mp2 = buffer2.getAs[Map[String, Long]](0)
    mp2 foreach {
        case (k ,v) => {
            var c:Long = mp1.getOrElse(k, 0)
            c = c + v
            mp1 = mp1 + (k -> c)
        }
    }
    buffer1(0) = mp1
  }

  def evaluate(buffer: Row): Any = {
      buffer.getAs[Map[String, LongType]](0)
  }
}

晚会迟到了。我刚刚发现一个人可以使用

override def bufferSchema: StructType = StructType(List(
    StructField("map", ObjectType(classOf[mutable.Map[String, Long]]))
))

在缓冲区中使用
mutable.Map

捕捉得好!嗯,
StructType
中的
MapyType
可能就是这种情况。但是
spark.sql.types
中没有其他可变的映射类型,除非我定义了自己的映射类型。就像我说的,不要——只使用不可变的
映射
mp=mp+(k->c)
在一个不可变的
Map上
提供了与
mp相同的功能。将(k,c)
放在一个可变的
Map上
mp=mp+(k->c)
工作!我是scala新手,不知道您可以像这样操作不可变的数据类型。非常感谢你!与其说是在操作它,不如说是在前一个实例的基础上创建一个全新的实例,然后丢弃前一个实例。但是,是的,在这一点上,我几乎只使用不可变集合——实际上没有太多理由使用可变集合!因此,
mp
需要是
var
,您正在为变量
mp
重新分配一个新映射。接受答复