如何使用 MxNet metrics api计算带有矢量标签的多类逻辑回归分类器的准确性?以下是标签的示例:
Class1: [1,0,0,0]
Class2: [0,1,0,0]
Class3: [0,0,1,0]
Class4: [0,0,0,1]
使用此函数的天真方法会产生错误的结果,因为 argmax 会将模型输出压缩为具有最大概率值的索引
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
acc.update(preds=p, labels=label)
return acc.get()[1]
我目前的解决方案有点hacky:
def evaluate_accuracy(data_iterator, ctx, net):
acc = mx.metric.Accuracy()
for i, (data, label) in enumerate(data_iterator):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
out = net(data)
p = nd.argmax(out, axis=1)
l = nd.argmax(label, axis=1)
acc.update(preds=p, labels=l)
return acc.get()[1]