Warning: file_get_contents(/data/phpspider/zhask/data//catemap/7/elixir/2.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python PySpark决策树模型的精度和召回率与手动结果不同_Python_Apache Spark_Pyspark - Fatal编程技术网

Python PySpark决策树模型的精度和召回率与手动结果不同

Python PySpark决策树模型的精度和召回率与手动结果不同,python,apache-spark,pyspark,Python,Apache Spark,Pyspark,我在PySpark数据帧上训练了一个DecisionTree模型。生成的数据帧模拟如下: rdd = sc.parallelize( [ (0., 1.), (0., 0.), (0., 0.), (1., 1.), (1.,0.), (1.,0.), (1.,1.), (1.,1.) ] ) df = sqlContext.createDat

我在PySpark数据帧上训练了一个
DecisionTree
模型。生成的数据帧模拟如下:

rdd = sc.parallelize(
    [
        (0., 1.), 
        (0., 0.), 
        (0., 0.), 
        (1., 1.), 
        (1.,0.), 
        (1.,0.),
        (1.,1.),
        (1.,1.)
    ]
)
df = sqlContext.createDataFrame(rdd, ["prediction", "target_index"])
df.show()
+----------+------------+
|prediction|target_index|
+----------+------------+
|       0.0|         1.0|
|       0.0|         0.0|
|       0.0|         0.0|
|       1.0|         1.0|
|       1.0|         0.0|
|       1.0|         0.0|
|       1.0|         1.0|
|       1.0|         1.0|
+----------+------------+
让我们计算一个指标,回想一下:

metricsp = MulticlassMetrics(df.rdd)
print metricsp.recall()
0.625
嗯。让我们尝试确认这是正确的:

tp = df[(df.target_index == 1) & (df.prediction == 1)].count()
tn = df[(df.target_index == 0) & (df.prediction == 0)].count()
fp = df[(df.target_index == 0) & (df.prediction == 1)].count()
fn = df[(df.target_index == 1) & (df.prediction == 0)].count()
print "True Positives:", tp
print "True Negatives:", tn
print "False Positives:", fp
print "False Negatives:", fn
print "Total", df.count()
True Positives: 3
True Negatives: 2
False Positives: 2
False Negatives: 1
Total 8
并计算召回率:

r = float(tp)/(tp + fn)
print "recall", r

recall 0.75
结果不同。我做错了什么

顺便说一句,
Metrics
类中的所有函数都给出相同的结果:

print metricsp.recall()
print metricsp.precision()
print metricsp.fMeasure()
0.625
0.625
0.625

问题是您正在使用MultiClassMetrics处理二进制分类器的输出。从:

要获得正确的结果,请使用recall(标签=1):

顺便说一句,您的
df.show()
中的标题看起来很混乱,应该是:

+----------+------------+
|prediction|target_index|
+----------+------------+
|       0.0|         1.0|
|       0.0|         0.0|
|       0.0|         0.0|
|       1.0|         1.0|
|       1.0|         0.0|
|       1.0|         0.0|
|       1.0|         1.0|
|       1.0|         1.0|
+----------+------------+

你的标题乱七八糟?
>>> print metricsp.recall(label=1)
0.75
+----------+------------+
|prediction|target_index|
+----------+------------+
|       0.0|         1.0|
|       0.0|         0.0|
|       0.0|         0.0|
|       1.0|         1.0|
|       1.0|         0.0|
|       1.0|         0.0|
|       1.0|         1.0|
|       1.0|         1.0|
+----------+------------+