Pyspark 分类变量的矢量索引器或OneHotEncoder?

Pyspark 分类变量的矢量索引器或OneHotEncoder?,pyspark,Pyspark,在Spark中处理作为ML算法输入的分类变量时,对VectorIndexer或OneHotEncoder的用法有点困惑。 是不是当我需要知道ML输出中每个分类级别的效果时,我需要使用一个HotEncoder,而在其他情况下可以使用VectorIndexer 示例如下所示: from pyspark.ml.feature import OneHotEncoder, VectorAssembler , VectorIndexer df = sqlContext.createDataFrame([

在Spark中处理作为ML算法输入的分类变量时,对VectorIndexer或OneHotEncoder的用法有点困惑。 是不是当我需要知道ML输出中每个分类级别的效果时,我需要使用一个HotEncoder,而在其他情况下可以使用VectorIndexer

示例如下所示:

from pyspark.ml.feature import OneHotEncoder, VectorAssembler , VectorIndexer

df = sqlContext.createDataFrame([
    (0.0, 3.0, 3.8),
    (1.0, 0.0, 6.7),
    (2.0, 3.0, 3.3),
    (0.0, 2.0, 1.2),
    (0.0, 1.0, 7.8),
    (2.0, 0.0, 4.4)
], ["category1", "category2","readings"])

encoder = OneHotEncoder(dropLast = True, inputCols=["category1", "category2"],
                        outputCols=["categoryVec1", "categoryVec2"])
model = encoder.fit(df)
encoded = model.transform(df)
encoded.show()


+---------+---------+--------+-------------+-------------+
|category1|category2|readings| categoryVec1| categoryVec2|
+---------+---------+--------+-------------+-------------+
|      0.0|      3.0|     3.8|(2,[0],[1.0])|    (3,[],[])|
|      1.0|      0.0|     6.7|(2,[1],[1.0])|(3,[0],[1.0])|
|      2.0|      3.0|     3.3|    (2,[],[])|    (3,[],[])|
|      0.0|      2.0|     1.2|(2,[0],[1.0])|(3,[2],[1.0])|
|      0.0|      1.0|     7.8|(2,[0],[1.0])|(3,[1],[1.0])|
|      2.0|      0.0|     4.4|    (2,[],[])|(3,[0],[1.0])|
+---------+---------+--------+-------------+-------------+


va = VectorAssembler(inputCols = df.columns , outputCol = 'features')
assembled = va.transform(df)
idx = VectorIndexer(inputCol = 'features', outputCol = 'features_indexed', maxCategories = 4)
idx_model = idx.fit(assembled)
transformed = idx_model.transform(assembled)
transformed.show()

+---------+---------+--------+-------------+----------------+
|category1|category2|readings|     features|features_indexed|
+---------+---------+--------+-------------+----------------+
|      0.0|      3.0|     3.8|[0.0,3.0,3.8]|   [0.0,3.0,3.8]|
|      1.0|      0.0|     6.7|[1.0,0.0,6.7]|   [1.0,0.0,6.7]|
|      2.0|      3.0|     3.3|[2.0,3.0,3.3]|   [2.0,3.0,3.3]|
|      0.0|      2.0|     1.2|[0.0,2.0,1.2]|   [0.0,2.0,1.2]|
|      0.0|      1.0|     7.8|[0.0,1.0,7.8]|   [0.0,1.0,7.8]|
|      2.0|      0.0|     4.4|[2.0,0.0,4.4]|   [2.0,0.0,4.4]|
+---------+---------+--------+-------------+----------------+

idx_model.categoryMaps

{0: {0.0: 0, 1.0: 1, 2.0: 2}, 1: {0.0: 0, 1.0: 1, 2.0: 2, 3.0: 3}}