调整SMOTE';s K与trafo失败:';警告(“k应小于样本量!”)&x27;

调整SMOTE';s K与trafo失败:';警告(“k应小于样本量!”)&x27;,r,machine-learning,mlr3,R,Machine Learning,Mlr3,我在使用SMOTE{smotefamily}的K参数的trafo函数时遇到问题。特别是,当最近邻数K大于或等于样本大小时,将返回错误(警告(“K应小于样本大小!”)),并终止调整过程 在内部重新采样过程中,用户无法控制K小于样本大小。这必须在内部进行控制,例如,如果trafo\u K=2^K>=sample\u size对于K的某个值,那么,比如说,trafo\u K=sample\u size-1 我想知道是否有解决方案,或者是否已经有了 library("mlr3") # mlr3 base

我在使用
SMOTE{smotefamily}
K
参数的trafo函数时遇到问题。特别是,当最近邻数
K
大于或等于样本大小时,将返回错误(
警告(“K应小于样本大小!”)
),并终止调整过程

在内部重新采样过程中,用户无法控制
K
小于样本大小。这必须在内部进行控制,例如,如果
trafo\u K=2^K>=sample\u size
对于
K
的某个值,那么,比如说,
trafo\u K=sample\u size-1

我想知道是否有解决方案,或者是否已经有了

library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    x[[index]] <- round(3 ^ x[[index]]) #  Intentionally define a trafo that won't work
  }
  x
}

# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 3), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)
会话信息

R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
 [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
 [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
[10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
[13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
[16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
[19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
[22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
[25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
[28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
[31] tidyverse_1.3.0         

loaded via a namespace (and not attached):
  [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
  [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
  [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
 [10] bbotk_0.1               DT_0.12                 future_1.17.0          
 [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
 [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
 [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
 [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
 [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
 [28] xfun_0.12               R6_2.4.1                markdown_1.1           
 [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
 [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
 [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
 [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
 [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
 [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
 [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
 [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
 [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
 [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
 [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
 [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
 [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
 [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
 [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
 [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
 [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
 [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
 [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
 [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
 [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
 [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
 [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
[100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
[103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
[106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
[109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
[112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
[115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
[118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
[121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
[124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
[127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
[130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
[133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
[136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
[139] memoise_1.1.0           future.apply_1.5.0     

非常感谢。

我找到了一个解决办法

如前所述,问题在于
SMOTE{smotefamily}
K
不能大于或等于样本量

我进入这个过程,发现
SMOTE{smotefamily}
使用
knearest{smotefamily}
,后者使用
knnx.index{FNN}
,后者反过来使用
get.knn{FNN}
, 这将返回错误
警告(“k应小于样本大小!”
,终止
mlr3
中的调优过程

现在,在
SMOTE{smotefamily}
中,
knearest{smotefamily}
的三个参数是
p_集
p_集
K
。从
mlr3
重采样的角度来看, 数据帧
P_集
是训练数据交叉验证折叠的子集,经过过滤,仅包含少数类的记录。“样本量”是指 错误指的是
P\u集
的行数

因此,随着
K
通过诸如
某个整数^K
(例如
2^K
)之类的流量的增加,
K>=nrow(p\u集)

我们需要确保
K
永远不会大于或等于
p_set

以下是我建议的解决方案:

  • 使用
    rsmp()
    定义cv重采样策略之前,先定义变量
    cv\u folds
  • 在定义trafo之前,在
    rsmp()
    中定义CV重采样策略,其中
    folds=CV\u folds
  • 实例化CV。现在,数据集分为训练数据和测试/验证数据
  • 在所有训练数据折叠中找到少数类的最小样本大小,并将其设置为
    K
    的阈值:

  • 我找到了一个解决办法

    如前所述,问题在于
    SMOTE{smotefamily}
    K
    不能大于或等于样本量

    我进入这个过程,发现
    SMOTE{smotefamily}
    使用
    knearest{smotefamily}
    ,后者使用
    knnx.index{FNN}
    ,后者反过来使用
    get.knn{FNN}
    , 这将返回错误
    警告(“k应小于样本大小!”
    ,终止
    mlr3
    中的调优过程

    现在,在
    SMOTE{smotefamily}
    中,
    knearest{smotefamily}
    的三个参数是
    p_集
    p_集
    K
    。从
    mlr3
    重采样的角度来看, 数据帧
    P_集
    是训练数据交叉验证折叠的子集,经过过滤,仅包含少数类的记录。“样本量”是指 错误指的是
    P\u集
    的行数

    因此,随着
    K
    通过诸如
    某个整数^K
    (例如
    2^K
    )之类的流量的增加,
    K>=nrow(p\u集)

    我们需要确保
    K
    永远不会大于或等于
    p_set

    以下是我建议的解决方案:

  • 使用
    rsmp()
    定义cv重采样策略之前,先定义变量
    cv\u folds
  • 在定义trafo之前,在
    rsmp()
    中定义CV重采样策略,其中
    folds=CV\u folds
  • 实例化CV。现在,数据集分为训练数据和测试/验证数据
  • 在所有训练数据折叠中找到少数类的最小样本大小,并将其设置为
    K
    的阈值:

  • 你建议什么样的解决方案?由于数据的特殊性,许多超参数组合是不可能的(导致错误),如您所示的示例。我建议在内部添加一个
    ifelse
    过程,以便当
    K
    超过样本大小时,将
    K
    设置为适当的最大值(可能是
    样本大小-1
    ?)。应返回警告,以便进程实际运行而不是终止。就目前的情况而言,这个过程实际上取决于你有多幸运。例如,试着在MyQ中运行几次代码。有时您会收到报告的错误,脚本会停止,有时您不会,脚本会运行。我自己的数据很不走运,所以我不能真正使用所说的
    trafo
    来处理我的数据。然后你建议在所有支持mlr3的算法中,如果数据允许,首先检查超参数组合,并调整它们,使其落在允许范围内并发出警告?听起来有很多工作要做。我认为在调优过程中跳过糟糕的超参数组合要容易得多。参见封装:我实际上是建议对SMOTE进行封装,而不是所有受支持的算法。无论如何,谢谢你提供的资源——我会看看我能做些什么。你建议什么样的解决方案?由于数据的特殊性,许多超参数组合是不可能的(结果是错误的),如您展示的示例中所示
    R version 3.6.2 (2019-12-12)
    Platform: x86_64-w64-mingw32/x64 (64-bit)
    Running under: Windows 10 x64 (build 16299)
    
    Matrix products: default
    
    locale:
    [1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
    [3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
    [5] LC_TIME=English_United Kingdom.1252    
    
    attached base packages:
    [1] stats     graphics  grDevices utils     datasets  methods   base     
    
    other attached packages:
     [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
     [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
     [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
    [10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
    [13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
    [16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
    [19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
    [22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
    [25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
    [28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
    [31] tidyverse_1.3.0         
    
    loaded via a namespace (and not attached):
      [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
      [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
      [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
     [10] bbotk_0.1               DT_0.12                 future_1.17.0          
     [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
     [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
     [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
     [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
     [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
     [28] xfun_0.12               R6_2.4.1                markdown_1.1           
     [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
     [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
     [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
     [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
     [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
     [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
     [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
     [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
     [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
     [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
     [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
     [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
     [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
     [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
     [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
     [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
     [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
     [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
     [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
     [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
     [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
     [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
     [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
    [100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
    [103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
    [106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
    [109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
    [112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
    [115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
    [118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
    [121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
    [124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
    [127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
    [130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
    [133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
    [136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
    [139] memoise_1.1.0           future.apply_1.5.0     
    
    smote_k_thresh <- 1:cv_folds %>%
      lapply(
        function(x) {
          index <- cv$train_set(x)
          aux <- as.data.frame(task$data())[index, task$target_names]
          aux <- min(table(aux))
        }
      ) %>%
      bind_cols %>%
      min %>%
      unique
    
    param_set$trafo <- function(x, param_set) {
      index <- which(grepl('.K', names(x)))
      if (sum(index) != 0){
        aux <- round(2 ^ x[[index]])
        if (aux < smote_k_thresh) {
          x[[index]] <- aux
        } else {
          x[[index]] <- sample(smote_k_thresh - 1, 1)
        }
      }
      x
    }
    
    library("mlr3learners") # additional ML algorithms
    library("mlr3viz") # autoplot for benchmarks
    library("paradox") # hyperparameter space
    library("OpenML") # to obtain data sets
    library("smotefamily") # SMOTE algorithm for imbalance correction
    
    # get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
    ds = listOMLDataSets(
      number.of.classes = 2,
      number.of.features = c(1, 100),
      number.of.instances = c(5000, 10000)
    )
    # select imbalanced data sets (without categorical features as SMOTE cannot handle them)
    ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
                  number.of.symbolic.features == 1)
    ds
    
    d = getOMLDataSet(980)
    d
    
    # make sure target is a factor and create mlr3 tasks
    data = as.data.frame(d)
    data[[d$target.features]] = as.factor(data[[d$target.features]])
    task = TaskClassif$new(
      id = d$desc$name, backend = data,
      target = d$target.features)
    task
    
    # Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/
    
    class_counts <- table(task$truth())
    majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
      class_counts[class_counts == min(class_counts)]
    
    # Pipe operator for SMOTE
    po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))
    
    # Define and instantiate resampling strategy to be applied within pipeline
    # Do that BEFORE defining the trafo
    cv_folds <- 2
    cv <- rsmp("cv", folds = cv_folds)
    cv$instantiate(task)
    
    # Calculate max possible value for k-nearest neighbours
    smote_k_thresh <- 1:cv_folds %>%
      lapply(
        function(x) {
          index <- cv$train_set(x)
          aux <- as.data.frame(task$data())[index, task$target_names]
          aux <- min(table(aux))
        }
      ) %>%
      bind_cols %>%
      min %>%
      unique
    
    # Random Forest learner
    rf <- lrn("classif.ranger", predict_type = "prob")
    
    # Pipeline of Random Forest learner with SMOTE
    graph <- po_smote %>>%
      po('learner', rf, id = 'rf')
    graph$plot()
    
    # Graph learner
    rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
    rf_smote$predict_type <- 'prob'
    
    # Parameter set in data table format
    ps_table <- as.data.table(rf_smote$param_set)
    View(ps_table[, 1:4])
    
    # Define parameter search space for the SMOTE parameters
    param_set <- ps_table$id %>%
      lapply(
        function(x) {
          if (grepl('smote.', x)) {
            if (grepl('.dup_size', x)) {
              ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
            } else if (grepl('.K', x)) {
              ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
            }
          }
        }
      )
    param_set <- Filter(Negate(is.null), param_set)
    param_set <- ParamSet$new(param_set)
    
    # Apply transformation function on SMOTE's K while ensuring it never equals or exceeds the sample size
    param_set$trafo <- function(x, param_set) {
      index <- which(grepl('.K', names(x)))
      if (sum(index) != 0){
        aux <- round(5 ^ x[[index]]) # Try a large value here for the sake of the example
        if (aux < smote_k_thresh) {
          x[[index]] <- aux
        } else {
          x[[index]] <- sample(smote_k_thresh - 1, 1)
        }
      }
      x
    }
    
    # Set up tuning instance
    instance <- TuningInstance$new(
      task = task,
      learner = rf_smote,
      resampling = cv,
      measures = msr("classif.bbrier"),
      param_set,
      terminator = term("evals", n_evals = 10), 
      store_models = TRUE)
    tuner <- TunerRandomSearch$new()
    
    # Tune pipe learner to find optimal SMOTE parameter values
    tuner$optimize(instance)
    
    # Here are the original K values
    instance$archive$data
    
    # And here are their transformations
    instance$archive$data$opt_x