gpt4 book ai didi

pytorch - 如何将 LBFGS 优化器与 pytorch ignite 一起使用?

转载 作者:行者123 更新时间:2023-12-02 03:10:16 26 4
gpt4 key购买 nike

我最近开始使用 Ignite,我发现它非常有趣。我想使用 torch.optim 模块中的 LBFGS 算法作为优化器来训练模型。

这是我的代码:

from ignite.engine import Events, Engine, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import RootMeanSquaredError, Loss
from ignite.handlers import EarlyStopping

D_in, H, D_out = 5, 10, 1
model = simpleNN(D_in, H, D_out) # a simple MLP with 1 Hidden Layer
model.double()
train_loader, val_loader = get_data_loaders(i)

optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
loss_func = torch.nn.MSELoss()

#Ignite
trainer = create_supervised_trainer(model, optimizer, loss_func)
evaluator = create_supervised_evaluator(model, metrics={'RMSE': RootMeanSquaredError(),'LOSS': Loss(loss_func)})

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
print("Epoch[{}] Loss: {:.5f}".format(engine.state.epoch, len(train_loader), engine.state.output))

def score_function(engine):
val_loss = engine.state.metrics['RMSE']
print("VAL_LOSS: {:.5f}".format(val_loss))
return -val_loss

handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
evaluator.add_event_handler(Events.COMPLETED, handler)

trainer.run(train_loader, max_epochs=100)

引发的错误是:TypeError:step() 缺少 1 个必需的位置参数:“closure”

我知道需要为 LBFGS 的实现定义闭包,所以我的问题是如何使用 ignite 来实现?还是有其他方法可以做到这一点?

最佳答案

实现方式是这样的:

    from ignite.engine import Engine

model = ...
optimizer = torch.optim.LBFGS(model.parameters(), lr=1)
criterion =

def update_fn(engine, batch):
model.train()
x, y = batch
# pass to device if needed as here: https://github.com/pytorch/ignite/blob/40d815930d7801b21acfecfa21cd2641a5a50249/ignite/engine/__init__.py#L45
def closure():
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
return loss

optimizer.step(closure)

trainer = Engine(update_fn)

# everything else is the same

Source

关于pytorch - 如何将 LBFGS 优化器与 pytorch ignite 一起使用?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/57806980/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com