gpt4 book ai didi

PyTorch:使用 backward() 时,如何只保留图形的一部分?

转载 作者:行者123 更新时间:2023-12-05 09:16:20 27 4
gpt4 key购买 nike

我有一个 PyTorch 计算图,它由一个执行某些计算的子图组成,这个计算的结果(我们称它为 x)然后被分支到另外两个子图中。这两个子图的每一个都会产生一些标量结果(我们称它们为 y1y2)。我想对这两个结果中的每一个做一个反向传递(也就是说,我想累加两个子图的梯度。我不想执行实际的优化步骤)。

现在,由于内存是一个问题,我想按以下顺序执行操作:首先,计算x。然后,计算y1,并执行y1.backward(),同时(这是关键点)保留导致x的图code>,但将图形从 x 释放到 y1。然后,计算y2,并执行y2.backward()

换句话说,为了在不牺牲太多速度的情况下节省内存,我想保留 x 而不需要重新计算它,但我想删除所有从 x 开始的计算> 到 y1,因为我不再需要它们了。

问题是函数 backward() 的参数 retain_graph 将保留通向 y1 的整个图,而我需要只保留图表中通向 x 的部分。

这是我理想中想要的示例:

import torch

w = torch.tensor(1.0)
w.requires_grad_(True)

# sub-graph for calculating `x`
x = w+10

# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=x) # this would not work, since retain_graph is a boolean and can either retain the entire graph or free it.

# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()

如何做到这一点?

最佳答案

参数 retain_graph 将保留整个图,而不仅仅是一个子图。但是,我们可以使用垃圾收集来释放图中不需要的部分。通过删除从 xy1 对子图的所有引用,该子图将被释放:

import torch

w = torch.tensor(1.0)
w.requires_grad_(True)

# sub-graph for calculating `x`
x = w+10

# sub-graph for calculating `y1`
x1 = x*x
y1 = x1*x1
y1.backward(retain_graph=True) # all graph is retained

# remove unneeded parts of graph. Note that these parts will be freed from memory (even if they were on GPU), due to python's garbage collection
y1 = None
x1 = None

# sub-graph for calculating `y2`
x2 = torch.sqrt(x)
y2 = x2/2
y2.backward()

关于PyTorch:使用 backward() 时,如何只保留图形的一部分?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50741344/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com