使用以下代码片段记录指标时:
def training_epoch_end(self, outs):
self.__common_epoch_end_report(mode="train")
def validation_epoch_end(self, outs):
self.__common_epoch_end_report(mode="validation")
和
def __common_epoch_end_report(self, mode: str):
"""
Args:
mode: one of ["train", "validation", "test"]
"""
tb = self.logger.experiment
acc, conf, f1, mcc = self.__select_metrics_by_mode(mode)
acc_val = acc.compute()
f1_val = f1.compute()
tb.add_scalars('accuracy', {mode: acc_val}, global_step=self.current_epoch)
tb.add_scalars('F1', {mode: f1_val}, global_step=self.current_epoch)
self.log(fr'accuracy\{mode}', acc_val, on_step=False, on_epoch=True, prog_bar=True)
self.log(fr'F1\{mode}', f1_val, on_step=False, on_epoch=True, prog_bar=True)
def __select_metrics_by_mode(self, mode):
if mode == "train":
acc = self.train_acc
f1 = self.train_f1
mcc = self._train_mcc
conf = self.train_confusion
elif mode == "validation":
acc = self.val_acc
f1 = self.val_f1
mcc = self._val_mcc
conf = self.val_confusion
elif mode == "test":
acc = self.test_acc
f1 = self.test_f1
mcc = self._test_mcc
conf = self.test_confusion
else:
raise ValueError("unsupported mode")
return acc, conf, f1, mcc
我证实__common_epoch_end_report
确实输入了 withmode='train'
和 with mode='validation'
。
但是,只有从中记录的指标train
可用于检查点:
checkpoint_callback = ModelCheckpoint(
...
save_top_k=10,
monitor=fr'F1\validation'
)
trainer = Trainer(
...
callbacks=[checkpoint_callback],
)
收到以下错误:
pytorch_lightning.utilities.exceptions.MisconfigurationException: ModelCheckpoint(monitor='F1\validation') not found in the returned metrics: ['loss\\train_step', 'loss\\train_epoch', 'loss\\train', 'accuracy\\train', 'F1\\train'].
如何允许通过 Pytorch-lightning 中的验证指标进行检查?