gpt4 book ai didi

python - 修复 torchvision 变换的随机种子

转载 作者:行者123 更新时间:2023-12-01 06:39:37 25 4
gpt4 key购买 nike

我使用一些类似于以下的代码 - 用于数据增强:

    from torchvision import transforms

#...

augmentation = transforms.Compose([
transforms.RandomApply([
transforms.RandomRotation([-30, 30])
], p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
])

在测试过程中,我想修复随机值,以在每次更改模型训练设置时重现相同的随机参数。我该怎么做?

我想做一些类似于np.random.seed(0)的事情,所以每次我第一次用概率调用随机函数时,它都会以相同的旋转角度和概率运行。换句话说,如果我根本不更改代码,那么当我重新运行它时,它必须重现相同的结果。

或者,我可以分离变换,使用p=1,将角度minmax固定为特定值并使用numpy随机数生成结果,但我的问题是我是否可以保持上面的代码不变。

最佳答案

在数据集类的 __getitem__ 中创建一个 numpy 随机种子。

def __getitem__(self, index):      
img = io.imread(self.labels.iloc[index,0])
target = self.labels.iloc[index,1]

seed = np.random.randint(2147483647) # make a seed with numpy generator
random.seed(seed) # apply this seed to img transforms
if self.transform is not None:
img = self.transform(img)

random.seed(seed) # apply this seed to target transforms
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

关于python - 修复 torchvision 变换的随机种子,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59516181/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com