gpt4 book ai didi

r - Tidymodels:仅当概率为 75% 或更高时才分类为 TRUE

转载 作者:行者123 更新时间:2023-12-05 02:01:07 26 4
gpt4 key购买 nike

我有一个二元分类问题,使用了随机森林和逻辑回归。根据 conf_matcollect_metrics()collect_predictions 的结果,我想更改我的模型,仅当模型“确定”时才分类为 TRUE ”说 75% 或更高的概率。我只是不知道在哪里指定此更改。如果有人能给我提示,那就太好了。我的直觉告诉我它应该在模型规范中的某个地方,例如在这里的某个地方,但也许我错了。

canc_rf_model <- rand_forest(
mtry = tune(),
min_n = tune(),
trees = 500) %>%
set_engine("ranger") %>%
set_mode("classification")

canc_log_model <- logistic_reg() %>%
set_engine("glm") %>%
set_mode("classification")

非常感谢您!

最佳答案

硬类预测来自底层 ranger::predictions() 函数,而不是来自 功能,因此配件本身无需做太多工作。

但是,如果您愿意,可以在拟合后非常流畅地更改它。让我们做一个示例分类模型:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip

data("ad_data")
alz <- ad_data

# data splitting
set.seed(100)
alz_split <- initial_split(alz, strata = Class, prop = .9)
alz_train <- training(alz_split)
alz_test <- testing(alz_split)

# data resampling
set.seed(100)
alz_folds <-
vfold_cv(alz_train, v = 10, strata = Class)

rf_mod <-
rand_forest(trees = 1e3) %>%
set_engine("ranger") %>%
set_mode("classification")

rf_wf <-
workflow() %>%
add_formula(Class ~ .) %>%
add_model(rf_mod)

set.seed(100)
rf_preds <- rf_wf %>%
fit_resamples(
resamples = alz_folds,
control = control_resamples(save_pred = TRUE)) %>%
collect_predictions()

这是默认的混淆矩阵:

rf_preds %>%
conf_mat(Class, .pred_class)
#> Truth
#> Prediction Impaired Control
#> Impaired 37 5
#> Control 45 213

您可以使用 probably包后处理你的类概率估计并覆盖默认值:

library(probably)
#>
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#>
#> as.factor, as.ordered

rf_preds %>%
mutate(.pred_class = make_two_class_pred(.pred_Impaired,
levels(rf_preds$Class),
threshold = 0.75),
.pred_class = factor(.pred_class, levels = levels(rf_preds$Class))) %>%
conf_mat(Class, .pred_class)
#> Truth
#> Prediction Impaired Control
#> Impaired 0 0
#> Control 82 218

reprex package 创建于 2021-03-23 (v1.0.0)

关于r - Tidymodels:仅当概率为 75% 或更高时才分类为 TRUE,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66759453/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com