R语言 如何使用Tidymodels和vip从随机森林中获取可变重要性图

vwhgwdsa  于 2023-03-15  发布在  其他
关注(0)|答案(1)|浏览(318)

我使用的数据集如下所示 (mushroom)https://archive.ics.uci.edu/ml/datasets/mushroom

  • 我使用Tidymodels指定了随机森林的以下配方、模型和工作流:
df_recipe_mixt <- df_train |>  recipe(class ~ cap_diameter + cap_color + does_bruise_or_bleed + gill_color + stem_height + stem_width + stem_color + has_ring + habitat + season, data = df_train) |>
  step_scale(all_numeric()) |> 
  step_center(all_numeric()) |> 
  step_dummy(all_nominal(), -all_outcomes()) |> 
  prep()

rf_mod <- rand_forest() |> 
  set_engine("ranger") |> 
  set_mode("classification") |> 
  set_args(mtry = tune(), trees = tune())

rf_wf <- workflow() |>  
  add_model(rf_mod) |> 
  add_recipe(df_recipe_mixt)
  • 然后我调整了模型并将其应用于测试数据
n_cores <- parallel::detectCores(logical = TRUE)
registerDoParallel(cores = n_cores - 1)

rf_params <- extract_parameter_set_dials(rf_wf) |>  
  update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))

rf_grid <- grid_regular(rf_params, levels = c(mtry = 5, trees = 3))

tic("random forest model tuning ")

tune_res_rf <- tune_grid(rf_wf,
  resamples = df_folds,
  grid = rf_grid,
  metrics = metric_set(accuracy)
)

toc()

stopImplicitCluster()

autoplot(tune_res_rf) + dark_mode(theme_minimal())

rf_best <- tune_res_rf |> select_best(metric = "accuracy")

rf_best$trees;rf_best$mtry

rf_final_wf <- rf_wf |>
  finalize_workflow(rf_best)

rf_res <- last_fit(rf_final_wf, split = df_split) |> collect_predictions()

模型运行良好。我得到了所有的指标,混淆矩阵沿着ROC曲线。然而,我找不到一种方法来获得可变重要性的图表(最好使用vip

92dk7w1h

92dk7w1h1#

要从ranger模型中获取变量的重要性,您需要指定它应该计算哪个importance指标。在tidymodels中,我们通过在set_engine()中设置importance = "impurity"来实现这一点,以便将此参数传递给底层的{ranger}函数。
另一个地方,这是显示在这里:https://www.tidymodels.org/start/case-study/
我还更新了食谱,使用all_numeric_predictors()all_nominal_predictors(),因为您更有可能想要使用这些。
(this reprex所做的与您的略有不同,因为您没有输入创建df_split()的方式)

library(tidymodels)

mushroom_col_names <- c(
  "cap_shape", "cap_surface", "cap_color", "bruises", "odor", "gill_attachment", 
  "gill_spacing", "gill_size", "gill_color", "stalk_shape", "stalk_root", 
  "stalk_surface_above_ring", "stalk_surface_below_ring", 
  "stalk_color_above_ring", "stalk_color_below_ring", "veil_type", "veil_color", 
  "ring_number", "ring_type", "spore_print_color", "population", "habitat"
)

mushrooms <- readr::read_csv(
  "https://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data",
  col_names = mushroom_col_names,
  show_col_types = FALSE
) |> select(-veil_color)

df_split <- initial_split(mushrooms)

df_train <- training(df_split)
df_folds <- vfold_cv(df_train, v = 2)

df_recipe_mixt <- recipe(cap_shape ~ .,
                         data = df_train) |>
  step_scale(all_numeric_predictors()) |>
  step_center(all_numeric_predictors()) |>
  step_novel(all_nominal_predictors()) |>
  step_unknown(all_nominal_predictors()) |>
  step_dummy(all_nominal_predictors()) |>
  prep()

rf_mod <- rand_forest() |> 
  set_engine("ranger", importance = "impurity") |> 
  set_mode("classification") |> 
  set_args(mtry = tune(), trees = tune())

rf_wf <- workflow() |>  
  add_model(rf_mod) |> 
  add_recipe(df_recipe_mixt) 

rf_params <- extract_parameter_set_dials(rf_wf) |>  
  update(mtry = mtry(c(1,5)), trees = trees(c(50,500)))

rf_grid <- grid_regular(rf_params, levels = c(mtry = 2, trees = 2))

tune_res_rf <- tune_grid(rf_wf,
                         resamples = df_folds,
                         grid = rf_grid,
                         metrics = metric_set(accuracy)
)

rf_best <- tune_res_rf |> select_best(metric = "accuracy")

rf_final_wf <- rf_wf |>
  finalize_workflow(rf_best)

rf_res <- last_fit(rf_final_wf, split = df_split)

extract_fit_parsnip(rf_res$.workflow[[1]]) |>
  vip::vi()
#> # A tibble: 135 × 2
#>    Variable                   Importance
#>    <chr>                           <dbl>
#>  1 gill_attachment_n               265. 
#>  2 gill_color_n                    180. 
#>  3 gill_attachment_f               179. 
#>  4 stalk_surface_below_ring_k      127. 
#>  5 stalk_color_above_ring_k        106. 
#>  6 odor                            101. 
#>  7 spore_print_color_p             100. 
#>  8 population_h                     94.4
#>  9 habitat_v                        82.4
#> 10 stalk_surface_below_ring_s       79.2
#> # … with 125 more rows

创建于2023年3月14日,使用reprex v2.0.2

相关问题