作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
如何从预训练的 PyTorch 模型(例如 ResNet 或 VGG)中提取特定层的特征,而无需再次进行前向传递?
最佳答案
编辑: there's a new feature in torchvision v0.11.0 that allows extracting features .
例如,如果您想从层 layer4.2.relu_2
中提取特征,您可以这样做:
import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import create_feature_extractor
x = torch.rand(1, 3, 224, 224)
model = resnet50()
return_nodes = {
"layer4.2.relu_2": "layer4"
}
model2 = create_feature_extractor(model, return_nodes=return_nodes)
intermediate_outputs = model2(x)
您可以注册forward hook在您想要的特定图层上。像这样的东西:
def some_specific_layer_hook(module, input_, output):
pass # the value is in 'output'
model.some_specific_layer.register_forward_hook(some_specific_layer_hook)
model(some_input)
例如,要获取 ResNet 中的 res5c
输出,您可能需要使用 nonlocal
变量(或 Python 2 中的 global
) :
res5c_output = None
def res5c_hook(module, input_, output):
nonlocal res5c_output
res5c_output = output
resnet.layer4.register_forward_hook(res5c_hook)
resnet(some_input)
# Then, use `res5c_output`.
关于python - 如何从 PyTorch 模型的特定层获取输出?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52796121/
我是一名优秀的程序员,十分优秀!