Warning: file_get_contents(/data/phpspider/zhask/data//catemap/4/r/70.json): failed to open stream: No such file or directory in /data/phpspider/zhask/libs/function.php on line 167

Warning: Invalid argument supplied for foreach() in /data/phpspider/zhask/libs/tag.function.php on line 1116

Notice: Undefined index: in /data/phpspider/zhask/libs/function.php on line 180

Warning: array_chunk() expects parameter 1 to be array, null given in /data/phpspider/zhask/libs/function.php on line 181
R 插入符号中训练函数最终模型中的错误分类样本_R_Random Forest_Cross Validation_R Caret - Fatal编程技术网

R 插入符号中训练函数最终模型中的错误分类样本

R 插入符号中训练函数最终模型中的错误分类样本,r,random-forest,cross-validation,r-caret,R,Random Forest,Cross Validation,R Caret,Caret包中的train函数返回一个最终模型,我想在我的主数据框中找到误分类样本的行索引。我按照以下方式进行交叉验证: library(caret) train_control <- trainControl(method="cv", number=5,savePredictions = TRUE,classProbs = TRUE) output <- train(Species~., data=iris, trControl=train_control, method="rf

Caret包中的train函数返回一个最终模型,我想在我的主数据框中找到误分类样本的行索引。我按照以下方式进行交叉验证:

library(caret)
train_control <- trainControl(method="cv", number=5,savePredictions =  TRUE,classProbs = TRUE)
output <- train(Species~., data=iris, trControl=train_control, method="rf")
有没有办法找出哪些样本被错误分类?(上面混淆矩阵中的3个和4个样本)

试试这个:

library(dplyr)
output$pred %>% filter_("pred!=obs")
输出:

         pred        obs setosa versicolor virginica rowIndex mtry Resample
1   virginica versicolor      0      0.084     0.916       71    2    Fold1
2  versicolor  virginica      0      0.976     0.024      107    2    Fold1
3   virginica versicolor      0      0.074     0.926       71    3    Fold1
4  versicolor  virginica      0      0.990     0.010      107    3    Fold1
5  versicolor  virginica      0      0.504     0.496      130    3    Fold1
6   virginica versicolor      0      0.070     0.930       71    4    Fold1
7  versicolor  virginica      0      0.992     0.008      107    4    Fold1
8  versicolor  virginica      0      0.550     0.450      130    4    Fold1
9   virginica versicolor      0      0.244     0.756       78    2    Fold2
10  virginica versicolor      0      0.172     0.828       78    3    Fold2
11  virginica versicolor      0      0.196     0.804       78    4    Fold2
12 versicolor  virginica      0      0.922     0.078      120    2    Fold3
13 versicolor  virginica      0      0.616     0.384      135    2    Fold3
14 versicolor  virginica      0      0.928     0.072      120    3    Fold3
15 versicolor  virginica      0      0.612     0.388      135    3    Fold3
16 versicolor  virginica      0      0.930     0.070      120    4    Fold3
17 versicolor  virginica      0      0.566     0.434      135    4    Fold3
18  virginica versicolor      0      0.352     0.648       84    2    Fold5
19  virginica versicolor      0      0.316     0.684       84    3    Fold5
20  virginica versicolor      0      0.256     0.744       84    4    Fold5
请注意,
mtry
是在每次分割时随机抽样作为候选变量的变量数量,
Resample
列出了交叉验证折叠

让我们绘制错误分类的项目:

d <- output$pred %>% 
  filter_("pred!=obs") %>% 
  distinct(rowIndex) %>% 
  unlist() %>% sort()

print(unname(d))
# 71  78  84 107 120 130 134 135 139

ggplot(iris, aes(Sepal.Length, Sepal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Sepal.Length, y = Sepal.Width), 
             color = "black")

ggplot(iris, aes(Petal.Length, Petal.Width, colour = Species)) + 
  geom_point() + 
  geom_point(data = iris[d, ], aes(x = Petal.Length, y = Petal.Width), 
             color = "black")
d%
过滤器(“pred!=obs”)%>%
不同(行索引)%>%
取消列表()%>%sort()
打印(未命名(d))
# 71  78  84 107 120 130 134 135 139
ggplot(鸢尾,aes(萼片长度,萼片宽度,颜色=种))+
几何点()
几何点(数据=虹膜[d,],aes(x=萼片长度,y=萼片宽度),
color=“黑色”)
ggplot(鸢尾,aes(花瓣长度,花瓣宽度,颜色=物种))+
几何点()
几何点(数据=虹膜[d,],aes(x=花瓣长度,y=花瓣宽度),
color=“黑色”)


可以看出,这些图对我们的结果给出了直观的解释

另一种简单的方法是检查预测样本:

output$output$finalModel$predicted
然后你可以将预测的数据与你的主要虹膜数据进行比较

output$output$finalModel$predicted