Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/wpf/13.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
Python 如何获取scikit学习决策树所有节点的pos/neg实例计数?_Python_Scikit Learn_Decision Tree - Fatal编程技术网

Python 如何获取scikit学习决策树所有节点的pos/neg实例计数?

Python 如何获取scikit学习决策树所有节点的pos/neg实例计数?,python,scikit-learn,decision-tree,Python,Scikit Learn,Decision Tree,我训练了一个sklearn决策树 from sklearn.tree import DecisionTreeClassifier c=DecisionTreeClassifier(class_weight="auto") c.fit([[0,0], [0,1], [1,1], ],[0,1,0]) 现在我想检查每个节点有多少阳性/阴性样本。因此,一个类似于 counts: [2,1] labels: (010)

我训练了一个sklearn决策树

from sklearn.tree import DecisionTreeClassifier
c=DecisionTreeClassifier(class_weight="auto")
c.fit([[0,0],
       [0,1],
       [1,1],
      ],[0,1,0])
现在我想检查每个节点有多少阳性/阴性样本。因此,一个类似于

  counts: [2,1]            labels: (010)
                                 split by x0
    [1,1]       [1,0]         (01)        (0)
                           split by x1
 [1,0] [0,1]      0        (0)   (1)
   0     1
我如何从经过训练的决策树中得到这个(左计数)


我可以看到一个
c.tree\uu
变量,但内容似乎不是很有用。有零、权重。。。很难猜测如何返回计数。

每个类的样本数存储在
tree.value
中,但是它只存储叶子的节点值,所以我使用后序遍历来获取所有节点的值

import numpy as np

def get_value(dt):
    left = dt.tree_.children_left
    right = dt.tree_.children_right
    value = dt.tree_.value
    leaves = np.argwhere(left == -1)[:, 0]

    def visit(node):
        if node in leaves:
            return
        visit(left[node])
        visit(right[node])
        value[node, :] = value[left[node], :] + value[right[node], :]

    visit(0)
    return value
In [1]: from sklearn.tree import DecisionTreeClassifier

In [2]: dt = DecisionTreeClassifier()

In [3]: dt.fit([[0,0],
   ...:         [0,1],
   ...:         [1,1]], [0,1,0])
Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            random_state=None, splitter='best')

In [4]: dt.tree_.value
Out[4]:
array([[[ 2.,  1.]],

       [[ 1.,  1.]],

       [[ 1.,  0.]],

       [[ 0.,  1.]],

       [[ 1.,  0.]]])
比如说,

from sklearn.tree import DecisionTreeClassifier
dt = DecisionTreeClassifier()
dt.fit([[0,0],
        [0,1],
        [1,1]], [0,1,0])
get_value(dt)
输出:

更新#1

我想知道为什么树的值只存储叶节点的值,然后我找到了和

事实证明,在scikit learn 0.17.dev0中,
tree_u2;.value
已经返回所有节点的值

import numpy as np

def get_value(dt):
    left = dt.tree_.children_left
    right = dt.tree_.children_right
    value = dt.tree_.value
    leaves = np.argwhere(left == -1)[:, 0]

    def visit(node):
        if node in leaves:
            return
        visit(left[node])
        visit(right[node])
        value[node, :] = value[left[node], :] + value[right[node], :]

    visit(0)
    return value
In [1]: from sklearn.tree import DecisionTreeClassifier

In [2]: dt = DecisionTreeClassifier()

In [3]: dt.fit([[0,0],
   ...:         [0,1],
   ...:         [1,1]], [0,1,0])
Out[3]:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            random_state=None, splitter='best')

In [4]: dt.tree_.value
Out[4]:
array([[[ 2.,  1.]],

       [[ 1.,  1.]],

       [[ 1.,  0.]],

       [[ 0.,  1.]],

       [[ 1.,  0.]]])
更新#2

虽然我认为在给定
类权重时“撤消权重”是有意义的,但这是可能的

class_权重
由以下公式计算:

In [1]: from sklearn.utils import compute_class_weight

In [2]: compute_class_weight('auto', [0, 1], [0, 1, 0])
Out[2]: array([ 0.66666667,  1.33333333])

因此,您可以将
value[node,:]/=class_weight
添加到
if-node in leaves:
之后,以重新计算叶节点的值。

仅计算每个节点的计数值。要创建类似于中的饼图的内容,我仍然对所需的输出有点困惑。为什么
计数中的叶子标记为1、0、1?样本的标签是
[0,1,0]
。为什么根节点是
[2,1]
,而它的子节点是
[1,1]
[0,1]
?你是对的。我把号码弄错了。他们现在被纠正了。但是我怎样才能弥补我的体重呢?到目前为止看起来还不错。如何撤消权重?如果我使用
class\u weight=“auto”
它会以某种方式更改数字。如果给定
class\u weight
,则会对值进行加权。我认为“撤销权重”是没有意义的。也许您可以通过根据叶节点所属的类对其重新加权来抵消权重。在我的示例中,重新加权会是什么?我相信在我的例子中,单位大小的叶片得到了0.66和1.33。但我不确定它们与我的类分布有何关联。请查看我的更新。我没有给出完整的例子,但希望你们能理解。谢谢!我不知道utils.compute\u class\u weight