gpt4 book ai didi

PyTorch Dataloader - 枚举时列表不可调用错误

转载 作者:行者123 更新时间:2023-12-01 21:59:02 25 4
gpt4 key购买 nike

当遍历 PyTorch 数据加载器时,例如

# define dataset, dataloader
train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)
test_data = datasets.ImageFolder(data_dir + '/test', transform=test_transforms)
trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64)

# define model, optimizer, loss
# not included - irrelevant to the question

for ii, (inputs, labels) in enumerate(trainloader):

# Move input and label tensors to the GPU
inputs, labels = inputs.to(device), labels.to(device)

start = time.time()

outputs = model.forward(inputs)
loss = criterion(outputs, labels)
loss.backward()

我在这一行收到一个TypeError: 'list' object is not callable

for ii, (inputs, labels) in enumerate(trainloader):

我忘记了什么蠢事?

最佳答案

您是否记得在您的转换列表中调用 transforms.Compose

在这一行

train_data = datasets.ImageFolder(data_dir + '/train', transform=train_transforms)

transform 参数需要一个可调用对象,而不是列表。

因此,例如,这是错误的:

train_transforms = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]

应该是这样的

train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

关于PyTorch Dataloader - 枚举时列表不可调用错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54431671/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com