- ubuntu12.04环境下使用kvm ioctl接口实现最简单的虚拟机
- Ubuntu 通过无线网络安装Ubuntu Server启动系统后连接无线网络的方法
- 在Ubuntu上搭建网桥的方法
- ubuntu 虚拟机上网方式及相关配置详解
CFSDN坚持开源创造价值,我们致力于搭建一个资源共享平台,让每一个IT人在这里找到属于你的精彩世界.
这篇CFSDN的博客文章Pytorch反向传播中的细节-计算梯度时的默认累加操作由作者收集整理,如果你对这篇文章有兴趣,记得点赞哟.
今天学习pytorch实现简单的线性回归,发现了pytorch的反向传播时计算梯度采用的累加机制, 于是百度来一下,好多博客都说了累加机制,但是好多都没有说明这个累加机制到底会有啥影响, 所以我趁着自己练习的一个例子正好直观的看一下以及如何解决:
先附上试验代码来感受一下:
torch.manual_seed(6)lr = 0.01 # 学习率result = []# 创建训练数据x = torch.rand(20, 1) * 10y = 2 * x + (5 + torch.randn(20, 1)) # 构建线性回归函数w = torch.randn((1), requires_grad=True)b = torch.zeros((1), requires_grad=True)# 这里是迭代过程,为了看pytorch的反向传播计算梯度的细节,我先迭代两次for iteration in range(2): # 前向传播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 计算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean() # 反向传播 loss.backward() # 这里看一下反向传播计算的梯度 print("w.grad:", w.grad) print("b.grad:", b.grad) # 更新参数 b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad)
上面的代码比较简单,迭代了两次, 看一下计算的梯度结果:
w.grad: tensor([-74.6261]) b.grad: tensor([-12.5532]) w.grad: tensor([-122.9075]) b.grad: tensor([-20.9364]) 。
然后我稍微加两行代码, 就是在反向传播上面,我手动添加梯度清零操作的代码,再感受一下结果:
torch.manual_seed(6)lr = 0.01result = []# 创建训练数据x = torch.rand(20, 1) * 10#print(x)y = 2 * x + (5 + torch.randn(20, 1)) #print(y)# 构建线性回归函数w = torch.randn((1), requires_grad=True)#print(w)b = torch.zeros((1), requires_grad=True)#print(b)for iteration in range(2): # 前向传播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 计算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean() # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0 if iteration > 0: w.grad.data.zero_() b.grad.data.zero_() # 反向传播 loss.backward() # 看一下梯度 print("w.grad:", w.grad) print("b.grad:", b.grad) # 更新参数 b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad)
w.grad: tensor([-74.6261]) b.grad: tensor([-12.5532]) w.grad: tensor([-48.2813]) b.grad: tensor([-8.3831]) 。
从上面可以发现,pytorch在反向传播的时候,确实是默认累加上了上一次求的梯度, 如果不想让上一次的梯度影响自己本次梯度计算的话,需要手动的清零.
但是, 如果不进行手动清零的话,会有什么后果呢? 我在这次线性回归试验中,遇到的后果就是loss值反复的震荡不收敛。下面感受一下:
torch.manual_seed(6)lr = 0.01result = []# 创建训练数据x = torch.rand(20, 1) * 10#print(x)y = 2 * x + (5 + torch.randn(20, 1)) #print(y)# 构建线性回归函数w = torch.randn((1), requires_grad=True)#print(w)b = torch.zeros((1), requires_grad=True)#print(b)for iteration in range(1000): # 前向传播 wx = torch.mul(w, x) y_pred = torch.add(wx, b) # 计算 MSE loss loss = (0.5 * (y - y_pred) ** 2).mean()# print("iteration {}: loss {}".format(iteration, loss)) result.append(loss) # 由于pytorch反向传播中,梯度是累加的,所以如果不想先前的梯度影响当前梯度的计算,需要手动清0 #if iteration > 0: # w.grad.data.zero_() # b.grad.data.zero_() # 反向传播 loss.backward() # 更新参数 b.data.sub_(lr * b.grad) w.data.sub_(lr * w.grad) if loss.data.numpy() < 1: break plt.plot(result)
上面的代码中,我没有进行手动清零,迭代1000次, 把每一次的loss放到来result中, 然后画出图像,感受一下结果:
接下来,我把手动清零的注释打开,进行每次迭代之后的手动清零操作,得到的结果:
可以看到,这个才是理想中的反向传播求导,然后更新参数后得到的loss值的变化.
这次主要是记录一下,pytorch在进行反向传播计算梯度的时候的累加机制到底是什么样子? 至于为什么采用这种机制,我也搜了一下,大部分给出的结果是这样子的:
但是如果不想累加的话,可以采用手动清零的方式,只需要在每次迭代时加上即可 。
w.grad.data.zero_()b.grad.data.zero_()
另外, 在搜索资料的时候,在一篇博客上看到两个不错的线性回归时pytorch的计算图在这里借用一下:
以上为个人经验,希望能给大家一个参考,也希望大家多多支持我.
原文链接:https://blog.csdn.net/wuzhongqiang/article/details/102572324 。
最后此篇关于Pytorch反向传播中的细节-计算梯度时的默认累加操作的文章就讲到这里了,如果你想了解更多关于Pytorch反向传播中的细节-计算梯度时的默认累加操作的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
我想做的是让 JTextPane 在 JPanel 中占用尽可能多的空间。对于我使用的 UpdateInfoPanel: public class UpdateInfoPanel extends JP
我在 JPanel 中有一个 JTextArea,我想将其与 JScrollPane 一起使用。我正在使用 GridBagLayout。当我运行它时,框架似乎为 JScrollPane 腾出了空间,但
我想在 xcode 中实现以下功能。 我有一个 View Controller 。在这个 UIViewController 中,我有一个 UITabBar。它们下面是一个 UIView。将 UITab
有谁知道Firebird 2.5有没有类似于SQL中“STUFF”函数的功能? 我有一个包含父用户记录的表,另一个表包含与父相关的子用户记录。我希望能够提取用户拥有的“ROLES”的逗号分隔字符串,而
我想使用 JSON 作为 mirth channel 的输入和输出,例如详细信息保存在数据库中或创建 HL7 消息。 简而言之,输入为 JSON 解析它并输出为任何格式。 最佳答案 var objec
通常我会使用 R 并执行 merge.by,但这个文件似乎太大了,部门中的任何一台计算机都无法处理它! (任何从事遗传学工作的人的附加信息)本质上,插补似乎删除了 snp ID 的 rs 数字,我只剩
我有一个以前可能被问过的问题,但我很难找到正确的描述。我希望有人能帮助我。 在下面的代码中,我设置了varprice,我想添加javascript变量accu_id以通过rails在我的数据库中查找记
我有一个简单的 SVG 文件,在 Firefox 中可以正常查看 - 它的一些包装文本使用 foreignObject 包含一些 HTML - 文本包装在 div 中:
所以我正在为学校编写一个 Ruby 程序,如果某个值是 1 或 3,则将 bool 值更改为 true,如果是 0 或 2,则更改为 false。由于我有 Java 背景,所以我认为这段代码应该有效:
我做了什么: 我在这些账户之间创建了 VPC 对等连接 互联网网关也连接到每个 VPC 还配置了路由表(以允许来自双方的流量) 情况1: 当这两个 VPC 在同一个账户中时,我成功测试了从另一个 La
我有一个名为 contacts 的表: user_id contact_id 10294 10295 10294 10293 10293 10294 102
我正在使用 Magento 中的新模板。为避免重复代码,我想为每个产品预览使用相同的子模板。 特别是我做了这样一个展示: $products = Mage::getModel('catalog/pro
“for”是否总是检查协议(protocol)中定义的每个函数中第一个参数的类型? 编辑(改写): 当协议(protocol)方法只有一个参数时,根据该单个参数的类型(直接或任意)找到实现。当协议(p
我想从我的 PHP 代码中调用 JavaScript 函数。我通过使用以下方法实现了这一点: echo ' drawChart($id); '; 这工作正常,但我想从我的 PHP 代码中获取数据,我使
这个问题已经有答案了: Event binding on dynamically created elements? (23 个回答) 已关闭 5 年前。 我有一个动态表单,我想在其中附加一些其他 h
我正在尝试找到一种解决方案,以在 componentDidMount 中的映射项上使用 setState。 我正在使用 GraphQL连同 Gatsby返回许多 data 项目,但要求在特定的 pat
我在 ScrollView 中有一个 View 。只要用户按住该 View ,我想每 80 毫秒调用一次方法。这是我已经实现的: final Runnable vibrate = new Runnab
我用 jni 开发了一个 android 应用程序。我在 GetStringUTFChars 的 dvmDecodeIndirectRef 中得到了一个 dvmabort。我只中止了一次。 为什么会这
当我到达我的 Activity 时,我调用 FragmentPagerAdapter 来处理我的不同选项卡。在我的一个选项卡中,我想显示一个 RecyclerView,但他从未出现过,有了断点,我看到
当我按下 Activity 中的按钮时,会弹出一个 DialogFragment。在对话框 fragment 中,有一个看起来像普通 ListView 的 RecyclerView。 我想要的行为是当
我是一名优秀的程序员,十分优秀!