gpt4 book ai didi

python - MNIST Pytorch 中的验证错误意外增加

转载 作者:行者123 更新时间:2023-12-01 01:22:45 28 4
gpt4 key购买 nike

我对整个领域有点陌生,因此决定研究 MNIST 数据集。我几乎改编了 https://github.com/pytorch/examples/blob/master/mnist/main.py 中的整个代码,只有一个重大变化:数据加载。我不想使用 Torchvision 中预加载的数据集。所以我用了MNIST in CSV .

我通过继承 Dataset 并创建一个新的数据加载器来加载 CSV 文件中的数据。相关代码如下:

mean = 33.318421449829934
sd = 78.56749081851163
# mean = 0.1307
# sd = 0.3081
import numpy as np
from torch.utils.data import Dataset, DataLoader

class dataset(Dataset):
def __init__(self, csv, transform=None):
data = pd.read_csv(csv, header=None)
self.X = np.array(data.iloc[:, 1:]).reshape(-1, 28, 28, 1).astype('float32')
self.Y = np.array(data.iloc[:, 0])

del data
self.transform = transform

def __len__(self):
return len(self.X)

def __getitem__(self, idx):
item = self.X[idx]
label = self.Y[idx]

if self.transform:
item = self.transform(item)

return (item, label)

import torchvision.transforms as transforms
trainData = dataset('mnist_train.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))
testData = dataset('mnist_test.csv', transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((mean,), (sd,))
]))

train_loader = DataLoader(dataset=trainData,
batch_size=10,
shuffle=True,
)
test_loader = DataLoader(dataset=testData,
batch_size=10,
shuffle=True,
)

然而,这段代码给了我你在图片中看到的绝对奇怪的训练错误图,以及 11% 的最终验证错误,因为它将所有内容分类为“7”。 Validation error graph

我设法将问题追溯到如何标准化数据,如果我使用示例代码中给出的值(0.1307和0.3081)进行变换。标准化,同时将数据读取为“uint8”类型,它就可以工作完美。请注意,这两种情况下提供的数据差异极小。对 0 到 1 的值按 0.1307 和 0.3081 进行归一化与对 0 到 255 的值按 33.31 和 78.56 进行归一化具有相同的效果。这些值甚至基本相同(黑色像素在第一种情况下对应于 -0.4241,在第一种情况下对应于 -0.4242)在第二个)。

如果您想查看可以清楚地看到此问题的 IPython Notebook,请查看 https://colab.research.google.com/drive/1W1qx7IADpnn5e5w97IcxVvmZAaMK9vL3

我无法理解是什么导致了这两种略有不同的数据加载方式的行为存在如此巨大的差异。任何帮助将不胜感激。

最佳答案

长话短说:你需要改变item = self.X[idx]item = self.X[idx].copy() .

长话短说:T.ToTensor()运行 torch.from_numpy ,它返回一个张量,该张量为 numpy 数组的内存设置别名 dataset.X 。和T.Normalize() works inplace ,因此每次抽取样本时都有 mean减去并除以 std ,导致数据集降级。

编辑:关于为什么它在原始 MNIST 加载器中工作,兔子洞甚至更深。 MNIST中的关键行是图像是 transformed进入PIL.Image实例。该操作声称仅在缓冲区不连续的情况下进行复制(在我们的例子中),但在 hood 下。它会检查它是否被跨步了(确实如此),然后复制它。因此,幸运的是,默认的 torchvision 管道涉及 T.Normalize() 的复制,从而就地操作。不会损坏内存 self.data我们的MNIST实例。

关于python - MNIST Pytorch 中的验证错误意外增加,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53652015/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com