gpt4 book ai didi

python - PyTorch:不一致的预训练 VGG 输出

转载 作者:太空宇宙 更新时间:2023-11-04 01:59:36 26 4
gpt4 key购买 nike

当使用 torchvision.models 模块加载预训练的 VGG 网络并使用它对任意 RGB 图像进行分类时,网络的输出在每次调用之间明显不同。为什么会这样?根据我的理解,VGG 前向传播的任何部分都不应该是不确定的。

这是一个 MCVE:

import torch
from torchvision.models import vgg16

vgg = vgg16(pretrained=True)

img = torch.randn(1, 3, 256, 256)

torch.all(torch.eq(vgg(img), vgg(img))) # result is 0, but why?

最佳答案

vgg16 有一个 nn.Dropout层,在训练期间,随机丢弃其输入的 50%。在测试期间,您应该通过将网络模式设置为“评估”模式来“关闭”此行为:

vgg.eval()
torch.all(torch.eq(vgg(img), vgg(img)))
Out[73]: tensor(1, dtype=torch.uint8)

请注意,还有其他具有随机行为和不同行为的层用于训练和评估(例如,BatchNorm)。因此,切换到 eval() 很重要评估训练模型之前的模式。

关于python - PyTorch:不一致的预训练 VGG 输出,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56003198/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com