Pytorch 'torch.no_grad()'不影响模块

k5hmc34c  于 2023-01-09  发布在  其他
关注(0)|答案(1)|浏览(142)

我从文档中得到了(显然是错误的)这样的印象:torch.no_grad(),作为一个上下文管理器,应该把所有的东西都变成requires_grad=False。事实上,这正是我打算使用torch.no_grad()的目的,作为一个方便的上下文管理器,用来示例化一堆我希望保持不变的东西(通过训练)。但这似乎只是torch.Tensor的情况;它似乎不会影响torch.nn.Module,如以下示例代码所示:

with torch.no_grad():
    linear = torch.nn.Linear(2, 3)
for p in linear.parameters():
    print(p.requires_grad)

这将输出:

True
True

在我看来,这有点违反直觉。这是预期的行为吗?如果是,为什么?是否有一个类似的方便的上下文管理器,我可以确信我在它下面示例化的任何东西都不需要梯度?

nqwrtyyt

nqwrtyyt1#

这是预期的行为,但我同意文档中有些不清楚。请注意,文档中说:
在这种模式下,每次计算的结果都将requires_grad = False,即使输入的requires_grad = True。
这个上下文禁用了在上下文中所做的任何计算的输出上的渐变。从技术上讲,声明/创建一个层不是计算,所以参数的requires_gradTrue。但是,对于在这个上下文中所做的任何计算,您将无法计算渐变。计算输出的requires_grad将是False。这可能是最好的解释扩展您的代码片段如下:

with torch.no_grad():
     linear = torch.nn.Linear(2, 3)
     for p in linear.parameters():
         print(p.requires_grad)
     out  = linear(torch.rand(10,2))
     print(out.requires_grad)
out = linear(torch.rand(10,2)) 
print(out.requires_grad)
True
True
False
True

即使层参数的requires_gradTrue,也无法计算渐变,因为输出为requires_gradFalse

相关问题