gpt4 book ai didi

python - 有没有办法检索在随机 Torchvision 变换中使用的特定参数?

转载 作者:行者123 更新时间:2023-12-04 15:07:11 26 4
gpt4 key购买 nike

我可以在训练期间通过应用随机变换(旋转/平移/重新缩放)来增加我的数据,但我不知道选择的值。
我需要知道应用了哪些值。我可以手动设置这些值,但是我失去了 Torch Vision 转换提供的很多好处。
是否有一种简单的方法可以让这些值以一种明智的方式在培训期间应用?
这是一个例子。我希望能够打印出旋转角度,在每个图像上应用平移/重新缩放:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms


RandAffine = transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2))

rotate = transforms.RandomRotation(degrees=45)
shift = RandAffine
composed = transforms.Compose([rotate,
shift])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = np.zeros((28,28))
sample[5:15,7:20] = 255
sample = transforms.ToPILImage()(sample.astype(np.uint8))
title = ['None', 'Rot','Aff','Comp']
for i, tsfrm in enumerate([None,rotate, shift, composed]):
if tsfrm:
t_sample = tsfrm(sample)
else:
t_sample = sample
ax = plt.subplot(1, 5, i + 2)
plt.tight_layout()
ax.set_title(title[i])
ax.imshow(np.reshape(np.array(list(t_sample.getdata())), (-1,28)), cmap='gray')

plt.show()

最佳答案

恐怕没有简单的方法可以解决这个问题:Torchvision 的随机变换实用程序的构建方式是在调用时对变换参数进行采样。它们是唯一的随机变换,因为 (1) 用户无法访问使用的参数和 (2) 相同的随机变换是 不是 可重复。
从 Torchvision 0.8.0 开始,随机变换通常使用两个主要功能构建:

  • get_params :这将基于变换的超参数进行采样(您在初始化变换运算符时提供的内容,即参数的值范围)
  • forward :应用转换时执行的函数。重要的部分是它从 get_params 获取参数。然后使用关联的确定性函数将其应用于输入。对于 RandomRotation , F.rotate 会被调用。同样, RandomAffine 将使用 F.affine .

  • 您的问题的一种解决方案是从 get_params 中采样参数。自己并调用功能性 - 确定性 - API。所以你不会使用 RandomRotation , RandomAffine ,也没有任何其他 Random*转型。

    例如,让我们看看 T.RandomRotation (为了简洁起见,我删除了评论)。
    class RandomRotation(torch.nn.Module):
    def __init__(
    self, degrees, interpolation=InterpolationMode.NEAREST, expand=False,
    center=None, fill=None, resample=None):
    # ...

    @staticmethod
    def get_params(degrees: List[float]) -> float:
    angle = float(torch.empty(1).uniform_(float(degrees[0]), \
    float(degrees[1])).item())
    return angle

    def forward(self, img):
    fill = self.fill
    if isinstance(img, Tensor):
    if isinstance(fill, (int, float)):
    fill = [float(fill)] * F._get_image_num_channels(img)
    else:
    fill = [float(f) for f in fill]
    angle = self.get_params(self.degrees)

    return F.rotate(img, angle, self.resample, self.expand, self.center, fill)

    def __repr__(self):
    # ...
    考虑到这一点,这里有一个可能的覆盖来修改 T.RandomRotation :
    class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
    super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

    self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
    fill = self.fill
    if isinstance(img, Tensor):
    if isinstance(fill, (int, float)):
    fill = [float(fill)] * F._get_image_num_channels(img)
    else:
    fill = [float(f) for f in fill]

    return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)
    我基本上复制了 T.RandomRotationforward函数,唯一不同的是参数在 __init__中采样(即一次)而不是在 forward 内(即每次通话)。 Torchvision 的实现涵盖了所有情况,您通常不需要复制完整的 forward .在某些情况下,您可以直接调用功能版本。例如,如果您不需要设置 fill参数,您可以丢弃该部分并仅使用:
    class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
    super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

    self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
    return F.rotate(img, self.angle, self.resample, self.expand, self.center)

    如果您想覆盖其他随机变换,您可以查看 the source code . API 是不言自明的,您应该不会在为每个转换实现覆盖时遇到太多问题。

    关于python - 有没有办法检索在随机 Torchvision 变换中使用的特定参数?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65906171/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com