我在 PyTorch 中实现了一些 RL,并且不得不编写自己的 mse_loss 函数(我在 Stackoverflow 上找到了该函数;))。损失函数为:
def mse_loss(input_, target_):
return torch.sum(
(input_ - target_) * (input_ - target_)) / input_.data.nelement()
现在,在我的训练循环中,第一个输入类似于:
tensor([-1.7610e+10]), tensor([-6.5097e+10])
有了这个输入,我会得到错误:
Unable to get repr for <class 'torch.Tensor'>
计算a = (input_ - target_)
工作正常,而b = a * a
分别b = torch.pow(a, 2)
会因上面提到的错误而失败。
有谁知道解决这个问题?
非常感谢!
更新:我刚刚尝试使用torch.nn.functional.mse_loss
它会导致相同的错误..