Sql 为Spark ML编码转置或透视分类变量数组的最佳方法

Sql 为Spark ML编码转置或透视分类变量数组的最佳方法,sql,scala,apache-spark,machine-learning,apache-spark-sql,Sql,Scala,Apache Spark,Machine Learning,Apache Spark Sql,我正在为spark ML模型设置分类变量。我没有一个包含单个分类变量的列,而是一个包含分类变量数组的列。参见下面的示例数据 (尽管这些是数字,但它们代表一个类别) 我需要将这些特性分离为单独的特性,例如,重要的是要保留#1、#3、#6和#7具有类别19,而不管数组中还有哪些其他类别 我可以使用SQL手动识别所有分类变量,并为每个变量创建一列。但这看起来并不优雅,我认为必须有更好的方法让所有类别都以列为轴心,然后指定1或0,这可能是一个热编码。或者,我想知道是否有一个更好的方法来思考这个问题 我正

我正在为spark ML模型设置分类变量。我没有一个包含单个分类变量的列,而是一个包含分类变量数组的列。参见下面的示例数据

(尽管这些是数字,但它们代表一个类别)

我需要将这些特性分离为单独的特性,例如,重要的是要保留#1、#3、#6和#7具有类别19,而不管数组中还有哪些其他类别

我可以使用SQL手动识别所有分类变量,并为每个变量创建一列。但这看起来并不优雅,我认为必须有更好的方法让所有类别都以列为轴心,然后指定1或0,这可能是一个热编码。或者,我想知道是否有一个更好的方法来思考这个问题

我正在使用scala 2.2.0(目前无法升级),因此无法使用较新的阵列函数

+---------------+----------------+
|id             |categorical_code|
+---------------+----------------+
|1              |           [19] |
|2              |       [87, 19] |
|3              |           [18] |
|4              |           [96] |
|5              |           [18] |
|6              |  [111, 22, 19] |
|7              |  [161, 19, 18] |
|8              |           [12] |
|9              |          [170] |
+---------------+----------------+
输出需要(我认为)类似于:

id,cat_12,cat_18,cat_19,cat_22,cat_87,cat_111,cat_161,cat_170
1,,,1,,,,,
2,,,1,,1,,,
3,,1,,,,,,
4,,,,,,,,
5,,1,,,,,,
6,,,1,1,,1,1,
7,,1,1,,,,,
8,1,,,,,,,1
9,,,,,,,,

我们可以将数组分解成单独的行,然后使用groupby透视获得所需的输出

val df2 =
  df.
    select(
      df("id"),
      explode(df("categorical_code")).as("categorical_code"),
      lit(1).as("categorical_code_exist")
    )

df2.show()
+---+----------------+----------------------+
| id|categorical_code|categorical_code_exist|
+---+----------------+----------------------+
|  1|              19|                     1|
|  2|              87|                     1|
|  2|              19|                     1|
|  3|              18|                     1|
|  4|              96|                     1|
|  5|              18|                     1|
|  6|             111|                     1|
|  6|              22|                     1|
|  6|              19|                     1|
|  7|             161|                     1|
|  7|              19|                     1|
|  7|              18|                     1|
|  8|              12|                     1|
|  9|             170|                     1|
+---+----------------+----------------------+

val df3 =
  df2.
    groupBy("id").
    pivot("categorical_code").
    agg(coalesce(first(df2("categorical_code_exist")))).
    orderBy("id")

df3.show()
+---+----+----+----+----+----+----+----+----+----+
| id|  12|  18|  19|  22|  87|  96| 111| 161| 170|
+---+----+----+----+----+----+----+----+----+----+
|  1|null|null|   1|null|null|null|null|null|null|
|  2|null|null|   1|null|   1|null|null|null|null|
|  3|null|   1|null|null|null|null|null|null|null|
|  4|null|null|null|null|null|   1|null|null|null|
|  5|null|   1|null|null|null|null|null|null|null|
|  6|null|null|   1|   1|null|null|   1|null|null|
|  7|null|   1|   1|null|null|null|null|   1|null|
|  8|   1|null|null|null|null|null|null|null|null|
|  9|null|null|null|null|null|null|null|null|   1|
+---+----+----+----+----+----+----+----+----+----+

df3.printSchema()
root
 |-- id: integer (nullable = true)
 |-- 12: integer (nullable = true)
 |-- 18: integer (nullable = true)
 |-- 19: integer (nullable = true)
 |-- 22: integer (nullable = true)
 |-- 87: integer (nullable = true)
 |-- 96: integer (nullable = true)
 |-- 111: integer (nullable = true)
 |-- 161: integer (nullable = true)
 |-- 170: integer (nullable = true)