1

假设我有一个来自 的神经网络对象torch.nn,默认情况下requires_gradFalse它的参数。我想把它改成True. 但是以下幼稚的方法失败了:

From torch import nn
a = nn.Linear(1, 1)
a.state_dict()[‘weight’].requires_grad = True
print(a.state_dict()[‘weight’].requires_grad)

结果是False。谁能解释问题是什么以及如何解决?谢谢!我的手电筒版本是 1.7.1。

4

1 回答 1

1

默认情况下,可训练 nn对象参数将具有requires_grad=True. 您可以通过执行以下操作来验证:

import torch.nn as nn

layer = nn.Linear(1, 1)

for param in layer.parameters():
    print(param.requires_grad)

# or use
print(layer.weight.requires_grad)
print(layer.bias.requires_grad)

改变requires_grad状态:

for param in layer.parameters():
    param.requires_grad = False # or True
于 2021-04-09T07:53:15.773 回答