Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/apache-spark/6.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Apache spark 如何为自定义DeclarativeAggregate(在catalyst包中)定义合并表达式_Apache Spark - Fatal编程技术网

Apache spark 如何为自定义DeclarativeAggregate(在catalyst包中)定义合并表达式

Apache spark 如何为自定义DeclarativeAggregate(在catalyst包中)定义合并表达式,apache-spark,Apache Spark,我不理解确定非平凡聚合器的mergeExpressions函数的一般方法。 类似org.apache.spark.sql.catalyst.expressions.aggregate.Average的mergeexpressions方法非常简单: override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right )

我不理解确定非平凡聚合器的mergeExpressions函数的一般方法。 类似org.apache.spark.sql.catalyst.expressions.aggregate.Average的mergeexpressions方法非常简单:

override lazy val mergeExpressions = Seq(
    /* sum = */ sum.left + sum.right,
    /* count = */ count.left + count.right
  )
CentralMomentaggregators的合并表达式更为复杂。 我想做的是创建一个以sparks CentralMomentAgg为模型的加权stddevsamp聚合器。 我几乎让它工作了,但它产生的加权标准差与我手工计算的结果仍然有点偏差。 调试它时遇到问题,因为我不明白如何计算mergeExpressions方法的确切逻辑。 下面是我的代码。updateExpressions方法就是基于此,所以我非常确定该方法是正确的。我认为我的问题在于合并表达式方法。如有任何提示,将不胜感激

abstract class WeightedCentralMomentAgg(child: Expression, weight: Expression) extends DeclarativeAggregate {

  override def children: Seq[Expression] = Seq(child, weight)
  override def nullable: Boolean = true
  override def dataType: DataType = DoubleType
  override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType)

  protected val wSum = AttributeReference("wSum", DoubleType, nullable = false)()
  protected val mean = AttributeReference("mean", DoubleType, nullable = false)()
  protected val s = AttributeReference("s", DoubleType, nullable = false)()
  override val aggBufferAttributes = Seq(wSum, mean, s)
  override val initialValues: Seq[Expression] = Array.fill(3)(Literal(0.0))

  // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm
  override val updateExpressions: Seq[Expression] = {

    val newWSum = wSum + weight
    val newMean = mean + (weight / newWSum) * (child - mean)
    val newS = s + weight * (child - mean) * (child - newMean)

    Seq(
      If(IsNull(child), wSum, newWSum),
      If(IsNull(child), mean, newMean),
      If(IsNull(child), s, newS)
    )
  }

  override val mergeExpressions: Seq[Expression] = {
    val wSum1 = wSum.left
    val wSum2 = wSum.right
    val newWSum = wSum1 + wSum2
    val delta = mean.right - mean.left
    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
    val newMean = mean.left + wSum1 / newWSum * delta                //  ???
    val newS = s.left + s.right + wSum1 * wSum2 * delta * deltaN     //  ???
    Seq(newWSum, newMean, newS)
  }
}


// Compute the weighted sample standard deviation of a column
case class WeightedStddevSamp(child: Expression, weight: Expression)
  extends WeightedCentralMomentAgg(child, weight) {

  override val evaluateExpression: Expression = {
    If(wSum === Literal(0.0), Literal.create(null, DoubleType),
      If(wSum === Literal(1.0), Literal(Double.NaN),
        Sqrt(s / wSum) ) )
  }

  override def prettyName: String = "wtd_stddev_samp"
}

对于任何哈希聚合,它分为四个步骤:

1) 初始化缓冲区(wSum、mean、s)

2) 在分区内,更新给定所有输入的键的缓冲区(为每个输入调用updateExpression)

3) 洗牌后,使用mergeExpression合并同一密钥的所有缓冲区。wSum.left表示左侧缓冲区中的wSum,wSum.right表示另一个缓冲区中的wSum


4) 使用valueExpression从缓冲区获取最终结果对于任何哈希聚合,它分为四个步骤:

1) 初始化缓冲区(wSum、mean、s)

2) 在分区内,更新给定所有输入的键的缓冲区(为每个输入调用updateExpression)

3) 洗牌后,使用mergeExpression合并同一密钥的所有缓冲区。wSum.left表示左侧缓冲区中的wSum,wSum.right表示另一个缓冲区中的wSum


4) 使用valueExpression从缓冲区获取最终结果我发现了如何为加权标准偏差编写mergeExpressions函数。实际上我是对的,但是在evaluateExpression中使用了总体方差而不是样本方差计算。下面显示的实现给出了与上面相同的结果,但更容易理解

override val mergeExpressions: Seq[Expression] = {   
    val newN = n.left + n.right
    val wSum1 = wSum.left
    val wSum2 = wSum.right
    val newWSum = wSum1 + wSum2
    val delta = mean.right - mean.left

    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
    val newMean = mean.left + deltaN * wSum2
    val newS =  (((wSum1 * s.left) + (wSum2 * s.right)) / newWSum) + (wSum1 * wSum2 * deltaN * deltaN)

    Seq(newN, newWSum, newMean, newS)
}
这里有一些参考资料

  • (最后一个为我提供了mergeExpressions函数所需的线索)
Davies的帖子给出了该方法的概要,但对于许多非平凡的聚合器,我认为mergeExpressions函数可能非常复杂,需要使用高等数学来确定正确有效的解决方案。幸运的是,在这种情况下,我发现有人已经解决了这个问题


这个解决方案与我手工计算的结果相符。需要注意的是,如果希望使用样本方差而不是总体方差,则需要稍微修改evaluateExpression(即s/((n-1)*wSum/n))

我发现了如何编写加权标准差的mergeExpressions函数。实际上我是对的,但是在evaluateExpression中使用了总体方差而不是样本方差计算。下面显示的实现给出了与上面相同的结果,但更容易理解

override val mergeExpressions: Seq[Expression] = {   
    val newN = n.left + n.right
    val wSum1 = wSum.left
    val wSum2 = wSum.right
    val newWSum = wSum1 + wSum2
    val delta = mean.right - mean.left

    val deltaN = If(newWSum === Literal(0.0), Literal(0.0), delta / newWSum)
    val newMean = mean.left + deltaN * wSum2
    val newS =  (((wSum1 * s.left) + (wSum2 * s.right)) / newWSum) + (wSum1 * wSum2 * deltaN * deltaN)

    Seq(newN, newWSum, newMean, newS)
}
这里有一些参考资料

  • (最后一个为我提供了mergeExpressions函数所需的线索)
Davies的帖子给出了该方法的概要,但对于许多非平凡的聚合器,我认为mergeExpressions函数可能非常复杂,需要使用高等数学来确定正确有效的解决方案。幸运的是,在这种情况下,我发现有人已经解决了这个问题

这个解决方案与我手工计算的结果相符。需要注意的是,如果希望使用样本方差而不是总体方差,则需要稍微修改evaluateExpression(即s/((n-1)*wSum/n))