- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章PyTorch中clone()、detach()及相关扩展详解由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
clone() 与 detach() 对比 。
Torch 为了提高速度,向量或是矩阵的赋值是指向同一内存的,这不同于 Matlab。如果需要保存旧的tensor即需要开辟新的存储地址而不是引用,可以用 clone() 进行深拷贝, 。
首先我们来打印出来clone()操作后的数据类型定义变化:
(1). 简单打印类型 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
import
torch
a
=
torch.tensor(
1.0
, requires_grad
=
True
)
b
=
a.clone()
c
=
a.detach()
a.data
*
=
3
b
+
=
1
print
(a)
# tensor(3., requires_grad=True)
print
(b)
print
(c)
'''
输出结果:
tensor(3., requires_grad=True)
tensor(2., grad_fn=<AddBackward0>)
tensor(3.) # detach()后的值随着a的变化出现变化
'''
|
grad_fn=<CloneBackward>,表示clone后的返回值是个中间变量,因此支持梯度的回溯。clone操作在一定程度上可以视为是一个identity-mapping函数.
detach()操作后的tensor与原始tensor共享数据内存,当原始tensor在计算图中数值发生反向传播等更新之后,detach()的tensor值也发生了改变.
注意: 在pytorch中我们不要直接使用id是否相等来判断tensor是否共享内存,这只是充分条件,因为也许底层共享数据内存,但是仍然是新的tensor,比如detach(),如果我们直接打印id会出现以下情况.
1
2
3
4
5
6
7
8
|
import
torch as t
a
=
t.tensor([
1.0
,
2.0
], requires_grad
=
True
)
b
=
a.detach()
#c[:] = a.detach()
print
(
id
(a))
print
(
id
(b))
#140568935450520
140570337203616
|
显然直接打印出来的id不等,我们可以通过简单的赋值后观察数据变化进行判断.
(2). clone()的梯度回传 。
detach()函数可以返回一个完全相同的tensor,与旧的tensor共享内存,脱离计算图,不会牵扯梯度计算.
而clone充当中间变量,会将梯度传给源张量进行叠加,但是本身不保存其grad,即值为None 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
import
torch
a
=
torch.tensor(
1.0
, requires_grad
=
True
)
a_
=
a.clone()
y
=
a
*
*
2
z
=
a
*
*
2
+
a_
*
3
y.backward()
print
(a.grad)
# 2
z.backward()
print
(a_.grad)
# None. 中间variable,无grad
print
(a.grad)
'''
输出:
tensor(2.)
None
tensor(7.) # 2*2+3=7
'''
|
使用torch.clone()获得的新tensor和原来的数据不再共享内存,但仍保留在计算图中,clone操作在不共享数据内存的同时支持梯度梯度传递与叠加,所以常用在神经网络中某个单元需要重复使用的场景下.
通常如果原tensor的requires_grad=True,则:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
import
torch
torch.manual_seed(
0
)
x
=
torch.tensor([
1.
,
2.
], requires_grad
=
True
)
clone_x
=
x.clone()
detach_x
=
x.detach()
clone_detach_x
=
x.clone().detach()
f
=
torch.nn.Linear(
2
,
1
)
y
=
f(x)
y.backward()
print
(x.grad)
print
(clone_x.requires_grad)
print
(clone_x.grad)
print
(detach_x.requires_grad)
print
(clone_detach_x.requires_grad)
'''
输出结果如下:
tensor([-0.0053, 0.3793])
True
None
False
False
'''
|
另一个比较特殊的是当源张量的 require_grad=False,clone后的张量 require_grad=True,此时不存在张量回传现象,可以得到clone后的张量求导.
如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import
torch
a
=
torch.tensor(
1.0
)
a_
=
a.clone()
a_.requires_grad_()
#require_grad=True
y
=
a_
*
*
2
y.backward()
print
(a.grad)
# None
print
(a_.grad)
'''
输出:
None
tensor(2.)
'''
|
了解了两者的区别后我们常与其他函数进行搭配使用,实现数据拷贝后的其他需要.
比如我们经常使用view()函数对tensor进行reshape操作。返回的新Tensor与源Tensor可能有不同的size,但是是共享data的,即其中的一个发生变化,另外一个也会跟着改变.
需要注意的是view返回的Tensor与源Tensor是共享data的,但是依然是一个新的Tensor(因为Tensor除了包含data外还有一些其他属性),两者id(内存地址)并不一致.
1
2
3
4
5
|
x
=
torch.rand(
2
,
2
)
y
=
x.view(
4
)
x
+
=
1
print
(x)
print
(y)
# 也加了1
|
view() 仅仅是改变了对这个张量的观察角度,内部数据并未改变。这时候想返回一个真正新的副本(即不共享data内存)该怎么办呢?Pytorch还提供了一个reshape()可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先用clone创造一个副本然后再使用view。参考此处 。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
x
=
torch.rand(
2
,
2
)
x_cp
=
x.clone().view(
4
)
x
+
=
1
print
(
id
(x))
print
(
id
(x_cp))
print
(x)
print
(x_cp)
'''
140568935036464
140568935035816
tensor([[0.4963, 0.7682],
[0.1320, 0.3074]])
tensor([[1.4963, 1.7682, 1.1320, 1.3074]])
'''
|
另外使用clone()会被记录在计算图中,即梯度回传到副本时也会传到源Tensor。在上一篇中有总结.
总结:
引用官方文档的话:如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。另外尽量避免in-place的使用.
像y = x + y这样的运算会新开内存,然后将y指向新内存。我们可以使用Python自带的id函数进行验证:如果两个实例的ID相同,则它们所对应的内存地址相同.
到此这篇关于PyTorch中clone()、detach()及相关扩展详解的文章就介绍到这了,更多相关PyTorch中clone()、detach()及相关扩展内容请搜索我以前的文章或继续浏览下面的相关文章希望大家以后多多支持我! 。
原文链接:https://blog.csdn.net/weixin_43199584/article/details/106876679 。
最后此篇关于PyTorch中clone()、detach()及相关扩展详解的文章就讲到这里了,如果你想了解更多关于PyTorch中clone()、detach()及相关扩展详解的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我是 magento 的新手,目前我在 magento 安装期间遇到“必须加载 PHP 扩展 curl ”错误。你能帮帮我吗? 最佳答案 如果您的服务器上没有安装 curl,您可以键入以下命令之一来安
我在 macOS Mojave/macOS Big Sur/macOS Monterey/macOS Ventura 上使用最新的 php 版本 7.2 并收到类似错误 $composer requ
这个问题已经有答案了: Why generic type is not applicable for argument extends super class for both? (5 个回答) 已关
我正在使用 NightWatch.js 并进行一些 UI 测试,我想用一些额外的 desiredCapabilities 启动默认浏览器实例(即启用扩展并应用一些特定值)。 p> 注意:我可以执行这些
有人知道为什么我在 java 8 中使用此代码时没有服务器扩展名称吗: try { URL url = new URL(urlString); URLC
扩展提供给我的类(class)。为现有的类提供新功能。或扩展现有的mixin s 或虚拟类,任何东西都可以工作。 也许是这样的: class FlatButton {} // maybe no
我有一个关于使用 c 代码和 mod_wsgi 扩展 python 的问题。 我在 apache 服务器中有一个 django 应用程序,它查询 postgresql 数据库以生成报告。在某些报告中,
testcafe支持在Chrome浏览器中加载crx扩展吗? 如果是这样,请告诉我需要尝试什么方法。 我尝试了下面的代码,但没有成功 await t.eval(new Function(fs.read
这个问题已经有答案了: What is a raw type and why shouldn't we use it? (16 个回答) 已关闭 3 年前。 有什么区别: // 1 class A c
我正在编写一个 chrome 扩展来记录单击开始按钮后触发的请求。 这是我的文件:1. list .json { "manifest_version": 2, "name": "recorde
扩展是将较短的文本,例如一组提示或主题列表,输入到大型语言模型中,让模型生成更长的文本。我们可以利用这个特性让大语言模型生成基于某个主题的电子邮件或小论文。通过这种方式使用大语言模型,可以为工作与生活
我每天都在使用 vim 和 perforce 现在我的问题是,如果我想查看 perforce 文件修订版,则从命令模式下的 vim :!p4 打印文件#1 vim 试图让我获得缓冲区 #1。有没有办法
大家好,我有一个关于 NUnit 扩展(2.5.10)的问题。 我想做的是向 数据库。为此,我使用 Event 创建了 NUnit 扩展 听众。 我遇到的问题是公共(public)无效 TestFin
我有弹出窗口,而不是模态窗口。 如何通过单击页面的其他部分(不在窗口中)来关闭此窗口? 最佳答案 像这样的东西: function closeWin(e, t) { var el = win.
我通常非常谨慎地使用扩展方法。当我确实觉得有必要编写一个扩展方法时,有时我想重载该方法。我的问题是,您对调用其他扩展方法的扩展方法有何看法?不好的做法?感觉不对,但我无法真正定义原因。 例如,第二个
扩展 Ant Ant带有一组预定义的任务,但是你可以创建自己的任务,如下面的例子所示。 定制Ant 任务应扩展 org.apache.tools.ant.Task 类,同时也应该拓展 execut
我想要一个重定向所有请求的扩展: http://website.com/foo.js 到: http://localhost/myfoo.js 我无法使用主机文件将主机从 website.com 编辑
对于为什么 QChartView 放在 QTabWidget 中时会扩展,我有点迷惑。 这是 QChartView 未展开(因为它被隐藏)时应用程序的图片。 应用程序的黑色部分是 QOpenGLWid
如果在连接条件中使用 OR 运算符,如何优化以下查询以避免 SQL 调优方面的 OR 扩展? SELECT t1.A, t2.B, t1.C, t1.D, t2.E FROM t1 LEFT J
一旦加载插件的问题得到解决(在 .NET 中通过 MEF 的情况下),下一步要解决的是与它们的通信。简单的方法是实现一个接口(interface),使用插件实现,但有时插件只需要扩展应用程序的工作方式
我是一名优秀的程序员,十分优秀!