gpt4 book ai didi

python - 推导pytorch网络的结构

转载 作者:太空宇宙 更新时间:2023-11-04 01:49:57 24 4
gpt4 key购买 nike

对于我的用例,我需要能够采用 pytorch 模块并解释模块中层的顺序,以便我可以以某种文件格式在层之间创建“连接”。现在假设我有一个简单的模块,如下所示

class mymodel(nn.Module):
def __init__(self, input_channels):
super(mymodel, self).__init__()
self.fc = nn.Linear(input_channels, input_channels)
def forward(self, x):
out = self.fc(x)
out += x
return out


if __name__ == "__main__":
net = mymodel(5)

for mod in net.modules():
print(mod)

这里的输出结果是:

mymodel(
(fc): Linear(in_features=5, out_features=5, bias=True)
)
Linear(in_features=5, out_features=5, bias=True)

如您所见,关于加等于操作或加操作的信息未被捕获,因为它不是转发函数中的 nnmodule。我的目标是能够从 pytorch 模块对象创建一个图形连接,以在 json 中说这样的话:

layers {
"fc": {
"inputTensor" : "t0",
"outputTensor": "t1"
}
"addOp" : {
"inputTensor" : "t1",
"outputTensor" : "t2"
}
}

输入张量名称是任意的,但它捕获了图形的本质和层之间的连接。

我的问题是,有没有办法从 pytorch 对象中提取信息?我正在考虑使用 .modules() 但后来意识到手写操作不会以这种方式作为模块捕获。我想如果一切都是 nn.module 那么 .modules() 可能会给我网络层安排。在这里寻求帮助。我希望能够知道张量之间的联系以创建上述格式。

最佳答案

您要查找的信息并未存储在 nn.Module 中,而是存储在输出张量的 grad_fn 属性中:

model = mymodel(channels)
pred = model(torch.rand((1, channels))
pred.grad_fn # all the information is in the computation graph of the output tensor

提取此信息并非易事。你可能想看看 torchvizgrad_fn 信息中绘制漂亮图形的包。

关于python - 推导pytorch网络的结构,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58253003/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com