Python 提取每个终端节点的路径

Python 提取每个终端节点的路径,python,json,nested,xgboost,Python,Json,Nested,Xgboost,我有一个python嵌套字典结构,如下所示。 这是一个小例子,但我有更大的例子,可以有不同层次的嵌套 从中,我需要提取一个包含以下内容的列表: 每个终端“叶”节点一条记录 表示通向该节点的逻辑路径的字符串、列表或对象 (例如,“nodeid_3:X

我有一个python嵌套字典结构,如下所示。 这是一个小例子,但我有更大的例子,可以有不同层次的嵌套

从中,我需要提取一个包含以下内容的列表:

  • 每个终端“叶”节点一条记录
  • 表示通向该节点的逻辑路径的字符串、列表或对象
    • (例如,“nodeid_3:X<0.500007和X<0.279907”)
  • 我花了这个周末的大部分时间试图让一些东西工作起来,我意识到我在递归方面有多么糟糕

    # Extract json string
    json_string = booster.get_dump(with_stats=True, dump_format='json')[0]
    
    # Convert to python dictionary
    json.loads(json_string)
    
    {u'children': [{u'children': [
        {u'cover': 2291, u'leaf': -0.0611795, u'nodeid': 3},
        {u'cover': 1779, u'leaf': -0.00965727, u'nodeid': 4}],
       u'cover': 4070,
       u'depth': 1,
       u'gain': 265.811,
       u'missing': 3,
       u'no': 4,
       u'nodeid': 1,
       u'split': u'X',
       u'split_condition': 0.279907,
       u'yes': 3},
      {u'cover': 3930, u'leaf': -0.0611946, u'nodeid': 2}],
     u'cover': 8000,
     u'depth': 0,
     u'gain': 101.245,
     u'missing': 1,
     u'no': 2,
     u'nodeid': 0,
     u'split': u'X',
     u'split_condition': 0.500007,
     u'yes': 1}
    

    数据结构是递归的。如果一个节点有子密钥,那么我们可以考虑它不是终端。 要分析数据,需要一个递归函数来跟踪祖先(路径)

    我会这样实施:

    def find_path(obj, path=None):
        path = path or []
        if 'children' in obj:
            child_obj = {k: v for k, v in obj.items()
                         if k in ['nodeid', 'split_condition']}
            child_path = path + [child_obj]
            children = obj['children']
            for child in children:
                find_path(child, child_path)
        else:
            pprint.pprint((obj, path))
    
    如果你打电话:

    find_path(data)
    
    您将获得3个结果:

    ({'cover': 2291, 'leaf': -0.0611795, 'nodeid': 3},
     [{'nodeid': 0, 'split_condition': 0.500007},
      {'nodeid': 1, 'split_condition': 0.279907}])
    ({'cover': 1779, 'leaf': -0.00965727, 'nodeid': 4},
     [{'nodeid': 0, 'split_condition': 0.500007},
      {'nodeid': 1, 'split_condition': 0.279907}])
    ({'cover': 3930, 'leaf': -0.0611946, 'nodeid': 2},
     [{'nodeid': 0, 'split_condition': 0.500007}])
    
    当然,您可以将对
    pprint.pprint()
    的调用替换为
    yield
    以将此函数转换为生成器:

    def iter_path(obj, path=None):
        path = path or []
        if 'children' in obj:
            child_obj = {k: v for k, v in obj.items()
                         if k in ['nodeid', 'split_condition']}
            child_path = path + [child_obj]
            children = obj['children']
            for child in children:
                # for o, p in iteration_path(child, child_path):
                #     yield o, p
                yield from iter_path(child, child_path)
        else:
            yield obj, path
    
    注意递归调用的
    yield from
    。您可以按如下方式使用此生成器:

    for obj, path in iter_path(data):
        pprint.pprint((obj, path))
    
    您还可以更改
    child\u obj
    对象的构建方式,以满足您的需要


    要保持对象的顺序,请反转
    if
    条件:
    if'children'不在obj中:…

    这真是太好了!我可以打印,但是如果我迭代生成的生成器,切换到
    yield
    不会返回任何内容。你介意添加带有相关调用的收益率版本吗(以防我搞砸了)?@Chris我添加了一个
    iter\u路径
    generator。这个有效:)<代码>查找路径确实需要在注释行中更改为
    iter\u路径
    ,但您首先提供这两个版本真是太好了。谢谢,这节省了我很多时间!