Python pyspark窗口函数的avg计算问题
我有一个输入数据框,如下所示:Python pyspark窗口函数的avg计算问题,python,dataframe,pyspark,average,pyspark-dataframes,Python,Dataframe,Pyspark,Average,Pyspark Dataframes,我有一个输入数据框,如下所示: partner_id|month_id|value1 |value2 1001 | 01 |10 |20 1002 | 01 |20 |30 1003 | 01 |30 |40 1001 | 02 |40 |50 1002 | 02 |50 |60 1003 | 02 |60 |
partner_id|month_id|value1 |value2
1001 | 01 |10 |20
1002 | 01 |20 |30
1003 | 01 |30 |40
1001 | 02 |40 |50
1002 | 02 |50 |60
1003 | 02 |60 |70
1001 | 03 |70 |80
1002 | 03 |80 |90
1003 | 03 |90 |100
使用下面的代码,我创建了两个新列,它们使用窗口函数进行平均:
rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df = df.withColumn("value1_1", F.avg("value1").over(rnum))
df = df.withColumn("value1_2", F.avg("value2").over(rnum))
输出:
partner_id|month_id|value1 |value2|value1_1|value2_2
1001 | 01 |10 |20 |10 |20
1002 | 01 |20 |30 |20 |30
1003 | 01 |30 |40 |30 |40
1001 | 02 |40 |50 |25 |35
1002 | 02 |50 |60 |35 |45
1003 | 02 |60 |70 |45 |55
1001 | 03 |70 |80 |40 |50
1002 | 03 |80 |90 |50 |60
1003 | 03 |90 |100 |60 |70
使用pyspark窗口函数,value1和value2列的累积平均值表现良好。
但是,如果我们在下面的输入中遗漏了一个月的数据,那么下个月的平均值计算应该基于月号,而不是正常平均值。
例如,如果输入如下(缺少第02个月的数据)
然后,第三个月记录的平均计算如下:对于ex:(70+10)/2
但是,如果缺少某个月的值,那么计算平均值的正确方法是什么呢?Spark不够聪明,无法理解缺少一个月,因为它甚至不知道一个月可能是什么 如果希望“缺失”月份包含在平均计算中,则需要生成缺失数据 只需使用数据帧[“month_id”,“defaultValue”]执行完整的外部联接,其中month_id是1到12之间的值,defaultValue=0
另一种解决方案,不是执行平均值,而是执行值的总和,然后除以月份数。如果您使用的是spark 2.4+。可以使用序列函数和数组函数。 此解决方案的灵感来源于此
您能显示“错误”的输出以及您期望的输出吗?谢谢。我们可以创建与其他月份相同数量的虚拟记录吗???如果您不想从结果集中获得第二个月。您可以删除那些值为1或值为2=0的行
partner_id|month_id|value1 |value2
1001 | 01 |10 |20
1002 | 01 |20 |30
1003 | 01 |30 |40
1001 | 03 |70 |80
1002 | 03 |80 |90
1003 | 03 |90 |100
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w= Window().partitionBy("partner_id")
df1 =df.withColumn("month_seq", F.sequence(F.min("month_id").over(w), F.max("month_id").over(w), F.lit(1)))\
.groupBy("partner_id").agg(F.collect_list("month_id").alias("month_id"), F.collect_list("value1").alias("value1"), F.collect_list("value2").alias("value2")
,F.first("month_seq").alias("month_seq")).withColumn("month_seq", F.array_except("month_seq","month_id"))\
.withColumn("month_id",F.flatten(F.array("month_id","month_seq"))).drop("month_seq")\
.withColumn("zip", F.explode(F.arrays_zip("month_id","value1", "value2"))) \
.select("partner_id", "zip.month_id", F.when(F.col("zip.value1").isNull() , \
F.lit(0)).otherwise(F.col("zip.value1")).alias("value1"),
F.when(F.col("zip.value2").isNull(), F.lit(0)).otherwise(F.col("zip.value2")
).alias("value2")).orderBy("month_id")
rnum = (Window.partitionBy("partner_id").orderBy("month_id").rangeBetween(Window.unboundedPreceding, 0))
df2 = df1.withColumn("value1_1", F.avg("value1").over(rnum)).withColumn("value1_2", F.avg("value2").over(rnum))
df2.show()
# +----------+--------+------+------+------------------+------------------+
# |partner_id|month_id|value1|value2| value1_1| value1_2|
# +----------+--------+------+------+------------------+------------------+
# | 1002| 1| 10| 20| 10.0| 20.0|
# | 1002| 2| 0| 0| 5.0| 10.0|
# | 1002| 3| 80| 90| 30.0|36.666666666666664|
# | 1001| 1| 10| 10| 10.0| 10.0|
# | 1001| 2| 0| 0| 5.0| 5.0|
# | 1001| 3| 70| 80|26.666666666666668| 30.0|
# | 1003| 1| 30| 40| 30.0| 40.0|
# | 1003| 2| 0| 0| 15.0| 20.0|
# | 1003| 3| 90| 100| 40.0|46.666666666666664|
# +----------+--------+------+------+------------------+------------------+