0

CrossEntropyLoss我有一些与忽略类一起使用的训练管道。

log_probs形状的模型输出(150, 3)- 意味着 3 个可能的类别,每批 150 个。

label_batchshape150torch.max(label_batch)== tensor(3, device='cuda:0'),意味着有一个额外的类被标记为3,它是忽略类。

损失处理得很好:

self._criterion = nn.CrossEntropyLoss(
    reduction='mean',
    ignore_index=3
)

但是准确度指标认为类3是有效的并且给出了非常错误的结果:

self.train_acc = pl.metrics.Accuracy()

self.train_acc.update(log_probs, label_batch)由于标签的错误结果3应该被忽略。


如何正确使用pl.metrics.Accuracy()忽略类?

4

1 回答 1

1

从 github 论坛https://github.com/PyTorchLightning/pytorch-lightning/discussions/6890的讨论帖中复制回复


准确度指标目前不支持它,但我们有一个开放的 PR 用于实现该确切功能 PyTorchLightning/metrics#155

目前,您可以改为计算混淆矩阵,然后基于此忽略一些类(请记住,在混淆矩阵的对角线上可以找到真正的正面/正确分类):

ignore_index = 3
metric = ConfusionMatrix(num_classes=3)
confmat = metric(preds, target)
confmat = confmat[:2,:2] # remove last column and row corresponding to class 3
acc = confmat.trace() / confmat.sum()
于 2021-04-09T08:34:32.307 回答