gpt4 book ai didi

python - 如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?

转载 作者:行者123 更新时间:2023-12-04 11:51:45 25 4
gpt4 key购买 nike

The official doc只有状态

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
这并未展示如何在框架中使用指标。
我的尝试(方法不完整,只显示相关部分):
def __init__(...):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)

def validation_step(self, batch, batch_index):
...
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs, label_batch)

self.val_confusion.update(log_probs, label_batch)
self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)

def validation_step_end(self, outputs):
return outputs

def validation_epoch_end(self, outs):
self.log('validation_confusion_epoch', self.val_confusion.compute())

在第 0 个纪元之后,这给出
    Traceback (most recent call last):
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
self.train_loop.run_training_epoch()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
batch_log_metrics = self.get_latest_batch_log_metrics()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
result[dl_key] = self[k]._forward_cache.detach()
AttributeError: 'NoneType' object has no attribute 'detach'


它确实在训练前通过了健全性验证检查。
失败发生在返回 validation_step_end .对我来说意义不大。
使用 mertics 的完全相同的方法可以准确地工作。
如何获得正确的混淆矩阵?

最佳答案

您可以使用 self.logger.experiment.add_figure(*tag*, *figure*) 报告该数字.
变量 self.logger.experiment实际上是 SummaryWriter (来自 PyTorch,而不是 Lightning)。这个类有方法 add_figure ( documentation )。
您可以按如下方式使用它:(MNIST 示例)

    def validation_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
loss = F.nll_loss(preds, y)
return { 'loss': loss, 'preds': preds, 'target': y}

def validation_epoch_end(self, outputs):
preds = torch.cat([tmp['preds'] for tmp in outputs])
targets = torch.cat([tmp['target'] for tmp in outputs])
confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)

df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
plt.figure(figsize = (10,7))
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
plt.close(fig_)

self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)

关于python - 如何在 pytorch-lightning 中使用 TensorBoard 记录器转储混淆矩阵?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65498782/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com