gpt4 book ai didi

python - 是否可以从 test_step() 函数保存文件?

转载 作者:行者123 更新时间:2023-12-05 05:48:51 24 4
gpt4 key购买 nike

我正在尝试使用 PyTorch Lightning 实现 MNIST 数字。

train 函数如下图所示

def train(epochs, train_loader, test_loader, model):
early_stopping = EarlyStopping('train_loss', mode='min', patience=5)
model_checkpoint = ModelCheckpoint(dirpath=model_path/'mnist_{epoch}-{train_loss:.2f}',monitor='train_loss', mode='min', save_top_k=3)
trainer = pl.Trainer(max_epochs=epochs, profiler=False, callbacks = [model_checkpoint],default_root_dir=model_path)
trainer.fit(model, train_dataloader=train_loader)
trainer.test(test_dataloaders=test_loader, ckpt_path=None)

test_step 函数如下所示

def test_step(self, test_batch):
x, y = test_batch
logits = self.forward(x)
loss = self.mean_squared_error_loss(logits.squeeze(-1), y.float())

# I want to calculate R2, MAPE, etc and want to save in a pandas df and
# need to return to the train function

self.log('test_loss', loss)
return {'test_loss': loss}

我可以使用 TorchMetrics 计算 R2、MAPE 等。但是,我不确定如何(或是否可能)将它们保存在整个测试数据集的 pandas df(或可能在列表中)中。我经历过这个 post但不确定我应该如何尝试!

如有任何建议,我们将不胜感激。

最佳答案

您可以在 test_epoch_end 中汇总测试结果:

def test_step(self, test_batch):
x, y = test_batch
logits = self.forward(x)
loss = self.mean_squared_error_loss(logits.squeeze(-1), y.float())


self.log('test_loss', loss)
return {'test_loss': loss, "logits":logits, "labels": y}

def test_epoch_end(self, outputs):
all_preds, all_labels = [], []
for output in outputs:
probs = list(output['logits'].cpu().detach().numpy()) # predicted values
labels = list(output['labels'].flatten().cpu().detach().numpy())
all_preds.extend(probs)
all_labels.extend(labels)

# you can calculate R2 here or save results as file
r2 = ...

请注意,这仅适用于单个 GPU。如果您使用多个 GPU,则需要一些函数来收集来自不同 GPU 的结果。

要获得模型预测,您需要在模型类中添加一个 predict_step()。

def predict_step(self, test_batch):
x, y = test_batch
logits = self.forward(x)
return {'logits': logits, 'labels':y}

然后运行:

outputs = trainer.predict(model, test_loader, return_predictions=True)

关于python - 是否可以从 test_step() 函数保存文件?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/70748858/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com