gpt4 book ai didi

python - 如何在 Pytorch 中可视化网络?

转载 作者:太空狗 更新时间:2023-10-29 21:04:34 44 4
gpt4 key购买 nike

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

我想从 pytorch 模型中可视化 resnet。我该怎么做?我尝试使用 torchviz 但出现错误:

'ResNet' object has no attribute 'grad_fn'

最佳答案

这是使用不同工具的三种不同图形可视化。

为了生成示例可视化效果,我将使用一个简单的 RNN 来执行从 online tutorial 中获取的情绪分析。 :

class RNN(nn.Module):

def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)

def forward(self, text):

embedding = self.embedding(text)
output, hidden = self.rnn(embedding)

return self.fc(hidden.squeeze(0))

如果您print() 模型,这里是输出。

RNN(
(embedding): Embedding(25002, 100)
(rnn): RNN(100, 256)
(fc): Linear(in_features=256, out_features=1, bias=True)
)

以下是三种不同可视化工具的结果。

对于所有这些,您需要有可以通过模型的 forward() 方法的虚拟输入。获取此输入的一种简单方法是从您的 Dataloader 中检索一个批处理,如下所示:

batch = next(iter(dataloader_train))
yhat = model(batch.text) # Give dummy batch to forward().

torch 可视化

https://github.com/szagoruyko/pytorchviz

我相信此工具使用向后传递生成其图形,因此所有框都使用 PyTorch 组件进行反向传播。

from torchviz import make_dot

make_dot(yhat, params=dict(list(model.named_parameters()))).render("rnn_torchviz", format="png")

此工具生成以下输出文件:

torchviz output

这是唯一明确提到我的模型中的三个层的输出,embeddingrnnfc。运算符名称取自反向传递,因此其中一些难以理解。

隐藏层

https://github.com/waleedka/hiddenlayer

我相信这个工具使用前向传递。

import hiddenlayer as hl

transforms = [ hl.transforms.Prune('Constant') ] # Removes Constant nodes from graph.

graph = hl.build_graph(model, batch.text, transforms=transforms)
graph.theme = hl.graph.THEMES['blue'].copy()
graph.save('rnn_hiddenlayer', format='png')

这是输出。我喜欢蓝色的阴影。

hiddenlayer output

我发现输出的细节太多,混淆了我的架构。例如,为什么多次提到 unsqueeze

耐创

https://github.com/lutzroeder/netron

此工具是适用于 Mac、Windows 和 Linux 的桌面应用程序。它依赖于首先导出到 ONNX format 的模型.然后应用程序读取 ONNX 文件并呈现它。然后可以选择将模型导出到图像文件。

input_names = ['Sentence']
output_names = ['yhat']
torch.onnx.export(model, batch.text, 'rnn.onnx', input_names=input_names, output_names=output_names)

这是模型在应用程序中的样子。我认为这个工具非常灵巧:您可以缩放和平移,还可以钻取图层和运算符。我发现的唯一缺点是它只能进行垂直布局。

Netron screenshot

关于python - 如何在 Pytorch 中可视化网络?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52468956/

44 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com