Python 将pyspark dataframe中列的字符串列表转换为一个热编码的字符串

Python 将pyspark dataframe中列的字符串列表转换为一个热编码的字符串,python,dataframe,apache-spark,pyspark,Python,Dataframe,Apache Spark,Pyspark,我的问题与我的前一个问题相关: 我已经创建了一个表“my_df”(pyspark中的数据帧): 现在,我需要对表进行聚类,以便找到“id”的相似性。 一开始我在尝试k-means。因此,我需要通过一个热编码将分类值转换为数值。 我指的是 我的代码: from pyspark.ml import Pipeline from pyspark.ml.feature import StringIndexer, OneHotEncoderEstimator inputs, my_indx_list =

我的问题与我的前一个问题相关:

我已经创建了一个表“my_df”(pyspark中的数据帧):

现在,我需要对表进行聚类,以便找到“id”的相似性。 一开始我在尝试k-means。因此,我需要通过一个热编码将分类值转换为数值。 我指的是

我的代码:

from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoderEstimator

inputs, my_indx_list = [], []
for a_col in my_df.columns: 
  my_indx = StringIndexer(inputCol = a_col, outputCol = a_col + "_index")
  inputs.append(my_indx.getOutputCol())
  my_indx_list.append(my_indx)

  encoder = OneHotEncoderEstimator(inputCols=inputs, outputCols=[x + "_vector" for x in inputs])
  a_pipeline = Pipeline(stages = my_indx_list + [encoder])
  pipeline.fit(my_df).transform(my_df).show() # error here !
但是,我有一个错误:

A column must be either string type or numeric type, but got ArrayType(StringType,true)
那么,我怎样才能解决这个问题呢

我的想法是:对每列的列表值进行排序,并将列表中的每个字符串连接为每列的长字符串

但是,对于每一列,这些值都是一些调查问题的答案,并且每个答案都具有相同的权重。 我不知道该怎么解决

谢谢 更新

根据提出的解决方案,它可以工作,但速度非常慢。 在一个拥有300 GB内存和32核的集群上运行大约需要3.5个小时

我的代码:

   from pyspark.ml.feature import CountVectorizer
   tmp_df = original_df # 3.5 million rows and 300 columns

   for a_col in original_df.columns: 
        a_vec = CountVectorizer(inputCol = a_col, outputCol = a_col + "_index", binary=True)
        tmp_df = a_vec.fit(tmp_df).transform(tmp_df)

  tmp_df.show()
“原始_df”有350万行和300列

我怎样才能加速


感谢@jxc建议在您的案例中使用
CountVectorizer
作为一种热编码,它通常用于自然语言处理中的令牌计数

使用
CountVectorizer
可以省去使用
onehotcoderestimator处理
explode
collect\u set
时的麻烦;如果您试图使用
udf
实现它,则情况会更糟

考虑到这个数据帧

df = spark.createDataFrame([
                            {'id': 'dapd', 'payment': ['credit', 'cash'], 'shop': ['retail', 'on-line']},
                            {'id': 'wrfr', 'payment': ['cash', 'debit'], 'shop': ['supermarket', 'brand store']}
                           ])
df.show()

+----+--------------+--------------------+
|  id|       payment|                shop|
+----+--------------+--------------------+
|dapd|[credit, cash]|   [retail, on-line]|
|wrfr| [cash, debit]|[supermarket, bra...|
+----+--------------+--------------------+
您可以通过在自然语言处理中将字符串数组作为标记来进行热编码。注意使用
binary=True
强制它只返回0或1

from pyspark.ml.feature import CountVectorizer

payment_cv = CountVectorizer(inputCol="payment", outputCol="paymentEnc", binary=True)
first_res_df = payment_cv.fit(df).transform(df)

shop_cv = CountVectorizer(inputCol="shop", outputCol="shopEnc", binary=True)
final_res_df = shop_cv.fit(first_res_df).transform(first_res_df)

final_res_df.show()

+----+--------------+--------------------+-------------------+-------------------+
|  id|       payment|                shop|         paymentEnc|            shopEnc|
+----+--------------+--------------------+-------------------+-------------------+
|dapd|[credit, cash]|   [retail, on-line]|(3,[0,2],[1.0,1.0])|(4,[0,3],[1.0,1.0])|
|wrfr| [cash, debit]|[supermarket, bra...|(3,[0,1],[1.0,1.0])|(4,[1,2],[1.0,1.0])|
+----+--------------+--------------------+-------------------+-------------------+

你是怎么得到这样的输出的,是手动的吗?“+-”扫描您是否添加了返回错误的代码行?添加的代码,感谢两个数组列,请使用CountVectorizer with
binary=True
:请查看一下:它工作正常,但速度非常慢。拥有350万行和300列的数据帧大约需要3个小时。有可能加快速度吗?感谢您正确地设置了执行器,Spark计算速度慢的最常见原因是数据分区不好。尝试
重新分区
您的数据帧,并检查您的Spark UI(默认端口为4040)中的任务是否正确地使您的执行者饱和。您能告诉我如何确定“numPartitions”的最佳值吗?我当前的分区是200。
from pyspark.ml.feature import CountVectorizer

payment_cv = CountVectorizer(inputCol="payment", outputCol="paymentEnc", binary=True)
first_res_df = payment_cv.fit(df).transform(df)

shop_cv = CountVectorizer(inputCol="shop", outputCol="shopEnc", binary=True)
final_res_df = shop_cv.fit(first_res_df).transform(first_res_df)

final_res_df.show()

+----+--------------+--------------------+-------------------+-------------------+
|  id|       payment|                shop|         paymentEnc|            shopEnc|
+----+--------------+--------------------+-------------------+-------------------+
|dapd|[credit, cash]|   [retail, on-line]|(3,[0,2],[1.0,1.0])|(4,[0,3],[1.0,1.0])|
|wrfr| [cash, debit]|[supermarket, bra...|(3,[0,1],[1.0,1.0])|(4,[1,2],[1.0,1.0])|
+----+--------------+--------------------+-------------------+-------------------+