Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/r/80.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
R 从中提取模型树并使其可视化_R_Apache Spark_Random Forest_Decision Tree_Sparklyr - Fatal编程技术网

R 从中提取模型树并使其可视化

R 从中提取模型树并使其可视化,r,apache-spark,random-forest,decision-tree,sparklyr,R,Apache Spark,Random Forest,Decision Tree,Sparklyr,关于如何将Sparkyr的ml_决策树分类器、ml_gbt_分类器或ml_随机林分类器模型中的树信息转换为a.)其他R树相关库可以理解的格式,以及(最终)b.)用于非技术消费的树可视化,有人有什么建议吗?这将包括将向量汇编程序生成的替换字符串索引值转换回实际特征名称的能力 出于提供示例的目的,大量复制了以下代码: library(sparklyr) library(dplyr) # If needed, install Spark locally via `spark_install()` s

关于如何将Sparkyr的ml_决策树分类器、ml_gbt_分类器或ml_随机林分类器模型中的树信息转换为a.)其他R树相关库可以理解的格式,以及(最终)b.)用于非技术消费的树可视化,有人有什么建议吗?这将包括将向量汇编程序生成的替换字符串索引值转换回实际特征名称的能力

出于提供示例的目的,大量复制了以下代码:

library(sparklyr)
library(dplyr)

# If needed, install Spark locally via `spark_install()`
sc <- spark_connect(master = "local")
iris_tbl <- copy_to(sc, iris)

# split the data into train and validation sets
iris_data <- iris_tbl %>%
  sdf_partition(train = 2/3, validation = 1/3, seed = 123)


iris_pipeline <- ml_pipeline(sc) %>%
  ft_dplyr_transformer(
    iris_data$train %>%
      mutate(Sepal_Length = log(Sepal_Length),
             Sepal_Width = Sepal_Width ^ 2)
  ) %>%
  ft_string_indexer("Species", "label")

iris_pipeline_model <- iris_pipeline %>%
  ml_fit(iris_data$train)

iris_vector_assembler <- ft_vector_assembler(
  sc, 
  input_cols = setdiff(colnames(iris_data$train), "Species"), 
  output_col = "features"
)
random_forest <- ml_random_forest_classifier(sc,features_col = "features")

# obtain the labels from the fitted StringIndexerModel
iris_labels <- iris_pipeline_model %>%
  ml_stage("string_indexer") %>%
  ml_labels()

# IndexToString will convert the predicted numeric values back to class labels
iris_index_to_string <- ft_index_to_string(sc, "prediction", "predicted_label", 
                                      labels = iris_labels)

# construct a pipeline with these stages
iris_prediction_pipeline <- ml_pipeline(
  iris_pipeline, # pipeline from previous section
  iris_vector_assembler, 
  random_forest,
  iris_index_to_string
)

# fit to data and make some predictions
iris_prediction_model <- iris_prediction_pipeline %>%
  ml_fit(iris_data$train)
iris_predictions <- iris_prediction_model %>%
  ml_transform(iris_data$validation)
iris_predictions %>%
  select(Species, label:predicted_label) %>%
  glimpse()
库(年)
图书馆(dplyr)
#如果需要,通过“Spark_install()本地安装Spark”`
sc%
ft_字符串索引器(“种类”、“标签”)
iris_管道_模型%
ml_拟合(iris_数据$train)
iris_向量_汇编程序%cat()
##打印如下##
具有20棵树的随机森林分类模型(uid=随机森林分类器)
树0(权重1.0):
如果(功能2.5)
如果(功能部件2 4.95)
如果(功能2 5.05)
预测:2.0
树1(重量1.0):
如果(功能部件3 0.8)
如果(功能部件3 1.75)
预测:2.0
树2(重量1.0):
如果(功能部件3 0.8)
如果(功能部件0 1.766405134230237)
如果(功能3 1.45)
如果(功能部件3 1.65)
预测:2.0
树3(重量1.0):
如果(功能部件0 1.6675287895788053)
如果(功能部件3 1.75)
预测:2.0
树4(重量1.0):
如果(功能部件2 4.85)
如果(功能2 5.05)
预测:2.0
树5(重量1.0):
如果(功能2 1.65)
如果(功能部件3 1.65)
预测:2.0
树6(重量1.0):
如果(功能2.5)
如果(功能2 5.05)
预测:2.0
树7(重量1.0):
如果(功能部件3 0.55)
如果(功能部件3 1.65)
如果(功能部件2 4.85)
预测:2.0
树8(重量1.0):
如果(功能部件3 0.8)
如果(功能3 1.85)
预测:2.0
树9(重量1.0):
如果(功能2.5)
如果(功能部件2 4.95)
预测:2.0
树10(重量1.0):
如果(功能部件3 0.8)
如果(功能部件2 4.95)
如果(功能2 5.05)
预测:2.0
树11(重量1.0):
如果(功能部件3 0.8)
如果(功能2 5.05)
预测:2.0
树12(重量1.0):
如果(功能部件3 0.8)
如果(功能部件3 1.75)
如果(特征0 1.7833559100698644)
预测:2.0
树13(重量1.0):
如果(功能部件3 0.55)
如果(功能部件2 4.95)
预测:2.0
树14(重量1.0):
如果(功能2.5)
如果(功能部件3 1.65)
如果(特征0 1.7833559100698644)
预测:2.0
树15(重量1.0):
如果(功能2.5)
如果(功能部件3 1.75)
预测:2.0
树16(重量1.0):
如果(功能部件3 0.8)
如果(功能部件0 1.7491620461964392)
如果(功能部件3 1.75)
预测:2.0
树17(重量1.0):
如果(特征0 1.695573522904327)
如果(功能部件2 4.75)
如果(功能部件3 1.75)
预测:2.0
树18(重量1.0):
如果(功能部件3 0.8)
如果(功能部件3 1.65)
如果(特征0 1.7833559100698644)
预测:2.0
树19(重量1.0):
如果(功能2.5)
如果(功能部件2 4.95)
预测:2.0
正如您所看到的,这种格式对于将决策树图形可视化的许多漂亮方法中的一种传递到您的最佳选择中来说并不太理想,而不涉及复杂的第三方工具(例如,您可以查看MLeap),可能是为了阅读:

根目录
|--treeID:integer(nullable=true)
|--nodeData:struct(nullable=true)
||--id:integer(nullable=true)
||--预测:双精度(可空=真)
||--杂质:双精度(可空=真)
||--inpurityStats:array(nullable=true)
|| |--元素:双精度(containsnall=true)
||--增益:加倍(可为空=真)
||--leftChild:integer(nullable=true)
||--rightChild:integer(nullable=true)
||--split:struct(nullable=true)
|| |--featureIndex:integer(nullable=true)
|| |--leftCategoriesOrThreshold:array(nullable=true)
|| | |--元素:double(containsnall=true)
|| |--numCategories:integer(nullable=true)
提供有关所有节点和拆分的信息

可以使用列元数据检索要素映射:

meta <- iris_predictions %>% 
    select(features) %>% 
    spark_dataframe() %>% 
    invoke("schema") %>% invoke("apply", 0L) %>% 
    invoke("metadata") %>% 
    invoke("getMetadata", "ml_attr") %>% 
    invoke("getMetadata", "attrs") %>% 
    invoke("json") %>%
    jsonlite::fromJSON() %>% 
    dplyr::bind_rows() %>% 
    copy_to(sc, .) %>%
    rename(featureIndex = idx)

meta


*在不久的将来,随着新引入的格式不可知的ML writer API(它已经为选定的模型支持PMML writer。希望新的模型和格式将随之而来),它将得到改进

**如果使用分类功能,可能需要将
leftCategoriesOrThreshold
映射到相应的索引级别

如果特征向量包含分类变量,则
jsonlite::fromJSON()
的输出将包含
nominal
组。例如,如果您有三个级别的索引列
foo
,在第一个位置组装,它将如下所示:

$nominal
     vals idx      name
1 a, b, c   1       foo
其中
vals
列是可变长度向量的列表

length(meta$nominal$vals[[1]])
标签对应于此结构的索引,因此在示例中:

  • a
    的标签为0.0(并非标签是双精度浮点数,编号从0.0开始)
  • b
    标签为1.0
以此类推,如果您使用
leftCategoriesOrThreshold
等于
c(0.0,2.0)
进行拆分,则表示拆分在标签
{“a”,“c”}

还请注意,如果存在分类数据,您可能需要在调用
copy\u to
之前对其进行处理-目前它似乎不支持复杂字段


在Spark中,这看起来很棒!谢谢!事实上,我正在使用分类功能。这显然是我问的,但是如果你对如何使用
meta <- iris_predictions %>% 
    select(features) %>% 
    spark_dataframe() %>% 
    invoke("schema") %>% invoke("apply", 0L) %>% 
    invoke("metadata") %>% 
    invoke("getMetadata", "ml_attr") %>% 
    invoke("getMetadata", "attrs") %>% 
    invoke("json") %>%
    jsonlite::fromJSON() %>% 
    dplyr::bind_rows() %>% 
    copy_to(sc, .) %>%
    rename(featureIndex = idx)

meta
labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
  copy_to(sc, .)
full_rf_spec <- rf_spec %>% 
  spark_dataframe() %>% 
  invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>% 
  sdf_register() %>% 
  select(-split, -impurityStats) %>% 
  left_join(meta, by = "featureIndex") %>% 
  left_join(labels, by = "prediction")

full_rf_spec
library(igraph)

gframe <- full_rf_spec %>% 
  filter(treeID == 0) %>%   # Take the first tree
  mutate(
    leftCategoriesOrThreshold = ifelse(
      size(leftCategoriesOrThreshold) == 1,
      # Continuous variable case
      concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)),
      # Categorical variable case. Decoding variables might be involved
      # but can be achieved if needed, using column metadata or indexer labels
      concat("in {", concat_ws(",", leftCategoriesOrThreshold), "}")
    ),
    name = coalesce(name, label)) %>% 
 select(
   id, label, impurity, gain, 
   leftChild, rightChild, leftCategoriesOrThreshold, name) %>%
 collect()

vertices <- gframe %>% rename(label = name, name = id)

edges <- gframe %>%
  transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>% 
  union_all(gframe %>% select(from = id, to = rightChild)) %>% 
  filter(to != -1)

g <- igraph::graph_from_data_frame(edges, vertices = vertices)

plot(
  g, layout = layout_as_tree(g, root = c(1)),
  vertex.shape = "rectangle",  vertex.size = 45)
$nominal
     vals idx      name
1 a, b, c   1       foo
length(meta$nominal$vals[[1]])
[1] 3