gpt4 book ai didi

python - Torch 数据集循环太远

转载 作者:太空宇宙 更新时间:2023-11-04 00:04:31 25 4
gpt4 key购买 nike

为什么这个数据集会尝试遍历最后一个元素

from torch.utils.data.dataset import Dataset
class DumbDataset(Dataset):
def __init__(self, dct):
self.dct = dct
self.mapping = dict(enumerate(dct))
def __getitem__(self, index):
return self.dct[self.mapping[index]]

def __len__(self):
print('called')
return len(self.dct)

ds = DumbDataset({'a': 'aword', 'b': 'another_words'})

for k in ds: print(k)

这引发了 KeyError: 2,我不明白,因为对象的长度是 2。迭代器用完后不应该得到 StopIteration 吗?

最佳答案

您的代码引发 KeyError 的原因是 Dataset does not implement __iter__(),因此当在 for 循环中使用时,Python 回退到从索引 0 开始并调用 __getitem__ 直到 IndexError 被引发,正如所讨论的 here .您可以修改 DumbDataset 使其在索引超出范围时引发 IndexError 以像这样工作

def __getitem__(self, index):
if index >= len(self): raise IndexError
return self.dct[self.mapping[index]]

然后是你的循环

for k in ds:
print(k)

将按您的预期工作。另一方面, torch 数据集的典型模板是您可以使用索引循环遍历它们

for i in range(len(ds)):
k = ds[k]
print(k)

或者您将它们包装在 DataLoader 中,它会批量返回元素

generator = DataLoader(ds)
for k in generator:
print(k)

关于python - Torch 数据集循环太远,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/54640906/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com