gpt4 book ai didi

python - 带有索引的 Scikit-learn train_test_split

转载 作者:IT老高 更新时间:2023-10-28 20:37:28 24 4
gpt4 key购买 nike

使用train_test_split()时如何获取数据的原始索引?

我拥有的是以下

from sklearn.cross_validation import train_test_split
import numpy as np
data = np.reshape(np.randn(20),(10,2)) # 10 training examples
labels = np.random.randint(2, size=10) # 10 labels
x1, x2, y1, y2 = train_test_split(data, labels, size=0.2)

但这并没有给出原始数据的索引。一种解决方法是将索引添加到数据(例如 data = [(i, d) for i, d in enumerate(data)]),然后将它们传递到 train_test_split然后再次展开。有没有更清洁的解决方案?

最佳答案

您可以像 Julien 所说的那样使用 pandas 数据帧或系列,但如果您想将自己限制为 numpy,您可以传递一个额外的索引数组:

from sklearn.model_selection import train_test_split
import numpy as np
n_samples, n_features, n_classes = 10, 2, 2
data = np.random.randn(n_samples, n_features) # 10 training examples
labels = np.random.randint(n_classes, size=n_samples) # 10 labels
indices = np.arange(n_samples)
(
data_train,
data_test,
labels_train,
labels_test,
indices_train,
indices_test,
) = train_test_split(data, labels, indices, test_size=0.2)

关于python - 带有索引的 Scikit-learn train_test_split,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/31521170/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com