gpt4 book ai didi

python - 如何在 Pytorch 中使用 torchvision.transforms 进行分割任务的数据增强?

转载 作者:行者123 更新时间:2023-12-02 08:18:31 38 4
gpt4 key购买 nike

我对 PyTorch 中执行的数据增强有点困惑。

因为我们正在处理分割任务,所以我们需要数据和掩码来进行相同的数据增强,但其中一些是随机的,例如随机旋转。

Keras 提供了随机种子保证数据和掩码执行相同的操作,如以下代码所示:

    data_gen_args = dict(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=25,
horizontal_flip=True,
vertical_flip=True)


image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)

seed = 1
image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

train_generator = zip(image_generator, mask_generator)

我在Pytorch官方文档中没有找到类似的描述,所以不知道如何保证data和mask能够同步处理。

Pytorch确实提供了这样的功能,但我想将其应用到自定义Dataloader中。

例如:

def __getitem__(self, index):
img = np.zeros((self.im_ht, self.im_wd, channel_size))
mask = np.zeros((self.im_ht, self.im_wd, channel_size))

temp_img = np.load(Image_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')

for i in range(channel_size):
img[:,:,i] = temp_img[self.count[index] + i]
mask[:,:,i] = temp_label[self.count[index] + i]

if self.transforms:
img = np.uint8(img)
mask = np.uint8(mask)
img = self.transforms(img)
mask = self.transforms(mask)

return img, mask

这种情况下,img和mask会分别进行变换,因为随机旋转等一些操作是随机的,所以mask和image的对应关系可能会改变。换句话说,图像可能已经旋转,但蒙版却没有这样做。

编辑1

我用的是augmentations.py中的方法,但是我得到了一个错误::

Traceback (most recent call last):
File "test_transform.py", line 87, in <module>
for batch_idx, image, mask in enumerate(train_loader):
File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
batch = self.collate_fn([self.dataset[i] for i in indices])
File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
return self.dataset[self.indices[idx]]
File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
img, mask = self.transforms(img, mask)
File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

这是我的__getitem__()代码:

data_transforms = {
'train': Compose([
RandomHorizontallyFlip(),
RandomRotate(degree=25),
transforms.ToTensor()
]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
img = np.zeros((self.im_ht, self.im_wd, channel_size))
mask = np.zeros((self.im_ht, self.im_wd, channel_size))
temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
for i in range(channel_size):
img[:,:,i] = temp_img[self.count[index] + i]
mask[:,:,i] = temp_label[self.count[index] + i]

if self.transforms:
img = T.ToPILImage()(np.uint8(img))
mask = T.ToPILImage()(np.uint8(mask))
img, mask = self.transforms(img, mask)

img = T.ToTensor()(img).copy()
mask = T.ToTensor()(mask).copy()
return img, mask

编辑2

我发现ToTensor之后,相同标签之间的骰子变成了255而不是1,如何修复它?

# Dice computation
def DSC_computation(label, pred):
pred_sum = pred.sum()
label_sum = label.sum()
inter_sum = np.logical_and(pred, label).sum()
return 2 * float(inter_sum) / (pred_sum + label_sum)

请随时询问是否需要更多代码来解释问题。

最佳答案

需要输入参数(例如 RandomCrop)的转换有一个 get_param 方法,该方法将返回该特定转换的参数。然后可以使用变换的功能接口(interface)将其应用于图像和蒙版:

from torchvision import transforms
import torchvision.transforms.functional as F

i, j, h, w = transforms.RandomCrop.get_params(input, (100, 100))
input = F.crop(input, i, j, h, w)
target = F.crop(target, i, j, h, w)

示例可在此处获取: https://github.com/pytorch/vision/releases/tag/v0.2.0

此处提供 VOC 和 COCO 的完整示例: https://github.com/pytorch/vision/blob/master/references/segmentation/transforms.py https://github.com/pytorch/vision/blob/master/references/segmentation/train.py

关于错误,

ToTensor() 未重写以处理其他掩码参数,因此它不能位于 data_transforms 中。此外,__getitem__ 在返回 imgmask 之前对它们进行 ToTensor

data_transforms = {
'train': Compose([
RandomHorizontallyFlip(),
RandomRotate(degree=25),
#transforms.ToTensor() => remove this line
]),
}

关于python - 如何在 Pytorch 中使用 torchvision.transforms 进行分割任务的数据增强?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58215056/

38 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com