gpt4 book ai didi

r - Tidymodels:如何从训练数据中获得额外的重要性

转载 作者:行者123 更新时间:2023-12-05 09:27:20 26 4
gpt4 key购买 nike

我有以下代码,我在其中对不同的 mtry 和 min_n 进行了一些网格搜索。我知道如何提取精度最高的参数(参见第二个代码框)。如何提取训练数据集中每个特征的重要性?我在网上找到的指南显示了如何使用“last_fit”仅在测试数据集中执行此操作。例如。指南:https://www.tidymodels.org/start/case-study/#data-split

set.seed(seed_number)
data_split <- initial_split(node_strength,prop = 0.8,strata = Group)

train <- training(data_split)
test <- testing(data_split)
train_folds <- vfold_cv(train,v = 10)


rfc <- rand_forest(mode = "classification", mtry = tune(),
min_n = tune(), trees = 1500) %>%
set_engine("ranger", num.threads = 48, importance = "impurity")

rfc_recipe <- recipe(data = train, Group~.)

rfc_workflow <- workflow() %>% add_model(rfc) %>%
add_recipe(rfc_recipe)

rfc_result <- rfc_workflow %>%
tune_grid(train_folds, grid = 40, control = control_grid(save_pred = TRUE),
metrics = metric_set(accuracy))

.

best <- 
rfc_result %>%
select_best(metric = "accuracy")

最佳答案

为此,您需要创建自定义extract 函数,如outlined in this documentation .

对于随机森林变量重要性,您的函数将如下所示:

get_rf_imp <- function(x) {
x %>%
extract_fit_parsnip() %>%
vip::vi()
}

然后您可以像这样将它应用于您的重采样(请注意您会得到一个新的 .extracts 列):

library(tidymodels)
data(cells, package = "modeldata")

set.seed(123)
cell_split <- cells %>% select(-case) %>%
initial_split(strata = class)
cell_train <- training(cell_split)
cell_test <- testing(cell_split)
folds <- vfold_cv(cell_train)

rf_spec <- rand_forest(mode = "classification") %>%
set_engine("ranger", importance = "impurity")

ctrl_imp <- control_grid(extract = get_rf_imp)

cells_res <-
workflow(class ~ ., rf_spec) %>%
fit_resamples(folds, control = ctrl_imp)
cells_res
#> # Resampling results
#> # 10-fold cross-validation
#> # A tibble: 10 × 5
#> splits id .metrics .notes .extracts
#> <list> <chr> <list> <list> <list>
#> 1 <split [1362/152]> Fold01 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 2 <split [1362/152]> Fold02 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 3 <split [1362/152]> Fold03 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 4 <split [1362/152]> Fold04 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 5 <split [1363/151]> Fold05 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 6 <split [1363/151]> Fold06 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 7 <split [1363/151]> Fold07 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 8 <split [1363/151]> Fold08 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 9 <split [1363/151]> Fold09 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>
#> 10 <split [1363/151]> Fold10 <tibble [2 × 4]> <tibble [0 × 3]> <tibble [1 × 2]>

reprex package 创建于 2022-06-19 (v2.0.1)

一旦你有了这些变量重要性分数提取,你可以unnest()它们(现在,你必须这样做两次,因为它嵌套很深)然后你可以总结和可视化你喜欢:

cells_res %>%
select(id, .extracts) %>%
unnest(.extracts) %>%
unnest(.extracts) %>%
group_by(Variable) %>%
summarise(Mean = mean(Importance),
Variance = sd(Importance)) %>%
slice_max(Mean, n = 15) %>%
ggplot(aes(Mean, reorder(Variable, Mean))) +
geom_crossbar(aes(xmin = Mean - Variance, xmax = Mean + Variance)) +
labs(x = "Variable importance", y = NULL)

reprex package 创建于 2022-06-19 (v2.0.1)

关于r - Tidymodels:如何从训练数据中获得额外的重要性,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72538064/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com