gpt4 book ai didi

r - 如何从 rpart 包中绘制递归分区

转载 作者:行者123 更新时间:2023-12-05 03:22:58 24 4
gpt4 key购买 nike

我想绘制通过递归二元拆分构造的二维协变量空间的分区。更准确地说,我想编写一个函数来复制下图(取自 Elements of Statistical Learning,第 306 页):

enter image description here

上面显示的是二维协变量空间和使用轴对齐分割(也称为 CART 算法)递归二进制分割空间得到的分区。我想要实现的是一个函数,它接受 rpart 函数的输出并生成这样的图。

它遵循一些示例代码:

## Generating data.
set.seed(1975)

n <- 5000
p <- 2

X <- matrix(sample(seq(0, 1, by = 0.01), n * p, replace = TRUE), ncol = p)
Y <- X[, 1] + 2 * X[, 2] + rnorm(n)

## Building tree.
tree <- rpart(Y ~ ., data = data.frame(Y, X), method = "anova", control = rpart.control(cp = 0, maxdepth = 2))

Navigating SO我找到了这个函数:

rpart_splits <- function(fit, digits = getOption("digits")) {
splits <- fit$splits
if (!is.null(splits)) {
ff <- fit$frame
is.leaf <- ff$var == "<leaf>"
n <- nrow(splits)
nn <- ff$ncompete + ff$nsurrogate + !is.leaf
ix <- cumsum(c(1L, nn))
ix_prim <- unlist(mapply(ix, ix + c(ff$ncompete, 0), FUN = seq, SIMPLIFY = F))
type <- rep.int("surrogate", n)
type[ix_prim[ix_prim <= n]] <- "primary"
type[ix[ix <= n]] <- "main"
left <- character(nrow(splits))
side <- splits[, 2L]
for (i in seq_along(left)) {
left[i] <- if (side[i] == -1L)
paste("<", format(signif(splits[i, 4L], digits)))
else if (side[i] == 1L)
paste(">=", format(signif(splits[i, 4L], digits)))
else {
catside <- fit$csplit[splits[i, 4L], 1:side[i]]
paste(c("L", "-", "R")[catside], collapse = "", sep = "")
}
}
cbind(data.frame(var = rownames(splits),
type = type,
node = rep(as.integer(row.names(ff)), times = nn),
ix = rep(seq_len(nrow(ff)), nn),
left = left),
as.data.frame(splits, row.names = F))
}
}

使用这个函数,我能够恢复所有的 split 变量和点:

splits <- rpart_splits(tree)[rpart_splits(tree)$type == "main", ]
splits

# var type node ix left count ncat improve index adj
# 1 X2 main 1 1 < 0.565 5000 -1 0.18110662 0.565 0
# 3 X2 main 2 2 < 0.265 2814 -1 0.06358597 0.265 0
# 6 X1 main 3 5 < 0.645 2186 -1 0.07645851 0.645 0

var 列告诉我每个非终端节点的拆分变量,left 列告诉我相关的拆分点。但是,我不知道如何使用这些信息来生成我想要的图。

当然,如果您有任何不涉及使用 rpart_splits 的替代策略,请随时提出建议。

最佳答案

您可以使用(未​​发布的)parttree 包,您可以通过以下方式从 GitHub 安装它:

remotes::install_github("grantmcdermott/parttree")

这允许:

library(parttree)

ggplot() +
geom_parttree(data = tree, aes(fill = path)) +
coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
scale_fill_brewer(palette = "Pastel1", name = "Partitions") +
theme_bw(base_size = 16) +
labs(x = "X2", y = "X1")

enter image description here

顺便说一句,这个包还包含函数 parttree,它返回的内容与您的非常相似rpart_splits 函数:

parttree(tree)
node Y path xmin xmax ymin ymax
1 4 0.7556079 X2 < 0.565 --> X2 < 0.265 -Inf 0.265 -Inf Inf
2 5 1.3087679 X2 < 0.565 --> X2 >= 0.265 0.265 0.565 -Inf Inf
3 6 1.8681143 X2 >= 0.565 --> X1 < 0.645 0.565 Inf -Inf 0.645
4 7 2.4993361 X2 >= 0.565 --> X1 >= 0.645 0.565 Inf 0.645 Inf

关于r - 如何从 rpart 包中绘制递归分区,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/72587924/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com