Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/apache-spark/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Scala Spark 2.0.0:如何使用自定义编码类型聚合数据集?_Scala_Apache Spark_Aggregate Functions_Apache Spark Dataset - Fatal编程技术网

Scala Spark 2.0.0:如何使用自定义编码类型聚合数据集?

Scala Spark 2.0.0:如何使用自定义编码类型聚合数据集?,scala,apache-spark,aggregate-functions,apache-spark-dataset,Scala,Apache Spark,Aggregate Functions,Apache Spark Dataset,我将一些数据存储为DataSet[(Long,LineString)],使用元组编码器和用于LineString的kryo编码器 implicit def single[A](implicit c: ClassTag[A]): Encoder[A] = Encoders.kryo[A](c) implicit def tuple2[A1, A2](implicit e1: Encoder[A1],

我将一些数据存储为DataSet[(Long,LineString)],使用元组编码器和用于LineString的kryo编码器

implicit def single[A](implicit c: ClassTag[A]): Encoder[A] = Encoders.kryo[A](c)
implicit def tuple2[A1, A2](implicit
                            e1: Encoder[A1],
                            e2: Encoder[A2]
                           ): Encoder[(A1,A2)] = Encoders.tuple[A1,A2](e1, e2)
implicit val lineStringEncoder = Encoders.kryo[LineString]

val ds = segmentPoints.map(
  sp => {
    val p1 = new Coordinate(sp.lon_ini, sp.lat_ini)
    val p2 = new Coordinate(sp.lon_fin, sp.lat_fin)
    val coords = Array(p1, p2)

    (sp.id, gf.createLineString(coords))
  })
  .toDF("id", "segment")
  .as[(Long, LineString)]
  .cache

ds.show

    +----+--------------------+
    | id |       segment      |
    +----+--------------------+
    | 347|[01 00 63 6F 6D 2...|
    | 347|[01 00 63 6F 6D 2...|
    | 347|[01 00 63 6F 6D 2...|
    | 808|[01 00 63 6F 6D 2...|
    | 808|[01 00 63 6F 6D 2...|
    | 808|[01 00 63 6F 6D 2...|
    +----+--------------------+
我可以对“线段”列应用任何贴图操作,并使用基础的LinesSign方法

ds.map(_._2.getClass.getName).show(false)

+--------------------------------------+
|value                                 |
+--------------------------------------+
|com.vividsolutions.jts.geom.LineString|
|com.vividsolutions.jts.geom.LineString|
|com.vividsolutions.jts.geom.LineString|
我想创建一些UDAFs来处理具有相同id的段,我尝试了两种不同的方法,但没有成功:

1)使用聚合器:

val length = new Aggregator[LineString, Double, Double] with Serializable {
  def zero: Double = 0                     // The initial value.
  def reduce(b: Double, a: LineString) = b + a.getLength    // Add an element to the running total
  def merge(b1: Double, b2: Double) = b1 + b2 // Merge intermediate values.
  def finish(b: Double) = b
  // Following lines are missing on the API doc example but necessary to get
  // the code compile
  override def bufferEncoder: Encoder[Double] = Encoders.scalaDouble
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}.toColumn

ds.groupBy("id")
  .agg(length(col("segment")).as("kms"))
  .show(false)
这里我得到了以下错误:

 Exception in thread "main" org.apache.spark.sql.AnalysisException: unresolved operator 'Aggregate [id#603L], [id#603L, anon$1(com.test.App$$anon$1@5bf1e07, None, input[0, double, true] AS value#715, cast(value#715 as double), input[0, double, true] AS value#714, DoubleType, DoubleType)['segment] AS kms#721];
2)使用用户定义的聚合函数

class Length extends UserDefinedAggregateFunction {
  val e = Encoders.kryo[LineString]

  // This is the input fields for your aggregate function.
  override def inputSchema: StructType = StructType(
    StructField("segment", DataTypes.BinaryType) :: Nil
  )

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
      StructField("length", DoubleType) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer : MutableAggregationBuffer, input : Row) : Unit = {
    // val l0 = input.getAs[LineString](0) // Can't cast to LineString (I guess because it is searialized using given encoder)
    val b = input.getAs[Array[Byte]](0) // This works fine
    val lse = e.asInstanceOf[ExpressionEncoder[LineString]]
    val ls = lse.fromRow(???) // it expects InternalRow  but input is a Row instance
    // I also tried casting b.asInstance[InternalRow] without success.
    buffer(0) = buffer.getAs[Double](0) + ls.getLength
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer.getDouble(0)
  }
}

val length = new Length
rseg
  .groupBy("id")
  .agg(length(col("segment")).as("kms"))
  .show(false)
我做错了什么?我希望将聚合API与自定义类型一起使用,而不是使用rdd groupBy API。我搜索了Spark文档,但找不到这个问题的答案,似乎目前还处于早期阶段

谢谢。

根据这一点,没有简单的方法可以传递嵌套类型的自定义编码器,例如在您的案例中的(Long,LineString)

一个选项可以是定义一个
案例类LineStringWithID
,它将扩展
LineString
id:Long
属性,并使用


另外,你能把你的问题分解成更小的部分,每个主题一个吗?

也许有人也会想知道:当使用kryo编码器时,你不能使用非类型化的、基于SQL的API来处理数据集。您只能使用类型化API,就分组而言,这意味着您需要使用自定义的
聚合器,而不是自定义的
UserDefinedAggregateFunction
。我认为您的
Aggregator
实现还可以,但是您的分组应该更改为在自定义聚合器实例中使用typed
groupByKey

ds.groupByKey(_._1)
  .agg(length)
  .show(false)