gpt4 book ai didi

python - 如何在pytorch对象检测中添加转换

转载 作者:行者123 更新时间:2023-12-05 06:09:11 25 4
gpt4 key购买 nike

我是 PyTorch 新手,正在阅读 PyTorch 对象检测文档教程 pytorch docx .在他们的协作版本中,我进行了以下更改以添加一些转换技术。

  1. 首先修改类PennFudanDataset(torch.utils.data.Dataset)的__getitem__方法
if self.transforms is not None:
img = self.transforms(img)
target = T.ToTensor()(target)
return img, target

In actual documentation it is
if self.transforms is not None:
img, target = self.transforms(img, target)

其次,在 get_transform(train) 函数处。

def get_transform(train):
if train:
transformed = T.Compose([
T.ToTensor(),
T.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
T.ColorJitter(brightness=[0.1, 0.2], contrast=[0.1, 0.2], saturation=[0, 0.2], hue=[0,0.5])
])
return transformed

else:
return T.ToTensor()

**In the documentation it is-**
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)

在执行代码时,出现以下错误。我无法理解我做错了什么。

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataset.py", line 272, in __getitem__
return self.dataset[self.indices[idx]]
File "<ipython-input-41-94e93ff7a132>", line 72, in __getitem__
target = T.ToTensor()(target)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 104, in __call__
return F.to_tensor(pic)
File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 64, in to_tensor
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
TypeError: pic should be PIL Image or ndarray. Got <class 'dict'>

最佳答案

我相信 Pytorch 转换仅适用于图像(在本例中为 PIL 图像或 np 数组),而不适用于标签(根据跟踪是字典)。因此,我认为您不需要像 __getitem__ 函数中的这一行 target = T.ToTensor()(target) 那样“张紧”标签。

关于python - 如何在pytorch对象检测中添加转换,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/64905441/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com