Pyspark多个简单聚合最佳实践-countif/suif格式

Pyspark多个简单聚合最佳实践-countif/suif格式,pyspark,Pyspark,我是Pyspark的新手,我正在寻找关于在一个长数据帧上进行多个简单聚合的最佳方法的建议 我有一个事务数据框架,其中客户一天有多个事务,我想按客户分组,并保留一些变量,如total sum,以及一些变量,如countdistinct,以区分条件所适用的日期 因此,我想了解每位客户: 他们在多少天内购买了A类产品 他们在多少个周末购买 所有交易的总支出 再加上一些其他的事情,比如上个月的交易,最高消费,周末的最高消费等等。 因此,本质上,excel术语中有相当多的countifs或sumif 我觉

我是Pyspark的新手,我正在寻找关于在一个长数据帧上进行多个简单聚合的最佳方法的建议

我有一个事务数据框架,其中客户一天有多个事务,我想按客户分组,并保留一些变量,如total sum,以及一些变量,如countdistinct,以区分条件所适用的日期

因此,我想了解每位客户:

他们在多少天内购买了A类产品 他们在多少个周末购买 所有交易的总支出 再加上一些其他的事情,比如上个月的交易,最高消费,周末的最高消费等等。 因此,本质上,excel术语中有相当多的countifs或sumif

我觉得这不是最好的事情,以计算所有这些单独如下,然后把他们结合在一起,根据答案 ,因为我有相当多的客户,所以加入将非常昂贵,而且由于一些客户在任何周末都不进行交易,我认为这将需要一个加入,而不仅仅是一个承诺:

total_variables = transactions.groupby('cust_id').agg(sum("spend").alias("total_spend")) 
weekend_variables = transactions.where(transactions.weekend_flag == "Y").groupby('cust_id').agg(countDistinct("date").alias("days_txn_on_weekend"))  
catA_variables = transactions.where(transactions.category == "CatA").groupby('cust_id').agg(countDistinct("date").alias("days_txn_cat_a")) 
final_df = total_variables.join(weekend_variables, col('total_variables.id') == col('weekend_variables.id'), 'left') \
                          .join(catA_variables, col('df1.id') == col('catA_variables.id'), 'left')
一种方法是使列部分为空,然后对它们调用count distint或sum,如下所示:

transactions_additional = transactions.withColumn('date_if_weekend',
                                                psf.when(psf.col("weekend_flag") == "Y",
                                                psf.col('date')).otherwise(psf.lit(None)))
                                      .withColumn('date_if_CatA',
                                                psf.when(psf.col("category") == "CatA",
                                                psf.col('date')).otherwise(psf.lit(None)))
final_df = total_variables .groupby('cust_id').agg(psf.countDistinct("date_if_weekend").alias("days_txn_on_weekend"),
                                                   psf.countDistinct("date_if_CatA").alias("days_txn_cat_a"),
                                                   psf.sum("spend").alias("total_spend"))
但就生成列而言,这似乎是浪费,而且可能会与我最终想要计算的内容失控

我想我可以在pyspark sql中使用countdistinct和case,但我希望有更好的方法使用pyspark语法-可能使用以下格式的自定义聚合UDF:

aggregated_df = transactions.groupby('cust_id').agg(<something that returns total spend>,
                                                    <something that returns days purchased cat A>,
                                                    <something that returns days purchased on the weekend>,)
这可能吗?

对于聚合结果,udf函数非常有用且可读。 下面是示例代码,对于所需的输出,您可以扩展以添加任何其他聚合结果

import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType,IntegerType,LongType,StructType,StructField,StringType
import pandas as pd

#you can add last month maximum spend, maximum spend on the weekend etc and 
#update agg_data function
agg_schema = StructType(
    [StructField("cust_id", StringType(), True),
     StructField("days_txn_on_weekend", IntegerType(), True),
     StructField("days_txn_cat_a", IntegerType(), True),
     StructField("total_spend", IntegerType(), True)
     ]
)

@F.pandas_udf(agg_schema, F.PandasUDFType.GROUPED_MAP)
def agg_data(pdf):
    days_txn_on_weekend =  pdf.query("weekend_flag == 'Y'")['date'].nunique()
    days_txn_cat_a = pdf.query("category == 'CatA'")['date'].nunique()
    total_spend = pdf['spend'].sum()
    return pd.DataFrame([(pdf.cust_id[0],days_txn_on_weekend,days_txn_cat_a,total_spend)])

transactions = spark.createDataFrame(
    [
    ('cust_1', 'CatA', 20190101, 'N', 10),
    ('cust_1', 'CatA', 20190101, 'N', 20),
    ('cust_1', 'CatA', 20190105, 'Y',40),
    ('cust_1', 'CatA', 20190105, 'Y',10),
    ('cust_1', 'CatA', 20190112, 'Y', 20),
    ('cust_1', 'CatA', 20190113, 'Y', 10),
    ('cust_1', 'CatA', 20190101, 'N',20),
    ('cust_1', 'CatB', 20190105, 'Y', 50),
    ('cust_1', 'CatB', 20190105, 'Y', 50),
    ('cust_2', 'CatA', 20190115, 'N', 10),
    ('cust_2', 'CatA', 20190116, 'N', 20),
    ('cust_2', 'CatA', 20190117, 'N', 40),
    ('cust_2', 'CatA', 20190119, 'Y', 10),
    ('cust_2', 'CatA', 20190119, 'Y', 20),
    ('cust_2', 'CatA', 20190120, 'Y', 10),
    ('cust_3', 'CatB', 20190108, 'N', 10),
    ],
    ['cust_id','category','date','weekend_flag','spend']
)
transactions.groupBy('cust_id').apply(agg_data).show()
结果是什么

+-------+-------------------+--------------+-----------+
|cust_id|days_txn_on_weekend|days_txn_cat_a|total_spend|
+-------+-------------------+--------------+-----------+
| cust_2|                  2|             5|        110|
| cust_3|                  0|             0|         10|
| cust_1|                  3|             4|        230|
+-------+-------------------+--------------+-----------+

谢谢,这正是我想要的。我想知道是否有办法通过在Scala中定义这些函数来提高这些函数的性能?