0

如何使用 MxNet metrics api计算带有矢量标签的多类逻辑回归分类器的准确性?以下是标签的示例:

Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]

使用此函数的天真方法会产生错误的结果,因为 argmax 会将模型输出压缩为具有最大概率值的索引

def evaluate_accuracy(data_iterator, ctx, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        out = net(data)
        p = nd.argmax(out, axis=1)
        acc.update(preds=p, labels=label)
    return acc.get()[1]

我目前的解决方案有点hacky:

def evaluate_accuracy(data_iterator, ctx, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        out = net(data)
        p = nd.argmax(out, axis=1)
        l = nd.argmax(label, axis=1)
        acc.update(preds=p, labels=l)
    return acc.get()[1]
4

1 回答 1

1

准确度指标很棘手。它并不能真正将单热编码标签作为基本事实。

我觉得这有点违反直觉,但是您需要将非一个热编码标签作为基本事实传递,而是实际类(例如,2 而不是 [0,0,1,0])。否则,准确性将无法按您期望的方式工作。在此处查看我之前的回复 -为什么 MXNet 报告的验证准确性不正确?

此外,MxNet 期望类从 0 开始。所以,如果你有从 1 开始的类,那么你需要通过减去 1 来调整所有类。

于 2018-03-12T18:36:19.590 回答