gpt4 book ai didi

r - 如何保存 Tidymodels Lightgbm 模型以供重用

转载 作者:行者123 更新时间:2023-12-05 03:25:21 32 4
gpt4 key购买 nike

我有以下代码用于使用 lightgbm 模型创建 tidymodels 工作流。但是,当我尝试保存到 .rds 对象并进行预测时出现了一些问题

library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()

### Model ###

# data
data <- make_ames() %>%
janitor::clean_names()

data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))

data$id <- c(1:nrow(data))

data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())

# model specification

lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")

# recipe and workflow

lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7) %>%
prep()

lgbm_workflow <- workflow() %>%
add_recipe(lgbm_recipe) %>%
add_model(lgbm_model)

# fit workflow

fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)

# predict

data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)


### CASE 1: Save the workflow with SaveRDS()

saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")

# Predict - error: Attempting to use a Booster which no longer exists

predict(new_lgbm_workflow, new_data = data_predict)



### CASE 2: Save the workflow and the fitted model separately

fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")


new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model


# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’

predict(new_lgbm_workflow, new_data = data_predict)

只有 lightgbm 模型的工作流似乎有这个问题。对于其他类型的模型(随机森林、xgboost、glm 等),我可以使用 saveRDS() 保存拟合工作流,使用 readRDS() 读取,并使用预测新数据就好了

对于案例 2,显然底层预测函数将更改为 predict.lgb.Booster(),它以 matrix 作为输入。但是我的 id 变量具有 character 格式,而 matrix 中的所有列必须具有相同的格式

有没有办法保存整个工作流程以供将来使用?

最佳答案

经过大量挖掘,我在这个 closed issue 中找到了解决方案.

library(tidymodels)
#> Warning: package 'tidymodels' was built under R version 4.2.1
#> Warning: package 'broom' was built under R version 4.2.1
#> Warning: package 'scales' was built under R version 4.2.1
#> Warning: package 'infer' was built under R version 4.2.1
#> Warning: package 'modeldata' was built under R version 4.2.1
#> Warning: package 'parsnip' was built under R version 4.2.1
#> Warning: package 'rsample' was built under R version 4.2.1
#> Warning: package 'tibble' was built under R version 4.2.1
#> Warning: package 'workflows' was built under R version 4.2.1
#> Warning: package 'workflowsets' was built under R version 4.2.1
library(bonsai)
library(lightgbm)
#> Warning: package 'lightgbm' was built under R version 4.2.1
#> Loading required package: R6
#>
#> Attaching package: 'lightgbm'
#> The following object is masked from 'package:dplyr':
#>
#> slice

# data

data <- modeldata::ames %>%
janitor::clean_names()

data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))

data$id <- c(1:nrow(data))

data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())

# model specification

lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")

# recipe and workflow

lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7)

lgbm_workflow <- workflow(preprocessor = lgbm_recipe,
spec = lgbm_model)

# fit workflow

fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)

# predict

data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
#> # A tibble: 2,930 × 1
#> .pred
#> <dbl>
#> 1 201911.
#> 2 124695.
#> 3 138983.
#> 4 221095.
#> 5 198972.
#> 6 188613.
#> 7 198730.
#> 8 170893.
#> 9 243899.
#> 10 196875.
#> # … with 2,920 more rows

# save the trained workflow and lgb.booster object separately

saveRDS(fit_lgbm_workflow, "lgbm_wflw.rds")
saveRDS.lgb.Booster(extract_fit_engine(fit_lgbm_workflow), "lgbm_booster.rds")

# load trained workflow and merge it with lgb.booster

new_lgbm_wflow <- readRDS("lgbm_wflw.rds")
new_lgbm_wflow$fit$fit$fit <- readRDS.lgb.Booster("lgbm_booster.rds")

predict(new_lgbm_wflow, data_predict)
#> # A tibble: 2,930 × 1
#> .pred
#> <dbl>
#> 1 201911.
#> 2 124695.
#> 3 138983.
#> 4 221095.
#> 5 198972.
#> 6 188613.
#> 7 198730.
#> 8 170893.
#> 9 243899.
#> 10 196875.
#> # … with 2,920 more rows

创建于 2022-09-07 reprex v2.0.2

在我上面的 reprex 中,我使用了一个工作流来适应。如果您使用欧洲防风草对象来适应,请改用这种方法:


saveRDS(bonsai_fit, path1)
saveRDS.lgb.Booster(extract_fit_engine(bonsai_fit), path2)
bonsai_fit_read <- readRDS(path1)
bonsai_fit_engine_read <- readRDS.lgb.Booster(path2)
bonsai_fit_read$fit <- bonsai_fit_engine_read

引用this comment了解更多详情。

silver lining是:

只想在此对话中补充一点,自 2021 年 12 月以来,{lightgbm} 的开发版本已支持直接对 {lightgbm} 模型使用 readsRDS()/saveRDS()

关于r - 如何保存 Tidymodels Lightgbm 模型以供重用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72027360/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com