调整SMOTE';s K与trafo失败:';警告(“k应小于样本量!”)&x27;
我在使用调整SMOTE';s K与trafo失败:';警告(“k应小于样本量!”)&x27;,r,machine-learning,mlr3,R,Machine Learning,Mlr3,我在使用SMOTE{smotefamily}的K参数的trafo函数时遇到问题。特别是,当最近邻数K大于或等于样本大小时,将返回错误(警告(“K应小于样本大小!”)),并终止调整过程 在内部重新采样过程中,用户无法控制K小于样本大小。这必须在内部进行控制,例如,如果trafo\u K=2^K>=sample\u size对于K的某个值,那么,比如说,trafo\u K=sample\u size-1 我想知道是否有解决方案,或者是否已经有了 library("mlr3") # mlr3 base
SMOTE{smotefamily}
的K
参数的trafo函数时遇到问题。特别是,当最近邻数K
大于或等于样本大小时,将返回错误(警告(“K应小于样本大小!”)
),并终止调整过程
在内部重新采样过程中,用户无法控制K
小于样本大小。这必须在内部进行控制,例如,如果trafo\u K=2^K>=sample\u size
对于K
的某个值,那么,比如说,trafo\u K=sample\u size-1
我想知道是否有解决方案,或者是否已经有了
library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction
# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
number.of.classes = 2,
number.of.features = c(1, 100),
number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
number.of.symbolic.features == 1)
ds
d = getOMLDataSet(980)
d
# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
id = d$desc$name, backend = data,
target = d$target.features)
task
# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/
class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] /
class_counts[class_counts == min(class_counts)]
# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))
# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")
# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
po('learner', rf, id = 'rf')
graph$plot()
# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'
# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])
# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
lapply(
function(x) {
if (grepl('smote.', x)) {
if (grepl('.dup_size', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
} else if (grepl('.K', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
}
}
}
)
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)
# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
index <- which(grepl('.K', names(x)))
if (sum(index) != 0){
x[[index]] <- round(3 ^ x[[index]]) # Intentionally define a trafo that won't work
}
x
}
# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)
# Set up tuning instance
instance <- TuningInstance$new(
task = task,
learner = rf_smote,
resampling = cv,
measures = msr("classif.bbrier"),
param_set,
terminator = term("evals", n_evals = 3),
store_models = TRUE)
tuner <- TunerRandomSearch$new()
# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)
会话信息
R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)
Matrix products: default
locale:
[1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C
[5] LC_TIME=English_United Kingdom.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] smotefamily_1.3.1 OpenML_1.10 mlr3viz_0.1.1.9002
[4] mlr3tuning_0.1.2-9000 mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0
[7] mlr3learners_0.2.0 mlr3filters_0.2.0.9000 mlr3_0.2.0-9000
[10] paradox_0.2.0 yardstick_0.0.5 rsample_0.0.5
[13] recipes_0.1.9 parsnip_0.0.5 infer_0.5.1
[16] dials_0.0.4 scales_1.1.0 broom_0.5.4
[19] tidymodels_0.0.3 reshape2_1.4.3 janitor_1.2.1
[22] data.table_1.12.8 forcats_0.4.0 stringr_1.4.0
[25] dplyr_0.8.4 purrr_0.3.3 readr_1.3.1
[28] tidyr_1.0.2 tibble_3.0.1 ggplot2_3.3.0
[31] tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] utf8_1.1.4 tidyselect_1.0.0 lme4_1.1-21
[4] htmlwidgets_1.5.1 grid_3.6.2 ranger_0.12.1
[7] pROC_1.16.1 munsell_0.5.0 codetools_0.2-16
[10] bbotk_0.1 DT_0.12 future_1.17.0
[13] miniUI_0.1.1.1 withr_2.2.0 colorspace_1.4-1
[16] knitr_1.28 uuid_0.1-4 rstudioapi_0.10
[19] stats4_3.6.2 bayesplot_1.7.1 listenv_0.8.0
[22] rstan_2.19.2 lgr_0.3.4 DiceDesign_1.8-1
[25] vctrs_0.2.4 generics_0.0.2 ipred_0.9-9
[28] xfun_0.12 R6_2.4.1 markdown_1.1
[31] mlr3measures_0.1.3-9000 rstanarm_2.19.2 lhs_1.0.1
[34] assertthat_0.2.1 promises_1.1.0 nnet_7.3-12
[37] gtable_0.3.0 globals_0.12.5 processx_3.4.1
[40] timeDate_3043.102 rlang_0.4.5 workflows_0.1.1
[43] BBmisc_1.11 splines_3.6.2 checkmate_2.0.0
[46] inline_0.3.15 yaml_2.2.1 modelr_0.1.5
[49] tidytext_0.2.2 threejs_0.3.3 crosstalk_1.0.0
[52] backports_1.1.6 httpuv_1.5.2 rsconnect_0.8.16
[55] tokenizers_0.2.1 tools_3.6.2 lava_1.6.6
[58] ellipsis_0.3.0 ggridges_0.5.2 Rcpp_1.0.4.6
[61] plyr_1.8.5 base64enc_0.1-3 visNetwork_2.0.9
[64] ps_1.3.0 prettyunits_1.1.1 rpart_4.1-15
[67] zoo_1.8-7 haven_2.2.0 fs_1.3.1
[70] furrr_0.1.0 magrittr_1.5 colourpicker_1.0
[73] reprex_0.3.0 GPfit_1.0-8 SnowballC_0.6.0
[76] packrat_0.5.0 matrixStats_0.55.0 tidyposterior_0.0.2
[79] hms_0.5.3 shinyjs_1.1 mime_0.8
[82] xtable_1.8-4 XML_3.99-0.3 tidypredict_0.4.3
[85] shinystan_2.5.0 readxl_1.3.1 gridExtra_2.3
[88] rstantools_2.0.0 compiler_3.6.2 crayon_1.3.4
[91] minqa_1.2.4 StanHeaders_2.21.0-1 htmltools_0.4.0
[94] later_1.0.0 lubridate_1.7.4 DBI_1.1.0
[97] dbplyr_1.4.2 MASS_7.3-51.4 boot_1.3-23
[100] Matrix_1.2-18 cli_2.0.1 parallel_3.6.2
[103] gower_0.2.1 igraph_1.2.4.2 pkgconfig_2.0.3
[106] xml2_1.2.2 foreach_1.4.7 dygraphs_1.1.1.6
[109] prodlim_2019.11.13 farff_1.1 rvest_0.3.5
[112] snakecase_0.11.0 janeaustenr_0.1.5 callr_3.4.1
[115] digest_0.6.25 cellranger_1.1.0 curl_4.3
[118] shiny_1.4.0 gtools_3.8.1 nloptr_1.2.1
[121] lifecycle_0.2.0 nlme_3.1-142 jsonlite_1.6.1
[124] fansi_0.4.1 pillar_1.4.3 lattice_0.20-38
[127] loo_2.2.0 fastmap_1.0.1 httr_1.4.1
[130] pkgbuild_1.0.6 survival_3.1-8 glue_1.4.0
[133] xts_0.12-0 FNN_1.1.3 shinythemes_1.1.2
[136] iterators_1.0.12 class_7.3-15 stringi_1.4.4
[139] memoise_1.1.0 future.apply_1.5.0
非常感谢。我找到了一个解决办法 如前所述,问题在于
SMOTE{smotefamily}
的K
不能大于或等于样本量
我进入这个过程,发现SMOTE{smotefamily}
使用knearest{smotefamily}
,后者使用knnx.index{FNN}
,后者反过来使用get.knn{FNN}
,
这将返回错误警告(“k应小于样本大小!”
,终止mlr3
中的调优过程
现在,在SMOTE{smotefamily}
中,knearest{smotefamily}
的三个参数是p_集
,p_集
和K
。从mlr3
重采样的角度来看,
数据帧P_集
是训练数据交叉验证折叠的子集,经过过滤,仅包含少数类的记录。“样本量”是指
错误指的是P\u集
的行数
因此,随着K
通过诸如某个整数^K
(例如2^K
)之类的流量的增加,K>=nrow(p\u集)
我们需要确保K
永远不会大于或等于p_set
以下是我建议的解决方案:
rsmp()
定义cv重采样策略之前,先定义变量cv\u folds
rsmp()
中定义CV重采样策略,其中folds=CV\u folds
K
的阈值:我找到了一个解决办法 如前所述,问题在于
SMOTE{smotefamily}
的K
不能大于或等于样本量
我进入这个过程,发现SMOTE{smotefamily}
使用knearest{smotefamily}
,后者使用knnx.index{FNN}
,后者反过来使用get.knn{FNN}
,
这将返回错误警告(“k应小于样本大小!”
,终止mlr3
中的调优过程
现在,在SMOTE{smotefamily}
中,knearest{smotefamily}
的三个参数是p_集
,p_集
和K
。从mlr3
重采样的角度来看,
数据帧P_集
是训练数据交叉验证折叠的子集,经过过滤,仅包含少数类的记录。“样本量”是指
错误指的是P\u集
的行数
因此,随着K
通过诸如某个整数^K
(例如2^K
)之类的流量的增加,K>=nrow(p\u集)
我们需要确保K
永远不会大于或等于p_set
以下是我建议的解决方案:
rsmp()
定义cv重采样策略之前,先定义变量cv\u folds
rsmp()
中定义CV重采样策略,其中folds=CV\u folds
K
的阈值:你建议什么样的解决方案?由于数据的特殊性,许多超参数组合是不可能的(导致错误),如您所示的示例。我建议在内部添加一个
ifelse
过程,以便当K
超过样本大小时,将K
设置为适当的最大值(可能是样本大小-1
?)。应返回警告,以便进程实际运行而不是终止。就目前的情况而言,这个过程实际上取决于你有多幸运。例如,试着在MyQ中运行几次代码。有时您会收到报告的错误,脚本会停止,有时您不会,脚本会运行。我自己的数据很不走运,所以我不能真正使用所说的trafo
来处理我的数据。然后你建议在所有支持mlr3的算法中,如果数据允许,首先检查超参数组合,并调整它们,使其落在允许范围内并发出警告?听起来有很多工作要做。我认为在调优过程中跳过糟糕的超参数组合要容易得多。参见封装:我实际上是建议对SMOTE进行封装,而不是所有受支持的算法。无论如何,谢谢你提供的资源——我会看看我能做些什么。你建议什么样的解决方案?由于数据的特殊性,许多超参数组合是不可能的(结果是错误的),如您展示的示例中所示
R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)
Matrix products: default
locale:
[1] LC_COLLATE=English_United Kingdom.1252 LC_CTYPE=English_United Kingdom.1252
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C
[5] LC_TIME=English_United Kingdom.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] smotefamily_1.3.1 OpenML_1.10 mlr3viz_0.1.1.9002
[4] mlr3tuning_0.1.2-9000 mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0
[7] mlr3learners_0.2.0 mlr3filters_0.2.0.9000 mlr3_0.2.0-9000
[10] paradox_0.2.0 yardstick_0.0.5 rsample_0.0.5
[13] recipes_0.1.9 parsnip_0.0.5 infer_0.5.1
[16] dials_0.0.4 scales_1.1.0 broom_0.5.4
[19] tidymodels_0.0.3 reshape2_1.4.3 janitor_1.2.1
[22] data.table_1.12.8 forcats_0.4.0 stringr_1.4.0
[25] dplyr_0.8.4 purrr_0.3.3 readr_1.3.1
[28] tidyr_1.0.2 tibble_3.0.1 ggplot2_3.3.0
[31] tidyverse_1.3.0
loaded via a namespace (and not attached):
[1] utf8_1.1.4 tidyselect_1.0.0 lme4_1.1-21
[4] htmlwidgets_1.5.1 grid_3.6.2 ranger_0.12.1
[7] pROC_1.16.1 munsell_0.5.0 codetools_0.2-16
[10] bbotk_0.1 DT_0.12 future_1.17.0
[13] miniUI_0.1.1.1 withr_2.2.0 colorspace_1.4-1
[16] knitr_1.28 uuid_0.1-4 rstudioapi_0.10
[19] stats4_3.6.2 bayesplot_1.7.1 listenv_0.8.0
[22] rstan_2.19.2 lgr_0.3.4 DiceDesign_1.8-1
[25] vctrs_0.2.4 generics_0.0.2 ipred_0.9-9
[28] xfun_0.12 R6_2.4.1 markdown_1.1
[31] mlr3measures_0.1.3-9000 rstanarm_2.19.2 lhs_1.0.1
[34] assertthat_0.2.1 promises_1.1.0 nnet_7.3-12
[37] gtable_0.3.0 globals_0.12.5 processx_3.4.1
[40] timeDate_3043.102 rlang_0.4.5 workflows_0.1.1
[43] BBmisc_1.11 splines_3.6.2 checkmate_2.0.0
[46] inline_0.3.15 yaml_2.2.1 modelr_0.1.5
[49] tidytext_0.2.2 threejs_0.3.3 crosstalk_1.0.0
[52] backports_1.1.6 httpuv_1.5.2 rsconnect_0.8.16
[55] tokenizers_0.2.1 tools_3.6.2 lava_1.6.6
[58] ellipsis_0.3.0 ggridges_0.5.2 Rcpp_1.0.4.6
[61] plyr_1.8.5 base64enc_0.1-3 visNetwork_2.0.9
[64] ps_1.3.0 prettyunits_1.1.1 rpart_4.1-15
[67] zoo_1.8-7 haven_2.2.0 fs_1.3.1
[70] furrr_0.1.0 magrittr_1.5 colourpicker_1.0
[73] reprex_0.3.0 GPfit_1.0-8 SnowballC_0.6.0
[76] packrat_0.5.0 matrixStats_0.55.0 tidyposterior_0.0.2
[79] hms_0.5.3 shinyjs_1.1 mime_0.8
[82] xtable_1.8-4 XML_3.99-0.3 tidypredict_0.4.3
[85] shinystan_2.5.0 readxl_1.3.1 gridExtra_2.3
[88] rstantools_2.0.0 compiler_3.6.2 crayon_1.3.4
[91] minqa_1.2.4 StanHeaders_2.21.0-1 htmltools_0.4.0
[94] later_1.0.0 lubridate_1.7.4 DBI_1.1.0
[97] dbplyr_1.4.2 MASS_7.3-51.4 boot_1.3-23
[100] Matrix_1.2-18 cli_2.0.1 parallel_3.6.2
[103] gower_0.2.1 igraph_1.2.4.2 pkgconfig_2.0.3
[106] xml2_1.2.2 foreach_1.4.7 dygraphs_1.1.1.6
[109] prodlim_2019.11.13 farff_1.1 rvest_0.3.5
[112] snakecase_0.11.0 janeaustenr_0.1.5 callr_3.4.1
[115] digest_0.6.25 cellranger_1.1.0 curl_4.3
[118] shiny_1.4.0 gtools_3.8.1 nloptr_1.2.1
[121] lifecycle_0.2.0 nlme_3.1-142 jsonlite_1.6.1
[124] fansi_0.4.1 pillar_1.4.3 lattice_0.20-38
[127] loo_2.2.0 fastmap_1.0.1 httr_1.4.1
[130] pkgbuild_1.0.6 survival_3.1-8 glue_1.4.0
[133] xts_0.12-0 FNN_1.1.3 shinythemes_1.1.2
[136] iterators_1.0.12 class_7.3-15 stringi_1.4.4
[139] memoise_1.1.0 future.apply_1.5.0
smote_k_thresh <- 1:cv_folds %>%
lapply(
function(x) {
index <- cv$train_set(x)
aux <- as.data.frame(task$data())[index, task$target_names]
aux <- min(table(aux))
}
) %>%
bind_cols %>%
min %>%
unique
param_set$trafo <- function(x, param_set) {
index <- which(grepl('.K', names(x)))
if (sum(index) != 0){
aux <- round(2 ^ x[[index]])
if (aux < smote_k_thresh) {
x[[index]] <- aux
} else {
x[[index]] <- sample(smote_k_thresh - 1, 1)
}
}
x
}
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction
# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
number.of.classes = 2,
number.of.features = c(1, 100),
number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
number.of.symbolic.features == 1)
ds
d = getOMLDataSet(980)
d
# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
id = d$desc$name, backend = data,
target = d$target.features)
task
# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/
class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] /
class_counts[class_counts == min(class_counts)]
# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))
# Define and instantiate resampling strategy to be applied within pipeline
# Do that BEFORE defining the trafo
cv_folds <- 2
cv <- rsmp("cv", folds = cv_folds)
cv$instantiate(task)
# Calculate max possible value for k-nearest neighbours
smote_k_thresh <- 1:cv_folds %>%
lapply(
function(x) {
index <- cv$train_set(x)
aux <- as.data.frame(task$data())[index, task$target_names]
aux <- min(table(aux))
}
) %>%
bind_cols %>%
min %>%
unique
# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")
# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
po('learner', rf, id = 'rf')
graph$plot()
# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'
# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])
# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
lapply(
function(x) {
if (grepl('smote.', x)) {
if (grepl('.dup_size', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
} else if (grepl('.K', x)) {
ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
}
}
}
)
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)
# Apply transformation function on SMOTE's K while ensuring it never equals or exceeds the sample size
param_set$trafo <- function(x, param_set) {
index <- which(grepl('.K', names(x)))
if (sum(index) != 0){
aux <- round(5 ^ x[[index]]) # Try a large value here for the sake of the example
if (aux < smote_k_thresh) {
x[[index]] <- aux
} else {
x[[index]] <- sample(smote_k_thresh - 1, 1)
}
}
x
}
# Set up tuning instance
instance <- TuningInstance$new(
task = task,
learner = rf_smote,
resampling = cv,
measures = msr("classif.bbrier"),
param_set,
terminator = term("evals", n_evals = 10),
store_models = TRUE)
tuner <- TunerRandomSearch$new()
# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)
# Here are the original K values
instance$archive$data
# And here are their transformations
instance$archive$data$opt_x