Dataframe 如何将一列添加到pyspark数据帧中,该数据帧包含基于另一列分组的一列的平均值

Dataframe 如何将一列添加到pyspark数据帧中,该数据帧包含基于另一列分组的一列的平均值,dataframe,pyspark,aggregate,mean,Dataframe,Pyspark,Aggregate,Mean,它与其他一些问题相似,但不同 假设我们有一个pyspark数据帧df,如下所示: +-----+------+-----+ |col1 | col2 | col3| +-----+------+-----+ |A | 5 | 6 | +-----+------+-----+ |A | 5 | 8 | +-----+------+-----+ |A | 6 | 3 | +-----+-

它与其他一些问题相似,但不同

假设我们有一个pyspark数据帧df,如下所示:

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        
我想添加另一列作为new_col,它包含基于col1分组的col2的平均值。因此,答案必须如下

   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+

任何帮助都将不胜感激。

步骤1:创建数据帧

from pyspark.sql.functions import avg, col
from pyspark.sql.window import Window
values = [('A',5,6),('A',5,8),('A',6,3),('A',5,9),('B',9,6),('B',3,8),('B',9,8),('C',3,4),('C',5,1)]
df = sqlContext.createDataFrame(values,['col1','col2','col3'])
df.show()
+----+----+----+
|col1|col2|col3|
+----+----+----+
|   A|   5|   6|
|   A|   5|   8|
|   A|   6|   3|
|   A|   5|   9|
|   B|   9|   6|
|   B|   3|   8|
|   B|   9|   8|
|   C|   3|   4|
|   C|   5|   1|
+----+----+----+
步骤2:通过在列
A
上分组,创建另一个具有
平均值的列

w = Window().partitionBy('col1')
df = df.withColumn('new_col',avg(col('col2')).over(w))
df.show()
+----+----+----+-------+
|col1|col2|col3|new_col|
+----+----+----+-------+
|   B|   9|   6|    7.0|
|   B|   3|   8|    7.0|
|   B|   9|   8|    7.0|
|   C|   3|   4|    4.0|
|   C|   5|   1|    4.0|
|   A|   5|   6|   5.25|
|   A|   5|   8|   5.25|
|   A|   6|   3|   5.25|
|   A|   5|   9|   5.25|
+----+----+----+-------+

好的,经过多次尝试,我可以自己回答这个问题了。我在这里为其他有类似问题的人发布答案。这里的原始文件是csv文件

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
#reading the file
df = spark.read.csv('file's name.csv', header=True)
df.show()
输出

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        


from pyspark.sql import functions as func
#Grouping the dataframe based on col1
col1group = df.groupBy('col1')
#Computing the average of col2 based on the grouping on col1
a= col1group.agg(func.avg("col2"))
a.show()
+-----+----------+
|col1 | avg(col2)|
+-----+----------+
| A   |   5.25   |
+-----+----------+
| B   |   7.0    |
+-----+----------+
| C   |   4.0    |
+-----+----------+
   +-----+------+------+---------+
   |col1 | col2 | col3 |avg(col2)|
   +-----+------+------+---------+
   |  A  |   5  |  6   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  8   | 5.25    |
   +-----+------+------+---------+
   |  A  |   6  |  3   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  9   | 5.25    |
   +-----+------+------+---------+
   |  B  |   9  |  6   | 7       |
   +-----+------+------+---------+
   |  B  |   3  |  8   | 7       |
   +-----+------+------+---------+    
   |  B  |   9  |  8   | 7       |
   +-----+------+------+---------+
   |  C  |   3  |  4   | 4       |
   +-----+------+------+---------+
   |  C  |   5  |  1   | 4       |
   +-----+------+------+---------+
   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+
输出

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        


from pyspark.sql import functions as func
#Grouping the dataframe based on col1
col1group = df.groupBy('col1')
#Computing the average of col2 based on the grouping on col1
a= col1group.agg(func.avg("col2"))
a.show()
+-----+----------+
|col1 | avg(col2)|
+-----+----------+
| A   |   5.25   |
+-----+----------+
| B   |   7.0    |
+-----+----------+
| C   |   4.0    |
+-----+----------+
   +-----+------+------+---------+
   |col1 | col2 | col3 |avg(col2)|
   +-----+------+------+---------+
   |  A  |   5  |  6   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  8   | 5.25    |
   +-----+------+------+---------+
   |  A  |   6  |  3   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  9   | 5.25    |
   +-----+------+------+---------+
   |  B  |   9  |  6   | 7       |
   +-----+------+------+---------+
   |  B  |   3  |  8   | 7       |
   +-----+------+------+---------+    
   |  B  |   9  |  8   | 7       |
   +-----+------+------+---------+
   |  C  |   3  |  4   | 4       |
   +-----+------+------+---------+
   |  C  |   5  |  1   | 4       |
   +-----+------+------+---------+
   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+
现在,我们将最后一个表与初始数据帧连接起来,以生成所需的数据帧:

df=test1.join(a, on = 'lable', how = 'inner')
df.show()
输出

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        


from pyspark.sql import functions as func
#Grouping the dataframe based on col1
col1group = df.groupBy('col1')
#Computing the average of col2 based on the grouping on col1
a= col1group.agg(func.avg("col2"))
a.show()
+-----+----------+
|col1 | avg(col2)|
+-----+----------+
| A   |   5.25   |
+-----+----------+
| B   |   7.0    |
+-----+----------+
| C   |   4.0    |
+-----+----------+
   +-----+------+------+---------+
   |col1 | col2 | col3 |avg(col2)|
   +-----+------+------+---------+
   |  A  |   5  |  6   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  8   | 5.25    |
   +-----+------+------+---------+
   |  A  |   6  |  3   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  9   | 5.25    |
   +-----+------+------+---------+
   |  B  |   9  |  6   | 7       |
   +-----+------+------+---------+
   |  B  |   3  |  8   | 7       |
   +-----+------+------+---------+    
   |  B  |   9  |  8   | 7       |
   +-----+------+------+---------+
   |  C  |   3  |  4   | 4       |
   +-----+------+------+---------+
   |  C  |   5  |  1   | 4       |
   +-----+------+------+---------+
   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+
现在将最后一列的名称更改为我们想要的名称

df = df.withColumnRenamed('avg(val1)', 'new_col')
df.show()
输出

+-----+------+-----+        
|col1 | col2 | col3| 
+-----+------+-----+        
|A    |   5  |  6  |
+-----+------+-----+        
|A    |   5  |  8  |
+-----+------+-----+        
|A    |   6  |  3  |
+-----+------+-----+        
|A    |   5  |  9  |
+-----+------+-----+        
|B    |   9  |  6  |
+-----+------+-----+        
|B    |   3  |  8  |
+-----+------+-----+        
|B    |   9  |  8  |
+-----+------+-----+        
|C    |  3   |  4  |
+-----+------+-----+        
|C    |  5   |  1  |
+-----+------+-----+        


from pyspark.sql import functions as func
#Grouping the dataframe based on col1
col1group = df.groupBy('col1')
#Computing the average of col2 based on the grouping on col1
a= col1group.agg(func.avg("col2"))
a.show()
+-----+----------+
|col1 | avg(col2)|
+-----+----------+
| A   |   5.25   |
+-----+----------+
| B   |   7.0    |
+-----+----------+
| C   |   4.0    |
+-----+----------+
   +-----+------+------+---------+
   |col1 | col2 | col3 |avg(col2)|
   +-----+------+------+---------+
   |  A  |   5  |  6   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  8   | 5.25    |
   +-----+------+------+---------+
   |  A  |   6  |  3   | 5.25    |
   +-----+------+------+---------+
   |  A  |   5  |  9   | 5.25    |
   +-----+------+------+---------+
   |  B  |   9  |  6   | 7       |
   +-----+------+------+---------+
   |  B  |   3  |  8   | 7       |
   +-----+------+------+---------+    
   |  B  |   9  |  8   | 7       |
   +-----+------+------+---------+
   |  C  |   3  |  4   | 4       |
   +-----+------+------+---------+
   |  C  |   5  |  1   | 4       |
   +-----+------+------+---------+
   +-----+------+------+--------+
   |col1 | col2 | col3 | new_col|
   +-----+------+------+--------+
   |  A  |   5  |  6   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  8   | 5.25   |
   +-----+------+------+--------+
   |  A  |   6  |  3   | 5.25   |
   +-----+------+------+--------+
   |  A  |   5  |  9   | 5.25   |
   +-----+------+------+--------+
   |  B  |   9  |  6   | 7      |
   +-----+------+------+--------+
   |  B  |   3  |  8   | 7      |
   +-----+------+------+--------+    
   |  B  |   9  |  8   | 7      |
   +-----+------+------+--------+
   |  C  |   3  |  4   | 4      |
   +-----+------+------+--------+
   |  C  |   5  |  1   | 4      |
   +-----+------+------+--------+

请提供一个示例数据和预期输出。非常感谢,但我得到了以下错误:AttributeError:“NoneType”对象在“w=Window()”之后没有属性“\u jvm”。partitionBy('col1')”检查此项-我很新。如何导入sqlContext。您能提到您导入的所有软件包吗?我的数据框是从csv文件导入的。这可能是你的和我的不同吗?我使用的是
https://community.cloud.databricks.com
我从网站上获得了我的
sqlContext
。请在
stackoverflow
中搜索此-
AttributeError:'NoneType'对象没有属性'\u jvm windows()
。有许多链接可以解决您的问题。因为,我不能重现这个错误,我不知道错误到底在哪里。