在scala中编写udf函数并在pyspark作业中使用

在scala中编写udf函数并在pyspark作业中使用,scala,apache-spark,pyspark,Scala,Apache Spark,Pyspark,我们正在尝试编写scala udf函数,并从pyspark中的map函数调用它。dateframe架构非常复杂,我们要传递给此函数的列是StructType数组 trip\u force\u speeds=trip\u details.groupby(“车辆id”、“驾驶员id”、“起点本地”、“终点本地”)\ .agg(collect\u list(结构(col)(“事件开始\u dt\u本地”), 上校(“部队”), col(“速度”), col(“从开始到结束”), col(“第二段自第二

我们正在尝试编写scala udf函数,并从pyspark中的map函数调用它。dateframe架构非常复杂,我们要传递给此函数的列是StructType数组

trip\u force\u speeds=trip\u details.groupby(“车辆id”、“驾驶员id”、“起点本地”、“终点本地”)\
.agg(collect\u list(结构(col)(“事件开始\u dt\u本地”),
上校(“部队”),
col(“速度”),
col(“从开始到结束”),
col(“第二段自第二段”),
col(“StartDocal”),
col(“EndDtLocal”),
col(“verisk_车辆识别号”),
col(行程持续时间秒)\
.别名(“行程详细信息”)

在映射函数中,我们需要进行一些计算

def calculateVariables(rec: Row):HashMap[String,Float] = {
val trips = rec.getAs[List]("trips")
val base_variables = new HashMap[String, Float]()   

val entropy_variables = new HashMap[String, Float]()

val week_day_list = List("monday", "tuesday", "wednesday", "thursday", "friday")

for (trip <- trips)
{
  if (trip("start_dt_local") >= trip("StartDtLocal") && trip("start_dt_local") <= trip("EndDtLocal"))
  {
    base_variables("trip_summary_count") += 1

    if (trip("duration_sec").toFloat >= 300 && trip("duration_sec").toFloat <= 1800) {
      base_variables ("bounded_trip") +=  1

      base_variables("bounded_trip_duration") = trip("duration_sec") + base_variables("bounded_trip_duration")

      base_variables("total_bin_1") += 30

      base_variables("total_bin_2") += 30

      base_variables("total_bin_3") += 60

      base_variables("total_bin_5") += 60

      base_variables("total_bin_6") += 30

      base_variables("total_bin_7") += 30
    }
    if (trip("duration_sec") > 120 && trip("duration_sec") < 21600 )
    {
      base_variables("trip_count") += 1
    }

    base_variables("trip_distance") += trip("distance_km")

    base_variables("trip_duration") = trip("duration_sec") + base_variables("trip_duration")

    base_variables("speed_event_distance") = trip("speed_event_distance_km")  + base_variables("speed_event_distance")

    base_variables("speed_event_duration") = trip("speed_event_duration_sec") + base_variables("speed_event_duration")

    base_variables("speed_event_distance_ratio") = trip("speed_distance_ratio") + base_variables("speed_event_distance_ratio")

    base_variables("speed_event_duration_ratio") = trip("speed_duration_ratio") + base_variables("speed_event_duration_ratio")

  }
}
return base_variables
}
我们定义函数签名的方式是否有问题?我们试图重写spark structtype,但这对我不起作用


我来自python背景,在python作业中面临一些性能问题,这就是为什么我决定用Scala编写此映射函数。

您必须在udf中使用行类型而不是结构类型。StructType表示模式本身,而不是数据。您可以使用Scala中的一个小示例:

object test{

  import org.apache.spark.sql.functions.{udf, collect_list, struct}

  val hash = HashMap[String, Float]("start_dt_local" -> 0)
  // This simple type to store you results
  val sampleDataset = Seq(Row(Instant.now().toEpochMilli, Instant.now().toEpochMilli))

  implicit val spark: SparkSession =
    SparkSession
      .builder()
      .appName("Test")
      .master("local[*]")
      .getOrCreate()

  def calculateVariablesUdf = udf { trip: Row =>

    if(trip.getAs[Long]("start_dt_local") >= trip.getAs[Long]("StartDtLocal")) {
      // crate a new instance with your results
      hash("start_dt_local") + 1
    } else {
      hash("start_dt_local") + 0
    }

  }


  def main(args: Array[String]) : Unit = {

    Logger.getLogger("org").setLevel(Level.OFF)
    Logger.getLogger("akka").setLevel(Level.OFF)

    val rdd = spark.sparkContext.parallelize(sampleDataset)
    val df = spark.createDataFrame(rdd, StructType(List(StructField("start_dt_local", LongType, false), StructField("StartDtLocal", LongType, false))))

    df.agg(collect_list(calculateVariablesUdf(struct(col("start_dt_local"), col("StartDtLocal")))).as("result")).show(false)

  }
}
编辑。为了更好地理解:

当你考虑一个模式描述:StructType(list(Strutfield))作为你的字段的类型时,你就错了。数据帧中没有列表类型

如果将CalculateVariable视为udf,则不需要for循环。我的意思是:

def calculateVariables = udf { trip: Row =>
  trip("start_dt_local").getAs[Long] 
  // your logic ....

}

正如我在示例中所说,您可以直接在udf中返回更新后的哈希值。您可以使用java.sql.Timestamp来处理时间戳类型。列表的错误是因为Scala中的列表类型是一种类似于阻碍的类型,所以需要声明列表元素的类型:List[Int][或List[String],等等。。。
def calculateVariables = udf { trip: Row =>
  trip("start_dt_local").getAs[Long] 
  // your logic ....

}