- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章PyTorch训练LSTM时loss.backward()报错的解决方案由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
训练用PyTorch编写的LSTM或RNN时,在loss.backward()上报错:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. 。
千万别改成loss.backward(retain_graph=True),会导致显卡内存随着训练一直增加直到OOM:
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.73 GiB total capacity; 9.79 GiB already allocated; 13.62 MiB free; 162.76 MiB cached) 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
LSRM
/
RNN模块初始化时定义好hidden,每次forward都要加上
self
.hidden
=
self
.init_hidden():
Class LSTMClassifier(nn.Module):
def
__init__(
self
, embedding_dim, hidden_dim):
# 此次省略其它代码
self
.rnn_cell
=
nn.LSTM(embedding_dim, hidden_dim)
self
.hidden
=
self
.init_hidden()
# 此次省略其它代码
def
init_hidden(
self
):
# 开始时刻, 没有隐状态
# 关于维度设置的详情,请参考 Pytorch 文档
# 各个维度的含义是 (Seguence, minibatch_size, hidden_dim)
return
(torch.zeros(
1
,
1
,
self
.hidden_dim),
torch.zeros(
1
,
1
,
self
.hidden_dim))
def
forward(
self
, x):
# 此次省略其它代码
self
.hidden
=
self
.init_hidden()
# 就是加上这句!!!!
out,
self
.hidden
=
self
.rnn_cell(x,
self
.hidden)
# 此次省略其它代码
return
out
|
或者其它模块每次调用这个模块时,其它模块的forward()都对这个LSTM模块init_hidden()一下.
1
2
3
4
5
6
7
8
9
10
11
12
|
Class LSTM_Model(nn.Module):
def
__init__(
self
, embedding_dim, hidden_dim):
# 此次省略其它代码
self
.rnn
=
LSTMClassifier(embedding_dim, hidden_dim)
# 此次省略其它代码
def
forward(
self
, x):
# 此次省略其它代码
self
.rnn.hidden
=
self
.rnn.init_hidden()
# 就是加上这句!!!!
out
=
self
.rnn(x)
# 此次省略其它代码
return
out
|
这是因为:
根据 官方tutorial,在 loss 反向传播的时候,pytorch 试图把 hidden state 也反向传播,但是在新的一轮 batch 的时候 hidden state 已经被内存释放了,所以需要每个 batch 重新 init (clean out hidden state), 或者 detach,从而切断反向传播.
补充:pytorch:在执行loss.backward()时out of memory报错 。
在自己编写SurfNet网络的过程中,出现了这个问题,查阅资料后,将得到的解决方法汇总如下 。
1、reduce batch size, all the way down to 1 。
2、remove everything to CPU leaving only the network on the GPU 。
3、remove validation code, and only executing the training code 。
4、reduce the size of the network (I reduced it significantly: details below) 。
5、I tried scaling the magnitude of the loss that is backpropagating as well to a much smaller value 。
在训练时,在每一个step后面加上:
1
|
torch.cuda.empty_cache()
|
在每一个验证时的step之后加上代码:
1
|
with torch.no_grad()
|
不要在循环训练中累积历史记录 。
1
2
3
4
5
6
7
8
|
total_loss
=
0
for
i
in
range
(
10000
):
optimizer.zero_grad()
output
=
model(
input
)
loss
=
criterion(output)
loss.backward()
optimizer.step()
total_loss
+
=
loss
|
total_loss在循环中进行了累计,因为loss是一个具有autograd历史的可微变量。你可以通过编写total_loss += float(loss)来解决这个问题.
本人遇到这个问题的原因是,自己构建的模型输入到全连接层中的特征图拉伸为1维向量时太大导致的,加入pool层或者其他方法将最后的卷积层输出的特征图尺寸减小即可.
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/qq_31375855/article/details/107568057 。
最后此篇关于PyTorch训练LSTM时loss.backward()报错的解决方案的文章就讲到这里了,如果你想了解更多关于PyTorch训练LSTM时loss.backward()报错的解决方案的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我是学习深度学习的新手,我一直在努力理解 Pytorch 中的“.backward()”是做什么的,因为它几乎完成了那里的大部分工作。因此,我试图详细了解反向函数的作用,因此,我将尝试逐步编写该函数的
就目前而言,这个问题不适合我们的问答形式。我们希望答案得到事实、引用或专业知识的支持,但这个问题可能会引起辩论、争论、投票或扩展讨论。如果您觉得这个问题可以改进并可能重新打开,visit the he
我正在开发一个同时具有 GUI(图形)和 API(脚本)界面的应用程序。我们的产品有一个非常大的安装基础。许多客户投入了大量时间和精力来编写使用我们产品的脚本。 在我们所有的设计和实现中,我们(可以理
我从事的项目以源代码和二进制形式免费分发,因为我们的许多用户需要专门为他们的系统编译它。这需要一定程度的考虑,以保持与旧主机系统(主要是它们的编译器)的向后兼容性。 其中一些最糟糕的,例如 GCC 3
我需要在 Python 中创建一个树状结构。我有一个函数 get(parentId),它返回具有该父级的对象列表——我认为应该递归地完成。 结果应该是这样的:["root object", ["chi
我正在尝试为移动菜单制作动画,因此当点击它时三条纹应该变成“X”,当关闭菜单时它应该恢复为三条纹。 “X”的动画效果非常好,但关闭时它会跳回三条纹,而不是平滑地过渡回它。这是因为我当然完全删除了 x
训练用PyTorch编写的LSTM或RNN时,在loss.backward()上报错: RuntimeError: Trying to backward through the graph
为什么参数到atan2功能是“倒退”的?即,为什么它接受 y, x 形式的坐标而不是标准 x, y ? 最佳答案 因为它类似于atan(y / x) , 与 y作为分子和 x作为分母。 关于math
data.table很棒,因为我可以进行滚动连接,甚至可以在组内进行滚动连接! library(data.table) set.seed(42) metrics metrics[calendar,r
我正在开发一个 grails 插件,它添加了一个新的标签库并呈现了一个模板。 我的问题是这个模板输出一个 JSON 并且我不需要它被编码。 如果我使用 raw() 函数,它工作正常,但这与 grail
我想匹配字符串中的数字:'abc@2003,或其他@2017'我想通过 match 函数获得结果 [2003, 2007]。 let strReg = 'abc@2003, or something
Douglas Crockford,因此,JSLint 真的不喜欢 for 循环。大多数时候我同意,并使用 [].forEach 和 Object.keys(obj) 遍历数组或字典。 但是,在某些情
这个问题在这里已经有了答案: Difference between null == x and x == null? [duplicate] (5 个答案) 关闭 7 年前。 我正在查看一些代码,发
我有一个 PyTorch 计算图,它由一个执行某些计算的子图组成,这个计算的结果(我们称它为 x)然后被分支到另外两个子图中。这两个子图的每一个都会产生一些标量结果(我们称它们为 y1 和 y2)。我
不是从前面读取文件,而是可以向后读取它吗?这样输出是从文件的后面到文件的前面。 编辑:最后一行首先显示,而不是完全向后显示。 最佳答案 这是可能的,但很麻烦。 Lua API 提供 seek函数来设置
我用 Storyboard 创建了我的应用程序,直到很晚才意识到 iOS4.3 设备将不受支持。是否有一个选项可以将 Storyboard View 复制到 xib 中或在运行时以编程方式执行此操作?
我为那些犯了同样错误的人发布了这个问题。尝试计算梯度时出现此错误: criterion = torch.nn.CrossEntropyLoss() loss = criterion(y_hat, y_
我一直在阅读this article并且在他们的一节中指出: Lenses compose backwards. Can't we make (.) behave like functions? Yo
编辑:这对我来说是“倒退” - 我可能缺少一些直觉 给定一个 glm 变换函数,例如 glm::translate,两个参数首先是一个矩阵 m,然后是一个用于平移的 vector v。 直观上,我希望
我正在使用调查图/图遍历 compile group: 'org.jgrapht', name: 'jgrapht-core', version: '1.1.0' 使用下面的代码我可以创建一个简单的图
我是一名优秀的程序员,十分优秀!