Warning: file_get_contents(/data/phpspider/zhask/data//catemap/2/python/281.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python scikit学习决策树节点深度_Python_Scikit Learn_Decision Tree - Fatal编程技术网

Python scikit学习决策树节点深度

Python scikit学习决策树节点深度,python,scikit-learn,decision-tree,Python,Scikit Learn,Decision Tree,我的目标是确定决策树中两个样本的分离深度。在scikit learn的开发版本中,您可以使用decision\u path()方法识别最后一个公共节点: from sklearn import tree import numpy as np clf = tree.DecisionTreeClassifier() clf.fit(data, outcomes) n_nodes = clf.tree_.node_count node_indicator = clf.decision_path(da

我的目标是确定决策树中两个样本的分离深度。在scikit learn的开发版本中,您可以使用
decision\u path()
方法识别最后一个公共节点:

from sklearn import tree
import numpy as np

clf = tree.DecisionTreeClassifier()
clf.fit(data, outcomes)
n_nodes = clf.tree_.node_count
node_indicator = clf.decision_path(data).toarray()
sample_ids = [0,1]
common_nodes = (node_indicator[sample_ids].sum(axis=0) == len(sample_ids))
common_node_id = np.arange(n_nodes)[common_nodes]
max_node = np.max(common_node_id)

是否有一种方法可以确定
max\u节点在树中出现的深度,可能是右
clf.tree\uu.children\u
clf.tree\uu.chldren\u左

这里有一个函数,可以用来递归遍历节点并计算节点深度

def get_node_depths(tree):
    """
    Get the node depths of the decision tree

    >>> d = DecisionTreeClassifier()
    >>> d.fit([[1,2,3],[4,5,6],[7,8,9]], [1,2,3])
    >>> get_node_depths(d.tree_)
    array([0, 1, 1, 2, 2])
    """
    def get_node_depths_(current_node, current_depth, l, r, depths):
        depths += [current_depth]
        if l[current_node] != -1 and r[current_node] != -1:
            get_node_depths_(l[current_node], current_depth + 1, l, r, depths)
            get_node_depths_(r[current_node], current_depth + 1, l, r, depths)

    depths = []
    get_node_depths_(0, 0, tree.children_left, tree.children_right, depths) 
    return np.array(depths)