Apache spark 如何从决策树spark MLlib中提取规则

Apache spark 如何从决策树spark MLlib中提取规则,apache-spark,apache-spark-mllib,Apache Spark,Apache Spark Mllib,我正在使用Spark MLlib 1.4.1创建决策树模型。现在我想从决策树中提取规则 如何提取规则 您可以通过调用model.toDebugString()以字符串形式获取完整模型,或者通过调用model.save(sc,filePath)将其保存为JSON ,其中包含一个示例,其中包含一个小样本数据,您可以在命令行中检查输出格式。在这里,我格式化了脚本,您可以直接通过并运行它 from numpy import array from pyspark.mllib.regression impo

我正在使用Spark MLlib 1.4.1创建决策树模型。现在我想从决策树中提取规则


如何提取规则

您可以通过调用model.toDebugString()以字符串形式获取完整模型,或者通过调用model.save(sc,filePath)将其保存为JSON

,其中包含一个示例,其中包含一个小样本数据,您可以在命令行中检查输出格式。在这里,我格式化了脚本,您可以直接通过并运行它

from numpy import array
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree

data = [
LabeledPoint(0.0, [0.0]),
LabeledPoint(1.0, [1.0]),
LabeledPoint(1.0, [2.0]),
LabeledPoint(1.0, [3.0])
]

model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
print(model)

print(model.toDebugString())
输出为:

DecisionTreeModel classifier of depth 1 with 3 nodes
DecisionTreeModel classifier of depth 1 with 3 nodes
  If (feature 0 <= 0.0)
   Predict: 0.0
  Else (feature 0 > 0.0)
   Predict: 1.0 
以下是我的dtMmodel中上述脚本的输出示例:

DecisionTreeModel classifier of depth 20 with 20031 nodes
  If (feature 0 <= -35.0)
   If (feature 24 <= 176.0)
    If (feature 0 <= -200.0)
     If (feature 29 <= 109.0)
      If (feature 6 <= -156.0)
       If (feature 9 <= 0.0)
        If (feature 20 <= -116.0)
         If (feature 16 <= 203.0)
          If (feature 11 <= 163.0)
           If (feature 5 <= 384.0)
            If (feature 15 <= 325.0)
             If (feature 13 <= -248.0)
              If (feature 20 <= -146.0)
               Predict: 0.0
              Else (feature 20 > -146.0)
               If (feature 19 <= -58.0)
                Predict: 6.0
               Else (feature 19 > -58.0)
                Predict: 0.0
             Else (feature 13 > -248.0)
              If (feature 9 <= -26.0)
               Predict: 0.0
              Else (feature 9 > -26.0)
               If (feature 10 <= 218.0)
...
...
...
...
深度为20且节点数为20031的DecisionTreeModel分类器 如果(功能0 加载模型数据,如果您以前在该位置使用过model.save(location),则此数据将显示在hadoop中

modeldf = spark.read.parquet(location+"/data/*")

noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()

 
创建虚拟要素阵列

features = ["feature"+str(i) for i in range(0,700)]
初始化图形

G = nx.DiGraph()
for rw in noderows:

    if rw['leftChild'] < 0 and rw['rightChild'] < 0:

        G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])

    else:

        G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])

 

for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():

    tempnode = G.nodes(data="True")[rw['id']][1]

    #print(tempnode)

    G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

    G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

 

 
输出如下所示:

dtModel = DecisionTree.trainClassifier(parsedTrainData, numClasses=7, categoricalFeaturesInfo={},impurity='gini', maxDepth=20, maxBins=24)



modelFile = ~/decisionTreeModel.txt"
f = open(modelFile,"w") 
f.write(dtModel.toDebugString())
f.close() 
(“规则编号:”,第5条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&功能294小于[1.0]&功能367小于 [1.0]

(“规则编号:”,第8条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特性294小于[1.0]&特性367大于 [1.0]&特性318小于[0.0]&特性385小于[0.0]

(“规则编号:”,第9条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特性294小于[1.0]&特性367大于 [1.0]&特性318小于[0.0]&特性385大于[0.0]

(“规则编号:”,第11条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特性294小于[1.0]&特性367大于 [1.0]&特性318大于[0.0]&特性266小于[0.0]

(“规则编号:”,第12条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特性294小于[1.0]&特性367大于 [1.0]&特性318大于[0.0]&特性266大于[0.0]

(“规则编号:”,第16条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特征294大于[1.0]&特征158小于 [1.0]&功能274小于[0.0]&功能89小于[1.0]

(“规则编号:”,第17条)

功能457小于[0.0]&功能353小于[0.0]&功能185 小于[1.0]&特征294大于[1.0]&特征158小于 [1.0]&特性274小于[0.0]&特性89大于[1.0]


修改了当前的初始代码

我们可以使用model.debugString属性提取规则。完整示例如下:

注意:如果您想了解以下代码的详细信息,请检查

从pyspark.sql.functions导入到_date、datediff、lit、udf、sum、avg、col、count、lag
从pyspark.sql.types导入StringType、LongType、StructType、StructField、DateType、IntegerType、DoubleType
从日期时间导入日期时间
从pyspark.sql导入SparkSession
从pyspark.ml.feature导入向量汇编程序
来自pyspark.ml.classification导入决策树分类程序
从pyspark.ml导入管道
作为pd进口熊猫
从pyspark.sql导入数据帧
从pyspark.sql.functions导入udf、lit、avg、max、min
从pyspark.sql.types导入StringType、ArrayType、DoubleType
从pyspark.ml.feature导入StringIndexer、VectorAssembler、StandardScaler
来自pyspark.ml.classification导入决策树分类程序
从pyspark.sql导入SparkSession
从pyspark.ml导入管道
进口经营者
导入ast
运算符={
“>=”:operator.ge,
“”:operator.gt,

“谢谢你的建议。实际上我想提取这样的规则:这是一个很好的开始,但是功能X解码成真正的功能名称并没有得到解决,这样规则就没用了。”。
G = nx.DiGraph()
for rw in noderows:

    if rw['leftChild'] < 0 and rw['rightChild'] < 0:

        G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])

    else:

        G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])

 

for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():

    tempnode = G.nodes(data="True")[rw['id']][1]

    #print(tempnode)

    G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

    G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))

 

 
nodes = [x for x in G.nodes() if G.out_degree(x)==0 and G.in_degree(x)==1]

for n in nodes:

    p = nx.shortest_path(G,0,n)

    print("Rule No:",n)

    print(" & ".join([G.get_edge_data(p[i],p[i+1])['reason'] for i in range(0,len(p)-1)]))
from pyspark.sql.functions import to_date,datediff,lit,udf,sum,avg,col,count,lag
from pyspark.sql.types import StringType,LongType,StructType,StructField,DateType,IntegerType,DoubleType
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline
import pandas as pd
from pyspark.sql import DataFrame
from pyspark.sql.functions import udf, lit, avg, max, min
from pyspark.sql.types import StringType, ArrayType, DoubleType
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
import operator

import ast

operators = {
            ">=": operator.ge,
            "<=": operator.le,
            ">": operator.gt,
            "<": operator.lt,
            "==": operator.eq,
            'and': operator.and_,
            'or': operator.or_
        }

data = pd.DataFrame({
    'ball': [0, 1, 1, 3, 1, 0, 1, 3],
    'keep': [4, 5, 6, 7, 7, 4, 6, 7],
    'hall': [8, 9, 10, 11, 2, 6, 10, 11],
    'fall': [12, 13, 14, 15, 15, 12, 14, 15],
    'mall': [16, 17, 18, 10, 10, 16, 18, 10],
    'label': [21, 31, 41, 51, 51, 51, 21, 31]
})
df = spark.createDataFrame(data)

f_list = ['ball','keep','mall','hall','fall']
 assemble_numerical_features = VectorAssembler(inputCols=f_list, outputCol='features',
                                                      handleInvalid='skip')

dt = DecisionTreeClassifier(featuresCol='features', labelCol='label')

pipeline = Pipeline(stages=[assemble_numerical_features, dt])
model = pipeline.fit(df)
df = model.transform(df)
dt_m = model.stages[-1]

# Step 1: convert model.debugString output to dictionary of nodes and children
def parse_debug_string_lines(lines):
    
    block = []
    while lines:

        if lines[0].startswith('If'):
            bl = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
            block.append({'name': bl, 'children': parse_debug_string_lines(lines)})

            if lines[0].startswith('Else'):
                be = ' '.join(lines.pop(0).split()[1:]).replace('(', '').replace(')', '')
                block.append({'name': be, 'children': parse_debug_string_lines(lines)})
        elif not lines[0].startswith(('If', 'Else')):
            block2 = lines.pop(0)
            block.append({'name': block2})
        else:
            break
    
    return block

def debug_str_to_json(debug_string):
    data = []
    for line in debug_string.splitlines():
        if line.strip():
            line = line.strip()
            data.append(line)
        else:
            break
        if not line: break
    json = {'name': 'Root', 'children': parse_debug_string_lines(data[1:])}
    return json

# Step 2 : Using metadata stored in features column, build dictionary which maps each feature in features column of df to its index in feature vector
f_type_to_flist_dict = df.schema['features'].metadata["ml_attr"]["attrs"]
f_index_to_name_dict = {}
for f_type, f_list in f_type_to_flist_dict.items():

    for f in f_list:
        f_index = f['idx']
        f_name = f['name']
        f_index_to_name_dict[f_index] = f_name


def generate_explanations(dt_as_json, df:DataFrame, f_index_to_name_dict, operators):

    dt_as_json_str = str(dt_as_json)
    cond_parsing_exception_occured = False

    df = df.withColumn('features'+'_list',
                            udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
                            (df['features'])
                        )
    # step 3 : parse and check whether current instance follows condition in perticular node
    def parse_validate_cond(cond: str, f_vector: list):

        cond_parts = cond.split()
        condition_f_index = int(cond_parts[1])
        condition_op = cond_parts[2]
        condition_value = float(cond_parts[3])

        f_value = f_vector[condition_f_index]
        f_name = f_index_to_name_dict[condition_f_index].replace('numerical_features_', '').replace('encoded_numeric_', '').lower()

        if operators[condition_op](f_value, condition_value):
            return True, f_name + ' ' + condition_op + ' ' + str(round(condition_value,2))

        return False, ''
        
# Step 4 : extract rules for an instance in a dataframe, going through nodes in a tree where instance is satisfying the rule, finally leading to a prediction node
    def extract_rule(dt_as_json_str: str, f_vector: list, rule=""):
        
        # variable declared in outer function is read only
        # in inner if not explicitly declared to be nonlocal
        nonlocal cond_parsing_exception_occured

        dt_as_json = ast.literal_eval(dt_as_json_str)
        child_l = dt_as_json['children']

        for child in child_l:
            name = child['name'].strip()

            if name.startswith('Predict:'):
                # remove last comma
                return rule[0:rule.rindex(',')]

            if name.startswith('feature'):
                try:
                    res, cond = parse_validate_cond(child['name'], f_vector)
                except Exception as e:
                    res = False
                    cond_parsing_exception_occured = True
                if res:
                    rule += cond +', '
                    rule = extract_rule(str(child), f_vector, rule=rule)
        return rule

    df = df.withColumn('explanation',
                        udf(lambda dt, fv:extract_rule(dt, fv) ,StringType())
                        (lit(dt_as_json_str), df['features'+'_list'])
                    )
    # log exception occured while trying to parse
    # condition in decision tree node
    if cond_parsing_exception_occured:
        print('some node in decision tree has unexpected format')

    return df

df = generate_explanations(debug_str_to_json(dt_m.toDebugString), df, f_index_to_name_dict, operators)
rows = df.select(['ball','keep','mall','hall','fall','explanation','prediction']).collect()

output :
-----------------------
[Row(ball=0, keep=4, mall=16, hall=8, fall=12, explanation='hall > 7.0, mall > 13.0, ball <= 0.5', prediction=21.0),
 Row(ball=1, keep=5, mall=17, hall=9, fall=13, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep <= 5.5', prediction=31.0),
 Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
 Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0),
 Row(ball=1, keep=7, mall=10, hall=2, fall=15, explanation='hall <= 7.0', prediction=51.0),
 Row(ball=0, keep=4, mall=16, hall=6, fall=12, explanation='hall <= 7.0', prediction=51.0),
 Row(ball=1, keep=6, mall=18, hall=10, fall=14, explanation='hall > 7.0, mall > 13.0, ball > 0.5, keep > 5.5', prediction=21.0),
 Row(ball=3, keep=7, mall=10, hall=11, fall=15, explanation='hall > 7.0, mall <= 13.0', prediction=31.0)]

output of dt_m.toDebugString:
-----------------------------------
'DecisionTreeClassificationModel (uid=DecisionTreeClassifier_2a17ae7633b9) of depth 4 with 9 nodes\n  If (feature 3 <= 7.0)\n   Predict: 51.0\n  Else (feature 3 > 7.0)\n   If (feature 2 <= 13.0)\n    Predict: 31.0\n   Else (feature 2 > 13.0)\n    If (feature 0 <= 0.5)\n     Predict: 21.0\n    Else (feature 0 > 0.5)\n     If (feature 1 <= 5.5)\n      Predict: 31.0\n     Else (feature 1 > 5.5)\n      Predict: 21.0\n'

output of debug_str_to_json(dt_m.toDebugString):
------------------------------------
{'name': 'Root',
'children': [{'name': 'feature 3 <= 7.0',
   'children': [{'name': 'Predict: 51.0'}]},
  {'name': 'feature 3 > 7.0',
   'children': [{'name': 'feature 2 <= 13.0',
     'children': [{'name': 'Predict: 31.0'}]},
    {'name': 'feature 2 > 13.0',
     'children': [{'name': 'feature 0 <= 0.5',
       'children': [{'name': 'Predict: 21.0'}]},
      {'name': 'feature 0 > 0.5',
       'children': [{'name': 'feature 1 <= 5.5',
         'children': [{'name': 'Predict: 31.0'}]},
        {'name': 'feature 1 > 5.5',
         'children': [{'name': 'Predict: 21.0'}]}]}]}]}]}