gpt4 book ai didi

tensorflow - Pytorch如何得到损失函数的梯度两次

转载 作者:行者123 更新时间:2023-12-04 13:44:21 26 4
gpt4 key购买 nike

这是我正在尝试实现的内容:

我们根据 F(X) 计算损失, 照常。但我们也定义了“对抗性损失”,这是基于 F(X + e) 的损失。 . e定义为 dF(X)/dX乘以某个常数。损失和对抗性损失都是反向传播的总损失。

在 tensorflow 中,这部分(得到 dF(X)/dX)可以像下面这样编码:

  grad, = tf.gradients( loss, X )
grad = tf.stop_gradient(grad)
e = constant * grad

下面是我的pytorch代码:
class DocReaderModel(object):
def __init__(self, embedding=None, state_dict=None):
self.train_loss = AverageMeter()
self.embedding = embedding
self.network = DNetwork(opt, embedding)
self.optimizer = optim.SGD(parameters)

def adversarial_loss(self, batch, loss, embedding, y):
self.optimizer.zero_grad()
loss.backward(retain_graph=True)
grad = embedding.grad
grad.detach_()

perturb = F.normalize(grad, p=2)* 0.5
self.optimizer.zero_grad()
adv_embedding = embedding + perturb
network_temp = DNetwork(self.opt, adv_embedding) # This is how to get F(X)
network_temp.training = False
network_temp.cuda()
start, end, _ = network_temp(batch) # This is how to get F(X)
del network_temp # I even deleted this instance.
return F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])

def update(self, batch):
self.network.train()
start, end, pred = self.network(batch)
loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y)
loss_total = loss + loss_adv

self.optimizer.zero_grad()
loss_total.backward()
self.optimizer.step()

我有几个问题:

1) 我用 grad.detach_() 替换了 tf.stop_gradient。这样对吗?

2) 我收到了 "RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time."所以我加了 retain_graph=Trueloss.backward .那个特定的错误消失了。
但是现在我在几个时期后收到内存错误( RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58)。我怀疑我不必要地保留了图表。

有人可以让我知道 pytorch 在这方面的最佳实践吗?任何提示/甚至简短的评论都将受到高度赞赏。

最佳答案

我认为您正在尝试实现生成对抗网络 (GAN),但是从代码来看,我不理解也无法理解您要实现的目标,因为 GAN 有一些缺失的部分无法正常工作。我可以看到有一个鉴别器网络模块,DNetwork但缺少生成器网络模块。

如果猜测,当你说“损失函数两次”时,我假设你的意思是你有一个用于鉴别器网络的损失函数和另一个用于生成器网络的损失函数。如果是这样的话,让我分享一下我将如何实现一个基本的 GAN 模型。

举个例子,我们看看这个Wasserstein GAN Jupyter notebook

我将跳过不太重要的部分并在此处放大重要的部分:

  • 一、导入PyTorch库并设置
    # Set up batch size, image size, and size of noise vector:
    bs, sz, nz = 64, 64, 100 # nz is the size of the latent z vector for creating some random noise later
  • 构建鉴别器模块
    class DCGAN_D(nn.Module):
    def __init__(self):
    ... truncated, the usual neural nets stuffs, layers, etc ...
    def forward(self, input):
    ... truncated, the usual neural nets stuffs, layers, etc ...
  • 构建生成器模块
    class DCGAN_G(nn.Module):
    def __init__(self):
    ... truncated, the usual neural nets stuffs, layers, etc ...
    def forward(self, input):
    ... truncated, the usual neural nets stuffs, layers, etc ...
  • 把它们放在一起
    netG = DCGAN_G().cuda()
    netD = DCGAN_D().cuda()
  • 优化器需要被告知要优化哪些变量。模块自动跟踪其变量。
    optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
    optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)
  • Discriminator 向前一步和向后一步

    在这里,网络可以在反向传播期间计算梯度,这取决于此函数的输入。所以,就我而言,我有 3 种损失;生成器损失,判别器真实图像损失,判别器假图像损失。对于 3 次不同的网络传递,我可以获得三次损失函数的梯度。
    def step_D(input, init_grad):
    # input can be from generator's generated image data or input image from dataset
    err = netD(input)
    err.backward(init_grad) # backward pass net to calculate gradient
    return err # loss
  • 控制可训练参数 [重要的]

    模型中的可训练参数是那些需要梯度的参数。
    def make_trainable(net, val):
    for p in net.parameters():
    p.requires_grad = val # note, i.e, this is later set to False below in netG update in the train loop.

    在 TensorFlow 中,这部分可以像下面这样编码:
    grad = tf.gradients(loss, X)
    grad = tf.stop_gradient(grad)

    所以,我认为这将回答你的第一个问题,“我用 grad.detach_() 替换了 tf.stop_gradient。这是正确的吗?”
  • 火车环线

  • 您可以在此处查看如何调用 3 个不同的损失函数。
        def train(niter, first=True):

    for epoch in range(niter):
    # Make iterable from PyTorch DataLoader
    data_iter = iter(dataloader)
    i = 0

    while i < n:
    ###########################
    # (1) Update D network
    ###########################
    make_trainable(netD, True)

    # train the discriminator d_iters times
    d_iters = 100

    j = 0

    while j < d_iters and i < n:
    j += 1
    i += 1

    # clamp parameters to a cube
    for p in netD.parameters():
    p.data.clamp_(-0.01, 0.01)

    data = next(data_iter)

    ##### train with real #####
    real_cpu, _ = data
    real_cpu = real_cpu.cuda()
    real = Variable( data[0].cuda() )
    netD.zero_grad()

    # Real image discriminator loss
    errD_real = step_D(real, one)

    ##### train with fake #####
    fake = netG(create_noise(real.size()[0]))
    input.data.resize_(real.size()).copy_(fake.data)

    # Fake image discriminator loss
    errD_fake = step_D(input, mone)

    # Discriminator loss
    errD = errD_real - errD_fake
    optimizerD.step()

    ###########################
    # (2) Update G network
    ###########################
    make_trainable(netD, False)
    netG.zero_grad()

    # Generator loss
    errG = step_D(netG(create_noise(bs)), one)
    optimizerG.step()

    print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
    % (epoch, niter, i, n,
    errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

    "I was getting "RuntimeError: Trying to backward through the graph a second time..."



    PyTorch 有这种行为;在 .backward() 期间减少 GPU 内存使用量调用,所有中间结果(如果您喜欢保存的激活等)在不再需要时将被删除。因此,如果您尝试拨打 .backward()同样,中间结果不存在并且无法执行向后传递(并且您会看到您看到的错误)。

    这取决于您要尝试做什么。您可以拨打 .backward(retain_graph=True)进行不会删除中间结果的反向传递,因此您可以拨打 .backward()再次。除了最后一次向后调用之外的所有调用都应该具有 retain_graph=True选项。

    Can someone let me know pytorch's best practice on this



    正如您从上面的 PyTorch 代码以及在 PyTorch 中尝试保持 Pythonic 所做的事情所看到的那样,您可以从中了解 PyTorch 的最佳实践。

    关于tensorflow - Pytorch如何得到损失函数的梯度两次,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51578235/

    26 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com