gpt4 book ai didi

python-3.x - PyTorch:用于训练和测试/验证的不同前向方法

转载 作者:行者123 更新时间:2023-12-04 02:42:31 25 4
gpt4 key购买 nike

我目前正在尝试扩展 a model这是基于 FairSeq/PyTorch 的。在训练期间,我需要训练两个编码器:一个带有目标样本,一个带有源样本。

所以当前的 forward 函数看起来像这样:

def forward(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out

并基于此 this idea我想要这样的东西:
def forward_test(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out

def forward_train(self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
decoder_out = self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)
return decoder_out

有没有办法做到这一点?

编辑:
这些是我的约束,因为我需要扩展 FairseqEncoderDecoderModel:
@register_model('transformer_mass')
class TransformerMASSModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)

编辑2:
在 Fairseq 中传递给 forward 函数的参数可以通过实现你自己的 Criterion 来改变,例如见 CrossEntropyCriterion , 其中 sample['net_input']传递给 __call__模型的函数,它调用 forward方法。

最佳答案

首先你应该始终使用和定义 forward 不是您在 torch.nn.Module 上调用的其他方法实例。

绝对不要重载eval()trsvchn所示因为它是由 PyTorch (see here) 定义的评估方法。 此方法允许将模型内的层置于评估模式(例如,对层的特定更改,如 DropoutBatchNorm 的推理模式)。

此外,您应该使用 __call__ 调用它魔术方法。为什么?因为钩子(Hook)和其他 PyTorch 特定的东西是正确注册的。

其次,不要使用一些外部的mode @Anant Mittal 建议的字符串变量 .就是这样 train PyTorch 中的变量是 for,标准是通过它来区分模型是否在 eval模式或 train模式。

话虽这么说,你最好这样做:

import torch


class Network(torch.nn.Module):
def __init__(self):
super().__init__()
...

# You could split it into two functions but both should be called by forward
def forward(
self, src_tokens=None, src_lengths=None, prev_output_tokens=None, **kwargs
):
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
if self.train:
return self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
autoencoder_out = self.encoder(tgt_tokens, src_lengths=src_lengths, **kwargs)
concat = some_concatination_func(encoder_out, autoencoder_out)
return self.decoder(prev_output_tokens, encoder_out=concat, **kwargs)

您可以(并且可以说应该)将上述内容拆分为两个单独的方法,但这并不算太糟糕,因为该函数相当短且可读。如果可能的话,只要坚持 PyTorch 的处理方式,而不是一些临时解决方案。不,反向传播不会有问题,为什么会有一个?

关于python-3.x - PyTorch:用于训练和测试/验证的不同前向方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58655207/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com