2

使用以下代码片段记录指标时:

    def training_epoch_end(self, outs):
        self.__common_epoch_end_report(mode="train")

    def validation_epoch_end(self, outs):
        self.__common_epoch_end_report(mode="validation")

   def __common_epoch_end_report(self, mode: str):
        """
        Args:
            mode: one of ["train", "validation", "test"]
        """
        tb = self.logger.experiment

        acc, conf, f1, mcc = self.__select_metrics_by_mode(mode)

        acc_val = acc.compute()
        f1_val = f1.compute()
        tb.add_scalars('accuracy', {mode: acc_val}, global_step=self.current_epoch)
        tb.add_scalars('F1', {mode: f1_val}, global_step=self.current_epoch)

        self.log(fr'accuracy\{mode}', acc_val, on_step=False, on_epoch=True, prog_bar=True)
        self.log(fr'F1\{mode}', f1_val, on_step=False, on_epoch=True, prog_bar=True)


    def __select_metrics_by_mode(self, mode):
        if mode == "train":
            acc = self.train_acc
            f1 = self.train_f1
            mcc = self._train_mcc
            conf = self.train_confusion
        elif mode == "validation":
            acc = self.val_acc
            f1 = self.val_f1
            mcc = self._val_mcc
            conf = self.val_confusion
        elif mode == "test":
            acc = self.test_acc
            f1 = self.test_f1
            mcc = self._test_mcc
            conf = self.test_confusion
        else:
            raise ValueError("unsupported mode")
        return acc, conf, f1, mcc

我证实__common_epoch_end_report确实输入了 withmode='train'和 with mode='validation'

但是,只有从中记录的指标train可用于检查点:

checkpoint_callback = ModelCheckpoint(
        ...
        save_top_k=10,
        monitor=fr'F1\validation'
)
trainer = Trainer(
      ...
        callbacks=[checkpoint_callback],
)

收到以下错误:

pytorch_lightning.utilities.exceptions.MisconfigurationException: ModelCheckpoint(monitor='F1\validation') not found in the returned metrics: ['loss\\train_step', 'loss\\train_epoch', 'loss\\train', 'accuracy\\train', 'F1\\train'].

如何允许通过 Pytorch-lightning 中的验证指标进行检查?

4

0 回答 0