gpt4 book ai didi

python - 如何根据sklearn中的列值拆分数据

转载 作者:行者123 更新时间:2023-12-03 17:00:48 28 4
gpt4 key购买 nike

我有一个包含以下列的数据文件

'顾客',
'calibrat' - 校准样本 = 1;验证样本 = 0;
'搅拌',
'churndep',
'收入',
'穆',

数据文件包含大约 40000 行,其中 20000 行的校准值为 1。我想将此数据拆分为

X1 = data.loc[:, data.columns != 'churn']
y1 = data.loc[:, data.columns == 'churn']
from imblearn.over_sampling import SMOTE
os = SMOTE(random_state=0)
X1_train, X1_test, y1_train, y1_test = train_test_split(X1, y1, test_size=0.3, random_state=0)

我想要的是,在我的 X1_train 中应该带有 calibrat =1 的校准数据
并且在 X1_test 中应该出现所有数据以进行 calibrat = 0 验证

最佳答案

sklearn.model_selection 除了 train_test_split 之外还有其他几个选项。其中之一,旨在解决您的要求。在这种情况下,您可以使用 GroupShuffleSplit ,如文档中所述,它提供随机训练/测试索引以根据第三方提供的组拆分数据。这在您进行交叉验证时很有用,并且您想在验证训练中多次拆分,确保集合由 group 字段拆分。对于这些情况,您还可以使用 GroupKFold ,这非常有用。

所以,调整你的例子,这是你可以做的。

假设你有例如:

from sklearn.model_selection import GroupShuffleSplit

cols = ['customer', 'calibrat', 'churn', 'churndep', 'revenue', 'mou',]
X = pd.DataFrame(np.random.rand(10, 6), columns=cols)
X['calibrat'] = np.random.choice([0,1], size=10)

print(X)

customer calibrat churn churndep revenue mou
0 0.523571 1 0.394896 0.933637 0.232630 0.103486
1 0.456720 1 0.850961 0.183556 0.885724 0.993898
2 0.411568 1 0.003360 0.774391 0.822560 0.840763
3 0.148390 0 0.115748 0.089891 0.842580 0.565432
4 0.505548 0 0.370198 0.566005 0.498009 0.601986
5 0.527433 0 0.550194 0.991227 0.516154 0.283175
6 0.983699 0 0.514049 0.958328 0.005034 0.050860
7 0.923172 0 0.531747 0.026763 0.450077 0.961465
8 0.344771 1 0.332537 0.046829 0.047598 0.324098
9 0.195655 0 0.903370 0.399686 0.170009 0.578925

y = X.pop('churn')

您现在可以实例化 GroupShuffleSplit ,就像使用 train_test_split 一样,唯一的区别是指定一个 group 列,该列将用于拆分 Xy 以便根据组值拆分组:
gs = GroupShuffleSplit(n_splits=2, train_size=.7, random_state=42)

如前所述,当您想分成多个组时,这更方便,通常用于交叉验证。如问题中所述,这只是一个如何获得两次拆分的示例:
train_ix, test_ix = next(gs.split(X, y, groups=X.calibrat))

X_train = X.loc[train_ix]
y_train = y.loc[train_ix]

X_test = X.loc[test_ix]
y_test = y.loc[test_ix]

给予:
print(X_train)

customer calibrat churndep revenue mou
3 0.148390 0 0.089891 0.842580 0.565432
4 0.505548 0 0.566005 0.498009 0.601986
5 0.527433 0 0.991227 0.516154 0.283175
6 0.983699 0 0.958328 0.005034 0.050860
7 0.923172 0 0.026763 0.450077 0.961465
9 0.195655 0 0.399686 0.170009 0.578925

print(X_test)

customer calibrat churndep revenue mou
0 0.523571 1 0.933637 0.232630 0.103486
1 0.456720 1 0.183556 0.885724 0.993898
2 0.411568 1 0.774391 0.822560 0.840763
8 0.344771 1 0.046829 0.047598 0.324098

关于python - 如何根据sklearn中的列值拆分数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61115535/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com