gpt4 book ai didi

python-3.x - DataLoader类错误Pytorch

转载 作者:行者123 更新时间:2023-12-03 12:29:17 25 4
gpt4 key购买 nike

我是pytorch初学者用户,并且正在尝试使用dataloader。

实际上,我正在尝试将其实现到我的网络中,但是加载需要很长时间。因此,我调试了网络,以查看网络本身是否有问题,但事实证明,它与我的数据加载器类有关。这是代码:

 from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd

class DiabetesDataset(Dataset):

def __init__(self, csv):
self.xy = pd.read_csv(csv)

def __len__(self):
return len(self.xy)

def __getitem__(self, index):
self.x_data = torch.Tensor(xy.iloc[:, 0:-1].values)
self.y_data = torch.Tensor(xy.iloc[:, [-1]].values)
return self.x_data[index], self.y_data[index]

dataset = DiabetesDataset("trial.csv")
train_loader = DataLoader(dataset=dataset,
batch_size=1,
shuffle=True,
num_workers=2)`

for a in train_loader:
print(a)


为了验证数据加载器是否造成了所有延迟,我创建了一个虚拟csv文件,该文件包含2列1s和2s,每列总共10个样本。然后,我循环了train_loader对象,该对象已经超过1小时,并且仍在运行,考虑到样本量较小且批处理大小设置为1。

我不确定我的代码是什么错误,这是导致此问题的原因。

任何意见/输入,不胜感激!

最佳答案

您的代码中存在一些错误-您可以检查一下是否可行(通过玩具示例在我的计算机上运行):

from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import torch


class DiabetesDataset(Dataset):

def __init__(self, csv):
self.xy = pd.read_csv(csv)

def __len__(self):
return len(self.xy)

def __getitem__(self, index):
x_data = torch.Tensor(self.xy.iloc[:, 0:-1].values)
y_data = torch.Tensor(self.xy.iloc[:, [-1]].values)
return x_data[index], y_data[index]


dataset = DiabetesDataset("trial.csv")


train_loader = DataLoader(
dataset=dataset,
batch_size=1,
shuffle=True,
num_workers=2)

if __name__ == '__main__':
for a in train_loader:
print(a)


编辑:您的代码不起作用,因为您缺少 self方法(self.xy.iloc ...)中的 __getitem__,并且因为脚本末尾没有 if __name__ == '__main__。对于第二个错误,请参见 RuntimeError on windows trying python multiprocessing

关于python-3.x - DataLoader类错误Pytorch,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54898145/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com