self.b) * (self.c) 在哪里 self.a, self.b, and self.c 是可学习的参数。 我的问题是-6ren">
gpt4 book ai didi

python - 当我使用 "if statement"之类的函数时,PyTorch optimizer.step() 不会更新权重

转载 作者:行者123 更新时间:2023-12-03 18:49:27 24 4
gpt4 key购买 nike

我的模型需要学习某些参数来解决这个函数:

self.a * (r > self.b) * (self.c) 
在哪里
self.a, self.b, and self.c  
是可学习的参数。
我的问题是 b在梯度更新期间不会改变。我认为这是因为该功能是不连续的。虽然,它是一个阶梯函数,所以我不知道如何修改它。
任何想法/提示将不胜感激
我的代码是
import torch
import torch.nn as nn
import torch.optim as optim

class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.a = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.b = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.c = torch.nn.Parameter(torch.rand(1, requires_grad=True))

model_net = model()
#function to learn = 5 * (r > 2) * (3)
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)

for epoch in range(10):
for r in range(10):
optimizer.zero_grad()
loss = 5 * (r > 2) * (3) - model_net.a * (r > model_net.b) * (model_net.c)
loss.backward()
optimizer.step()
print(model_net.a)
print(model_net.b)
print(model_net.c)
print()
更新 1:
我找到了 this问题非常相似。这个人也有一个不连续的功能,似乎可以使用 tanh 解决。反而。不过,我还没有找到使用相同方法的方法。
我也经历了 tutorial @seraph 建议的,但我没有在那里找到不连续的功能。

最佳答案

正如其他人提到的,我们需要对非连续函数进行近似。这个怎么样?

import torch
import torch.nn as nn
import torch.optim as optim

class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.a = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.b = torch.nn.Parameter(torch.rand(1, requires_grad=True))
self.c = torch.nn.Parameter(torch.rand(1, requires_grad=True))

model_net = model()
#function to learn = 5 * (r > 2) * (3)
optimizer = optim.Adam(model_net.parameters(), lr = 0.1)

for epoch in range(10):
for r in range(10):
optimizer.zero_grad()
loss = 5 * (r > 2) * (3) - model_net.a * torch.tanh((r - model_net.b)) * (model_net.c) #this is the change -- you can try tanh/sigmoid/etc and see which one works better for you
loss.backward()
optimizer.step()
#print(model_net.a)
print(model_net.b)
#print(model_net.c)
print()

关于python - 当我使用 "if statement"之类的函数时,PyTorch optimizer.step() 不会更新权重,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67172032/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com