我从文档中得到了(显然是错误的)这样的印象:torch.no_grad()
,作为一个上下文管理器,应该把所有的东西都变成requires_grad=False
。事实上,这正是我打算使用torch.no_grad()
的目的,作为一个方便的上下文管理器,用来示例化一堆我希望保持不变的东西(通过训练)。但这似乎只是torch.Tensor
的情况;它似乎不会影响torch.nn.Module
,如以下示例代码所示:
with torch.no_grad():
linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
print(p.requires_grad)
这将输出:
True
True
在我看来,这有点违反直觉。这是预期的行为吗?如果是,为什么?是否有一个类似的方便的上下文管理器,我可以确信我在它下面示例化的任何东西都不需要梯度?
1条答案
按热度按时间nqwrtyyt1#
这是预期的行为,但我同意文档中有些不清楚。请注意,文档中说:
在这种模式下,每次计算的结果都将requires_grad = False,即使输入的requires_grad = True。
这个上下文禁用了在上下文中所做的任何计算的输出上的渐变。从技术上讲,声明/创建一个层不是计算,所以参数的
requires_grad
是True
。但是,对于在这个上下文中所做的任何计算,您将无法计算渐变。计算输出的requires_grad
将是False
。这可能是最好的解释扩展您的代码片段如下:即使层参数的
requires_grad
为True
,也无法计算渐变,因为输出为requires_grad
False
。