gpt4 book ai didi

pytorch - PyTorch Lightning 是否是整个 epoch 的平均指标?

转载 作者:行者123 更新时间:2023-12-04 13:28:44 95 4
gpt4 key购买 nike

我正在查看 PyTorch-Lightning 上提供的示例官方文档 https://pytorch-lightning.readthedocs.io/en/0.9.0/lightning-module.html .
这里的损失和度量是根据混凝土批次计算的。但是当记录一个人对特定批次的准确性不感兴趣时​​,它可能很小而且不具有代表性,而是所有时期的平均值。我是否理解正确,有一些代码对所有批次执行平均,通过 epoch?

 import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

class ClassificationTask(pl.LightningModule):

def __init__(self, model):
super().__init__()
self.model = model

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return pl.TrainResult(loss)

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log_dict({'val_acc': acc, 'val_loss': loss})
return result

def test_step(self, batch, batch_idx):
result = self.validation_step(batch, batch_idx)
result.rename_keys({'val_acc': 'test_acc', 'val_loss': 'test_loss'})
return result

def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)

最佳答案

如果你想平均这个时期的指标,你需要告诉 LightningModule你已经子类这样做了。有几种不同的方法可以做到这一点,例如:

  • 调用 result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) as shown in the docson_epoch=True以便训练损失在整个 epoch 中取平均值。即:
  •  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    result = pl.TrainResult(loss)
    result.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return result
  • 或者,您可以拨打 log LightningModule上的方法本身:self.log("train_loss", loss, on_epoch=True, sync_dist=True) (可选地传递 sync_dist=True 以减少跨加速器)。

  • 你会想要在 validation_step 中做类似的事情获取聚合的 val-set 指标或在 validation_epoch_end 中自己实现聚合方法。

    关于pytorch - PyTorch Lightning 是否是整个 epoch 的平均指标?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66516486/

    95 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com