Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/apache-spark/5.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Apache spark pyspark数据帧中的平面映射集合_Apache Spark_Pyspark_Apache Spark Sql - Fatal编程技术网

Apache spark pyspark数据帧中的平面映射集合

Apache spark pyspark数据帧中的平面映射集合,apache-spark,pyspark,apache-spark-sql,Apache Spark,Pyspark,Apache Spark Sql,我有两个数据帧,在使用groupby之后,我正在agg中使用collect\u set()。聚合后生成的数组的最佳方式是flatMap schema = ['col1', 'col2', 'col3', 'col4'] a = [[1, [23, 32], [11, 22], [9989]]] df1 = spark.createDataFrame(a, schema=schema) b = [[1, [34], [43, 22], [888, 777]]] df2 = spark.cr

我有两个数据帧,在使用
groupby
之后,我正在agg中使用
collect\u set()。聚合后生成的数组的最佳方式是
flatMap

schema = ['col1', 'col2', 'col3', 'col4']

a = [[1, [23, 32], [11, 22], [9989]]]

df1 = spark.createDataFrame(a, schema=schema)

b = [[1, [34], [43, 22], [888, 777]]]

df2 = spark.createDataFrame(b, schema=schema)

df = df1.union(
        df2
    ).groupby(
        'col1'
    ).agg(
        collect_set('col2').alias('col2'),
        collect_set('col3').alias('col3'),
        collect_set('col4').alias('col4')
    )

df.collect()
我将此作为输出:

[Row(col1=1, col2=[[34], [23, 32]], col3=[[11, 22], [43, 22]], col4=[[9989], [888, 777]])]
[Row(col1=1, col2=[23, 32, 34], col3=[11, 22, 43], col4=[9989, 888, 777])]
但是,我希望将此作为输出:

[Row(col1=1, col2=[[34], [23, 32]], col3=[[11, 22], [43, 22]], col4=[[9989], [888, 777]])]
[Row(col1=1, col2=[23, 32, 34], col3=[11, 22, 43], col4=[9989, 888, 777])]

您可以使用
udf

from itertools import chain
from pyspark.sql.types import *
from pyspark.sql.functions import udf

flatten = udf(lambda x: list(chain.from_iterable(x)), ArrayType(IntegerType()))

df.withColumn('col2_flat', flatten('col2'))