gpt4 book ai didi

python - 复杂数据集拆分 - StratifiedGroupShuffleSplit

转载 作者:太空宇宙 更新时间:2023-11-04 01:55:45 25 4
gpt4 key购买 nike

我有一个包含约 2m 个观测值的数据集,我需要按 60:20:20 的比例将其分成训练集、验证集和测试集。我的数据集的简化摘录如下所示:

+---------+------------+-----------+-----------+
| note_id | subject_id | category | note |
+---------+------------+-----------+-----------+
| 1 | 1 | ECG | blah ... |
| 2 | 1 | Discharge | blah ... |
| 3 | 1 | Nursing | blah ... |
| 4 | 2 | Nursing | blah ... |
| 5 | 2 | Nursing | blah ... |
| 6 | 3 | ECG | blah ... |
+---------+------------+-----------+-----------+

有多个类别——它们不是均衡的——所以我需要确保训练集、验证集和测试集都具有与原始数据集中相同的类别比例。这部分很好,我可以使用 sklearn 库中的 StratifiedShuffleSplit

但是,我还需要确保每个受试者的观察结果不会分散在训练、验证和测试数据集中。来自给定主题的所有观察结果都需要放在同一个桶中,以确保我训练的模型在验证/测试之前从未见过该主题。例如。 subject_id 1 的每个观察结果都应该在训练集中。

我想不出一种方法来确保按类别 进行分层拆分,防止subject_id 跨数据集的污染(需要一个更好的词),确保60:20:20 拆分并确保数据集以某种方式打乱。任何帮助,将不胜感激!

谢谢!


编辑:

我现在了解到,通过 GroupShuffleSplit 函数,sklearn 也可以实现按类别分组以及跨数据集拆分将组保持在一起。所以本质上,我需要的是组合的分层和分组洗牌拆分,即 StratifiedGroupShuffleSplit 不存在。 Github 问题:https://github.com/scikit-learn/scikit-learn/issues/12076

最佳答案

这在 scikit-learn 1.0 中用 StratifiedGroupKFold 解决了

在此示例中,您在洗牌后生成 3 折,将组保持在一起并进行分层(尽可能多)

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

X = np.ones((30, 2))
y = np.array([0, 0, 1, 1, 1, 1, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 1, 1, 1,])
groups = np.array([1, 1, 2, 2, 3, 3, 3, 4, 5, 5,
5, 5, 6, 6, 7, 8, 8, 9, 9, 9,
10, 11, 11, 12, 12, 12, 13, 13,
13, 13])
print("ORIGINAL POSITIVE RATIO:", y.mean())
cv = StratifiedGroupKFold(n_splits=3, shuffle=True)
for fold, (train_idxs, test_idxs) in enumerate(cv.split(X, y, groups)):
print("Fold :", fold)
print("TRAIN POSITIVE RATIO:", y[train_idxs].mean())
print("TEST POSITIVE RATIO :", y[test_idxs].mean())
print("TRAIN GROUPS :", set(groups[train_idxs]))
print("TEST GROUPS :", set(groups[test_idxs]))

在输出中,您可以看到折叠中阳性案例的比率接近原始阳性比率,并且同一组永远不会出现在两组中。当然,您拥有的组越少/越大(即,您的类(class)越不平衡)就越难保持接近原始类(class)分布。

输出:

ORIGINAL POSITIVE RATIO: 0.5
Fold : 0
TRAIN POSITIVE RATIO: 0.4375
TEST POSITIVE RATIO : 0.5714285714285714
TRAIN GROUPS : {1, 3, 4, 5, 6, 7, 10, 11}
TEST GROUPS : {2, 8, 9, 12, 13}
Fold : 1
TRAIN POSITIVE RATIO: 0.5
TEST POSITIVE RATIO : 0.5
TRAIN GROUPS : {2, 4, 5, 7, 8, 9, 11, 12, 13}
TEST GROUPS : {1, 10, 3, 6}
Fold : 2
TRAIN POSITIVE RATIO: 0.5454545454545454
TEST POSITIVE RATIO : 0.375
TRAIN GROUPS : {1, 2, 3, 6, 8, 9, 10, 12, 13}
TEST GROUPS : {11, 4, 5, 7}

关于python - 复杂数据集拆分 - StratifiedGroupShuffleSplit,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56872664/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com