gpt4 book ai didi

r - R 中 XGBoost 在分类变量值不完整的数据中的应用

转载 作者:行者123 更新时间:2023-11-30 09:41:30 25 4
gpt4 key购买 nike

您如何应对 XGBoost 在 R 中的应用?我有一个问题,因为当数据的分类类型列不包含其所有可能的值(模型解释的)时,我收到一个错误:“存储在 object 中的特色名称和 newdata 不同”。

我知道如何通过以不同的方式准备输入数据来解决这个问题,即通过添加足够数量的虚拟变量来覆盖我打算考虑的分类变量的所有可能值。例如。如果我想使用的特征 F 具有值“a”、“b”或“c”,我会使用特征 is_a、is_b 和 is_c 创建一个 XGBoost 模型。然后,如果在我想要应用模型的输入数据中,特征 F 仅带有“b”或“c”值,则我仍然使用这 3 个特征,每次观察时 is_c 等于 0。

但这不是我想要的方式,因为它通常看起来相当乏味,而且在使用不同的模型时我没有遇到类似的问题,例如通过 glm() 函数进行逻辑回归。

所以我的问题是:是否可以将 XGBoost 模型应用于包含具有不完整值的分类(因子)变量的观测值? 不完整在这里的意思是:模型并未考虑所有值。

我根据 mtcars 数据准备了一个示例来展示这种情况。假设我们想要一个预测齿轮箱类型的分类模型(自动或手动,“am”列)。可能的特征之一是权重(“wt”列),我们希望将权重数据用作因子类型特征,而不是连续类型特征。

library(xgboost)
library(dplyr)
library(dummies)

##### Example 0: wt as a continuous variable (no errors on data with incomplete values) #####
# Train:
data_train <- mtcars
model_matrix_train <- model.matrix(am ~ ., data = data_train)
xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)
param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")
model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)

# Test on data with incomplete wt values:
data_test <- mtcars %>%
filter(wt < 4)
model_matrix_test <- model.matrix(am ~ ., data = data_test)
xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)
predict(model_xgb, newdata = xgb_data_test, type="prob")


##### Example 1: wt as a factor (error on data with incomplete values) #####
# Train:
data_train <- mtcars %>%
mutate(wt = factor(
case_when(
wt < 2 ~ "1_2",
wt < 3 ~ "2_3",
wt < 4 ~ "3_4",
wt < 5 ~ "4_5",
TRUE ~ "5_6"
))
)
model_matrix_train <- model.matrix(am ~ ., data = data_train)
xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)
param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")
model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)

# Test on data with incomplete wt values:
data_test <- mtcars %>%
filter(wt < 4) %>%
mutate(wt = factor(
case_when(
wt < 2 ~ "1_2",
wt < 3 ~ "2_3",
wt < 4 ~ "3_4",
wt < 5 ~ "4_5",
TRUE ~ "5_6"
))
)
model_matrix_test <- model.matrix(am ~ ., data = data_test)
xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)
predict(model_xgb, newdata = xgb_data_test, type="prob") # ERROR

我还尝试对 wt 的所有相关情况使用虚拟变量(而不是将 wt 转换为因子变量)。结果与上面的示例1类似:

##### Example 2: wt as a dummy variable (error on data with incomplete values) #####
# Train:
data_train <- mtcars %>%
mutate(wt = factor(
case_when(
wt < 2 ~ "1_2",
wt < 3 ~ "2_3",
wt < 4 ~ "3_4",
wt < 5 ~ "4_5",
TRUE ~ "5_6"
))
)
data_train <- dummy.data.frame(data_train, "wt", sep = "_")
model_matrix_train <- model.matrix(am ~ ., data = data_train)
xgb_data_train <- xgb.DMatrix(model_matrix_train, label = data_train$am)
param <- list(max_depth = 2, eta = 1, objective = "binary:logistic")
model_xgb <- xgb.train(param, xgb_data_train, nrounds = 100)

# Test on data with incomplete wt values:
data_test <- mtcars %>%
filter(wt < 4) %>%
mutate(wt = factor(
case_when(
wt < 2 ~ "1_2",
wt < 3 ~ "2_3",
wt < 4 ~ "3_4",
wt < 5 ~ "4_5",
TRUE ~ "5_6"
))
)
data_test <- dummy.data.frame(data_test, "wt", sep = "_")
model_matrix_test <- model.matrix(am ~ ., data = data_test)
xgb_data_test <- xgb.DMatrix(model_matrix_test, label = data_test$am)
predict(model_xgb, newdata = xgb_data_test, type="prob") # ERROR

最佳答案

虽然输入数据中缺少特征的原因对于算法来说是合理的(没有可用的分类数据),但无论是因为数据不包含因子水平还是因为数据确实不完整而导致特征缺失,都没有区别(缺少一个功能)。

因此,我只能为您提供一种更快的方法来对新输入数据进行编码,以始终具有正确的功能级别:

data_test <- mtcars %>% 
filter(wt < 4) %>%
mutate(wt = factor(
case_when(
wt < 2 ~ "1_2",
wt < 3 ~ "2_3",
wt < 4 ~ "3_4",
wt < 5 ~ "4_5",
TRUE ~ "5_6"
), levels = c("1_2","2_3","3_4","4_5","5_6")) #instead of c(...) this could be variable with the stored factor levels from model creation
)

data_test <- (data_test %>% cbind(model.matrix(~ wt-1, data = .) %>% data.frame())

这做了两件重要的事情:

  1. 对因素级别进行编码

通过在因子转换中提供级别参数,您将获得所有相关级别。除了提供手动列表之外,您在创建原始模型时始终将适当的因子水平保存为变量。

  • 使用 cbind 和 model.matrix() 作为傻瓜
  • 不要使用 dummy.data.frame 函数,而是使用 model.matrix(),因为它会针对缺失因子水平自动编码为 0。

    关于r - R 中 XGBoost 在分类变量值不完整的数据中的应用,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58184250/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com