在pyspark中创建一列,该列在第一行之后引用自身

在pyspark中创建一列,该列在第一行之后引用自身,pyspark,lag,Pyspark,Lag,我想在pyspark中创建一列,在第一行之后引用它自己 Customer | Week | Price | Index change | Column to be created A 1 10 0.5 10 A 2 13 0.1 10* (1+0.1) = 11 A 3 16 0.6 11* (1+0.6

我想在pyspark中创建一列,在第一行之后引用它自己

Customer  | Week | Price | Index change | Column to be created
A           1      10      0.5            10  
A           2      13      0.1            10* (1+0.1)  = 11
A           3      16      0.6            11* (1+0.6)  = 17.6
A           4      16      0.1            17.6 * (1+0.1)  = 19.36
此数据集中有多个客户,每个客户都有52周的时间。我知道我必须使用一个窗口函数,但是我在应用它的同时,在创建一个函数时遇到了困难,这个函数在第一行之后基本上引用了它自己,它引用了另一列。我觉得它应该像下面这样,但不知道如何使它工作,如果你可以参考一个专栏,而它正在创建

df = df.withColumn('Column to be created', 
                    F.when(F.col('week') != 1, 
                    lag(df['Column to be created'])*(1+df['Index change']).over(win))
                    .otherwise(F.col('Price')))

* win refers to a partitionby that I have created already

据我所知,你们正试图根据指数变化来调整价格。另外,想想看,我们不能在创建新列之前重用它们。我尝试了我的方式,希望这能有所帮助

dff= spark.createDataFrame([('A',1,10,0.5),('A',2,13,0.1),('A',3,16,0.6),('A',4,16,0.1)],['Customer', 'Week', 'Price', 'Index_change'])
dff.show()
+--------+----+-----+------------+
|Customer|Week|Price|Index_change|
+--------+----+-----+------------+
|       A|   1|   10|         0.5|
|       A|   2|   13|         0.1|
|       A|   3|   16|         0.6|
|       A|   4|   16|         0.1|
+--------+----+-----+------------+

from pyspark.sql import Window
from pyspark.sql import functions as F

w = Window.partitionBy('Customer').orderBy('week').rowsBetween(Window.unboundedPreceding,0)

#2nd row : 10*(1+0.1),3rd row: 10*(1+0.1)*(1+0.6),4th row: 10*(1+0.1)*(1+0.6)*(1+0.1)..so on
#say, for 3rd  row, you need cumulative product of (index_change + 1). In algebra, log(a*b) = log(a)+log(b),using this,log_sum:log(1+0.1)+log(1+0.6)
#cum_idx : to convert from log space back to original space,we use exp(log value).

log_sum = F.sum(F.when(F.col('Week')!=1,F.log(F.col("index_change") + 1))).over(w) # sum of logs = multiplying them
cum_idx = F.exp(log_sum) # back to original
base_value = F.first('Price').over(w) # getting base value

dff = dff.withColumn('new_column',F.when(F.col('Week') != 1,cum_idx*base_value).otherwise(F.col('Price')))

+--------+----+-----+------------+------------------+
|Customer|Week|Price|Index_change|        new_column|
+--------+----+-----+------------+------------------+
|       A|   1|   10|         0.5|              10.0|
|       A|   2|   13|         0.1|              11.0|
|       A|   3|   16|         0.6|              17.6|
|       A|   4|   16|         0.1|             19.36|
+--------+----+-----+------------+------------------+

您可以使用一个udf和两个窗口函数

data = sc.parallelize([
    ('A', 1, 10, 0.5),
    ('A', 2, 13, 0.1),
    ('A', 3, 16, 0.6),
    ('A', 4, 16, 0.1),
    ('B', 1, 10, 0.5),
    ('B', 2, 13, 0.1),
    ('B', 3, 16, 0.6),
    ('B', 4, 16, 0.1),
])

df = spark.createDataFrame(data, ['Customer', 'Week', 'Price', 'Index change'])

window1 = Window.partitionBy('Customer').orderBy('week')
window2 = Window.partitionBy('Customer').orderBy('week').rangeBetween(Window.unboundedPreceding, 0)

from functools import reduce

@F.udf(FloatType())
def mul_list(l):
    if len(l) == 1:
        return None
    else:
        return reduce(lambda x,y: x*y, l[1:])
    
df.withColumn('new_col', F.collect_list(F.col('Index change') + 1).over(window2))\
.withColumn('mult', mul_list('new_col'))\
.withColumn('result', F.first(F.col('Price')).over(window1) * F.coalesce(F.col('mult'), F.lit(1))).show()
导致

+--------+----+-----+------------+--------------------+-----+------+
|Customer|Week|Price|Index change|             new_col| mult|result|
+--------+----+-----+------------+--------------------+-----+------+
|       B|   1|   10|         0.5|               [1.5]| null|  10.0|
|       B|   2|   13|         0.1|          [1.5, 1.1]|  1.1|  11.0|
|       B|   3|   16|         0.6|     [1.5, 1.1, 1.6]| 1.76|  17.6|
|       B|   4|   16|         0.1|[1.5, 1.1, 1.6, 1.1]|1.936| 19.36|
|       A|   1|   10|         0.5|               [1.5]| null|  10.0|
|       A|   2|   13|         0.1|          [1.5, 1.1]|  1.1|  11.0|
|       A|   3|   16|         0.6|     [1.5, 1.1, 1.6]| 1.76|  17.6|
|       A|   4|   16|         0.1|[1.5, 1.1, 1.6, 1.1]|1.936| 19.36|
+--------+----+-----+------------+--------------------+-----+------+

我创建了新列以使中间步骤更加明确。

为什么不在第一行中使用索引更改?这对吗?谢谢!虽然我不确定我是否完全理解log_sum和cum_idx中发生的事情,但这是可行的,你能解释一下吗?更新了答案并做了一些解释。希望有帮助。