1

我已经尝试使用一个热编码来使用狗品种数据集对大约 120 个类进行多类分类。也使用 resnet18。但是当我运行代码时会显示以下错误。请帮我解决问题。

我的模型的代码如下所示:

model = torchvision.models.resnet18()
op = torch.optim.Adam(model.parameters(),lr=0.001)
crit = nn.NLLLoss()
model.fc = nn.Sequential(
    nn.Linear(512,120),
    nn.Dropout(inplace=True),
    nn.ReLU(),
    nn.LogSoftmax())

for i,(x,y) in enumerate(train_dl):
    # prepare one-hot vector
    y_oh=torch.zeors(y.shape[0],120)
    y_oh.scatter_(1, y.unsqueeze(1), 1)

    # do the prediction
    y_hat=model(x)
    y_=torch.max(y_hat)

    loss=crit(y,y_)
    op.zero_grad()
    loss.backward()
    op.step()

错误:

RuntimeError Traceback (most recent call last) <ipython-input-190-46a21ead759a> in <module>
       6 
       7     y_hat=model(x)
 ----> 8     loss=crit(y_oh,y_hat)
       9     op.zero_grad()
      10     loss.backward()

***RuntimeError: 1D target tensor expected, multi-target not supported***
4

1 回答 1

0

NLLLoss您正在使用的期望真实目标类的索引。顺便提一句。您不必将目标转换为 one-hot 向量并直接使用y张量。

另请注意,NLLLoss期望分布输出分布在对数域中,即使用nn.LogSoftmax而不是nn.Softmax

于 2020-05-11T07:24:22.520 回答