CrossEntropyLoss
我有一些与忽略类一起使用的训练管道。
log_probs
形状的模型输出(150, 3)
- 意味着 3 个可能的类别,每批 150 个。
是label_batch
shape150
和torch.max(label_batch)
== tensor(3, device='cuda:0')
,意味着有一个额外的类被标记为3
,它是忽略类。
损失处理得很好:
self._criterion = nn.CrossEntropyLoss(
reduction='mean',
ignore_index=3
)
但是准确度指标认为类3
是有效的并且给出了非常错误的结果:
self.train_acc = pl.metrics.Accuracy()
self.train_acc.update(log_probs, label_batch)
由于标签的错误结果3
应该被忽略。
如何正确使用pl.metrics.Accuracy()
忽略类?