Warning: file_get_contents(/data/phpspider/zhask/data//catemap/3/xpath/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何从Pyspark One vs Rest多类分类器中获取概率_Python_Apache Spark_Pyspark_Prediction - Fatal编程技术网

Python 如何从Pyspark One vs Rest多类分类器中获取概率

Python 如何从Pyspark One vs Rest多类分类器中获取概率,python,apache-spark,pyspark,prediction,Python,Apache Spark,Pyspark,Prediction,Pyspark Onv vs Rest分类器似乎没有提供概率。有办法做到这一点吗 我在下面添加代码。我正在添加标准的多类分类器进行比较 from pyspark.ml.classification import LogisticRegression, OneVsRest from pyspark.ml.evaluation import MulticlassClassificationEvaluator # load data file. inputData = spark.read.form

Pyspark Onv vs Rest分类器似乎没有提供概率。有办法做到这一点吗

我在下面添加代码。我正在添加标准的多类分类器进行比较

from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# load data file.
inputData = spark.read.format("libsvm") \
    .load("/data/mllib/sample_multiclass_classification_data.txt")

(train, test) = inputData.randomSplit([0.8, 0.2])

# instantiate the base classifier.
lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)

# instantiate the One Vs Rest Classifier.
ovr = OneVsRest(classifier=lr)


# train the multiclass model.
ovrModel = ovr.fit(train)
lrm = lr.fit(train)

# score the model on test data.
predictions = ovrModel.transform(test)
predictions2 = lrm.transform(test)

predictions.show(6)
predictions2.show(6)

我不认为你可以访问概率(置信度)向量,因为它取了置信度的最大值,然后丢弃了置信度向量。要进行测试,您可以复制类并修改它,然后删除
.drop(accColName)

# output the index of the classifier with highest confidence as prediction
labelUDF = udf(
    lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
    DoubleType())

# output label and label metadata as prediction
return aggregatedDataset.withColumn(
    self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)