gpt4 book ai didi

python - 如何将关键字参数传递给 pre-forward 钩子(Hook)使用的 forward?

转载 作者:行者123 更新时间:2023-12-04 12:15:56 29 4
gpt4 key购买 nike

给定一个手电筒的nn.Module带有预前钩,例如

import torch
import torch.nn as nn

class NeoEmbeddings(nn.Embedding):
def __init__(self, num_embeddings:int, embedding_dim:int, padding_idx=-1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(self.neo_genesis)

@staticmethod
def neo_genesis(self, input, higgs_bosson=0):
if higgs_bosson:
input = input + higgs_bosson
return input

在进入实际的 forward() 之前,可以让输入张量经过一些操作。功能,例如
>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]))
tensor([[-1.6449, 0.5832, -0.0165, -1.3329, 0.6878],
[-0.3262, 0.5844, 0.6917, 0.1268, 2.1363],
[ 1.0772, 0.1748, -0.7131, 0.7405, 1.5733],
[ 0.7651, 0.4619, 0.4388, -0.2752, -0.3018]],
grad_fn=<EmbeddingBackward>)

>>> print(x._forward_pre_hooks)
OrderedDict([(25, <function NeoEmbeddings.neo_genesis at 0x1208d10d0>)])

我们如何传递 pre-forward 钩子(Hook)需要但默认不接受的参数( *args**kwargs ) forward()功能?

无需修改/覆盖 forward()功能,这是不可能的:
>>> x = NeoEmbeddings(10, 5, 1)
>>> x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

----------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-102-8705a40a3cc2> in <module>
1 x = NeoEmbeddings(10, 5, 1)
----> 2 x.forward(torch.tensor([0,2,5,8]), higgs_bosson=2)

TypeError: forward() got an unexpected keyword argument 'higgs_bosson'

最佳答案

Torchscript 不兼容(从 1.2.0 开始)

首先,你的例子torch.nn.Module有一些小错误(可能是意外)。

其次,可以通过任何东西转发和register_forward_pre_hook只会得到将传递给您的参数torch.nn.Module (无论是图层还是模型或其他任何东西)。你确实做不到不修改forward打电话,但你为什么要避免这种情况?您可以简单地将参数转发给基本函数,如下所示:

import torch


class NeoEmbeddings(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

# First argument should be named something like module, as that's what
# you are registering this hook to
@staticmethod
def neo_genesis(module, inputs): # No need for self as first argument
net_input, higgs_bosson = inputs # Simply unpack tuple here
return net_input

def forward(self, inputs, higgs_bosson):
# Do whatever you want here with both arguments, you can ignore
# higgs_bosson if it's only needed in the hook as done here
return super().forward(inputs)


if __name__ == "__main__":
x = NeoEmbeddings(10, 5, 1)
# You should call () instead of forward so the hooks register appropriately
print(x(torch.tensor([0, 2, 5, 8]), 1))

你不能用更简洁的方式来做,但限制是基类 forward方法,而不是钩子(Hook)本身(而且我不希望它更简洁,因为它会变得不可读 IMO)。

火炬脚本兼容

如果你想使用 torchscript(在 1.2.0 上测试)你可以使用组合而不是继承。您只需更改两行代码,您的代码可能如下所示:
import torch

# Inherit from Module and register embedding as submodule
class NeoEmbeddings(torch.nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx=-1):
super().__init__()
# Just use it as a container inside your own class
self._embedding = torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx)
self.register_forward_pre_hook(NeoEmbeddings.neo_genesis)

@staticmethod
def neo_genesis(module, inputs):
net_input, higgs_bosson = inputs
return net_input

def forward(self, inputs: torch.Tensor, higgs_bosson: torch.Tensor):
return self._embedding(inputs)


if __name__ == "__main__":
x = torch.jit.script(NeoEmbeddings(10, 5, 1))
# All arguments must be tensors in torchscript
print(x(torch.tensor([0, 2, 5, 8]), torch.tensor([1])))

关于python - 如何将关键字参数传递给 pre-forward 钩子(Hook)使用的 forward?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57703808/

29 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com