gpt4 book ai didi

scikit-learn - 使用 pytorch 和 sklearn 对 MNIST 数据集进行交叉验证

转载 作者:行者123 更新时间:2023-12-03 17:44:19 24 4
gpt4 key购买 nike

我是 pytorch 的新手,正在尝试实现一个前馈神经网络来对 mnist 数据集进行分类。我在尝试使用交叉验证时遇到了一些问题。我的数据具有以下形状:
x_train :torch.Size([45000, 784])y_train :torch.Size([45000])
我尝试使用 sklearn 中的 KFold。
kfold =KFold(n_splits=10)
这是我的训练方法的第一部分,我将数据分成几部分:

for  train_index, test_index in kfold.split(x_train, y_train): 
x_train_fold = x_train[train_index]
x_test_fold = x_test[test_index]
y_train_fold = y_train[train_index]
y_test_fold = y_test[test_index]
print(x_train_fold.shape)
for epoch in range(epochs):
...
y_train_fold 的索引变量是对的,它很简单: [ 0 1 2 ... 4497 4498 4499] ,但它不适用于 x_train_fold , 即 [ 4500 4501 4502 ... 44997 44998 44999] .测试折叠也是如此。

对于第一次迭代,我想要变量 x_train_fold成为前 4500 张图片,即具有形状 torch.Size([4500, 784]) ,但它的形状是 torch.Size([40500, 784])
关于如何做到这一点的任何提示?

最佳答案

我觉得你糊涂了!

暂时忽略第二个维度,当你有 45000 个点时,你使用 10 折交叉验证,每折的大小是多少? 45000/10 即 4500。

这意味着您的每个折叠将包含 4500 个数据点,其中一个折叠将用于测试,其余用于训练,即

For testing: one fold => 4500 data points => size: 4500
For training: remaining folds => 45000-4500 data points => size: 45000-4500=40500



因此,对于第一次迭代,前 4500 个数据点(对应于索引)将用于测试,其余用于训练。 (检查下图)

鉴于您的数据是 x_train: torch.Size([45000, 784])y_train: torch.Size([45000]) ,这就是您的代码的外观:
for train_index, test_index in kfold.split(x_train, y_train):  
print(train_index, test_index)

x_train_fold = x_train[train_index]
y_train_fold = y_train[train_index]
x_test_fold = x_train[test_index]
y_test_fold = y_train[test_index]

print(x_train_fold.shape, y_train_fold.shape)
print(x_test_fold.shape, y_test_fold.shape)
break

[ 4500 4501 4502 ... 44997 44998 44999] [ 0 1 2 ... 4497 4498 4499]
torch.Size([40500, 784]) torch.Size([40500])
torch.Size([4500, 784]) torch.Size([4500])

所以,当你说

I want the variable x_train_fold to be the first 4500 picture... shape torch.Size([4500, 784]).



你错了。此大小对应于 x_test_fold .在第一次迭代中,基于10折, x_train_fold将有 40500 个点,因此它的大小应该是 torch.Size([40500, 784]) .

K-fold validation image

关于scikit-learn - 使用 pytorch 和 sklearn 对 MNIST 数据集进行交叉验证,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58996242/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com