假设我有一个来自 的神经网络对象torch.nn
,默认情况下requires_grad
是False
它的参数。我想把它改成True
. 但是以下幼稚的方法失败了:
From torch import nn
a = nn.Linear(1, 1)
a.state_dict()[‘weight’].requires_grad = True
print(a.state_dict()[‘weight’].requires_grad)
结果是False
。谁能解释问题是什么以及如何解决?谢谢!我的手电筒版本是 1.7.1。