CrossEntropyLoss我有一些与忽略类一起使用的训练管道。
log_probs形状的模型输出(150, 3)- 意味着 3 个可能的类别,每批 150 个。
是label_batchshape150和torch.max(label_batch)== tensor(3, device='cuda:0'),意味着有一个额外的类被标记为3,它是忽略类。
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',
ignore_index=3
)
但是准确度指标认为类3是有效的并且给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
self.train_acc.update(log_probs, label_batch)由于标签的错误结果3应该被忽略。
如何正确使用pl.metrics.Accuracy()忽略类?