Python Spark决策树模型中节点信息的获取

Python Spark决策树模型中节点信息的获取,python,scala,pyspark,apache-spark-mllib,apache-spark-ml,Python,Scala,Pyspark,Apache Spark Mllib,Apache Spark Ml,我想通过Spark MLlib的决策树获得关于生成模型的每个节点的更详细信息。使用API我能得到的最接近的结果是print(model.toDebugString()),它返回类似这样的结果(取自PySpark文档) DecisionTreeModel分类器,深度为1,有3个节点 如果(功能0.0) 预测:1.0 如何修改MLlib源代码以获得每个节点的杂质和深度?(如果有必要,我如何在PySpark中调用新的Scala函数?不幸的是,我找不到任何方法直接访问PySpark或Spark(Sca

我想通过Spark MLlib的决策树获得关于生成模型的每个节点的更详细信息。使用API我能得到的最接近的结果是
print(model.toDebugString())
,它返回类似这样的结果(取自PySpark文档)

DecisionTreeModel分类器,深度为1,有3个节点
如果(功能0.0)
预测:1.0

如何修改MLlib源代码以获得每个节点的杂质和深度?(如果有必要,我如何在PySpark中调用新的Scala函数?

不幸的是,我找不到任何方法直接访问PySpark或Spark(Scala API)中的节点。但是有一种方法可以从根节点开始,遍历到不同的节点

(我刚才在这里提到了杂质,但对于深度,人们可以很容易地用
subtreeDepth
替换
杂质

假设决策树模型实例为
dt

皮斯帕克 现在,如果我们看一下适用于
root
的方法:

dir(root)
[u'apply', u'deepCopy', u'emptyNode', u'equals', 'getClass', u'getNode', u'hashCode', u'id', 'impurity', u'impurity_$eq', u'indexToLevel', u'initializeLogIfNecessary', u'isLeaf', u'isLeaf_$eq', u'isLeftChild', u'isTraceEnabled', u'leftChildIndex', u'leftNode', u'leftNode_$eq', u'log', u'logDebug', u'logError', u'logInfo', u'logName', u'logTrace', u'logWarning', u'maxNodesInLevel', u'notify', u'notifyAll', u'numDescendants', u'org$apache$spark$internal$Logging$$log_', u'org$apache$spark$internal$Logging$$log__$eq', u'parentIndex', u'predict', u'predict_$eq', u'rightChildIndex', u'rightNode', u'rightNode_$eq', u'split', u'split_$eq', u'startIndexInLevel', u'stats', u'stats_$eq', u'subtreeDepth', u'subtreeIterator', u'subtreeToString', u'subtreeToString$default$1', u'toString', u'wait']
我们可以做到:

root.leftNode().get().impurity()
root.leftNode.get.impurity
这可能会深入到树中,例如:

root.leftNode().get().rightNode().get().impurity()
root.leftNode.get.rightNode.get.impurity
由于在应用
leftNode()
rightNode()
之后,我们得到了一个
选项
,因此必须应用
get
或getOrElse
才能得到所需的
节点类型

如果你想知道我是怎么学会这些奇怪的方法的,我得承认,我作弊了!!,i、 e.我首先研究了Scala API:

火花 假设
dt
相同,以下几行与上述几行完全相同,并给出相同的结果:

val root = dt.topNode
root.impurity
我们可以做到:

root.leftNode().get().impurity()
root.leftNode.get.impurity
这可能会深入到树中,例如:

root.leftNode().get().rightNode().get().impurity()
root.leftNode.get.rightNode.get.impurity

我将通过描述我如何使用PySpark 2.4.3来补充@mostOfMajority的答案

根节点 给定一个经过训练的决策树模型,以下是获取其根节点的方法:

def _get_root_node(tree: DecisionTreeClassificationModel):
    return tree._call_java('rootNode')
杂质 我们可以通过从根节点向下遍历树来获取杂质。可以这样做:

def get_impurities(tree: DecisionTreeClassificationModel) -> List[float]:
    def recur(node):
        if node.numDescendants() == 0:
            return []
        ni = node.impurity()
        return (
            recur(node.leftChild()) + [ni] + recur(node.rightChild())
        )
    return recur(_get_root_node(tree))
例子 [1]中的
:打印(tree.toDebugString)
深度为3且具有7个节点的DecisionTreeClassifier模型(uid=DecisionTreeClassifier_f90ba6dbb0fe)
如果(功能0.6.5)
预测:0.0
在[2]中:cat.get_杂质(树)
输出[2]:[0.4444,0.5,0.5]

决策树模型实例在pyspark 2.3中没有call()方法。您使用的是什么版本的spark?这是pyspark 2.2.1的版本。尚未尝试2.3。@sgu找到pyspark2.3的解决方案了吗?