使用PyTorch为较大批量累积梯度

ykejflvf  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(134)

为了模拟更大的批量,我希望能够在PyTorch中为模型每N个批次累积梯度,例如:

def train(model, optimizer, dataloader, num_epochs, N):
     for epoch_num in range(1, num_epochs+1):
         for batch_num, data in enumerate(dataloader):
             ims = data.to('cuda:0') 
             loss = model(ims)
             loss.backward()
             if batch_num % N == 0:
                 optimizer.step()
                 optimizer.zero_grad(set_to_none=True)

对于这种方法,我是否需要添加标志retain_graph=True,即

loss.backward(retain_graph=True)

以这种方式,每个向后调用的梯度是否简单地按每个参数求和?

ecr0jaav

ecr0jaav1#

如果你想在同一个计算图上进行多次反向传递,利用一次正向传递的中间结果,你需要设置retain_graph=True。例如,如果你在计算loss一次之后多次调用loss.backward(),或者如果您有来自图的不同部分的多个损耗要从其反向传播(可以找到here的一个很好的解释)。
在你的例子中,对于每一次向前传递,你只需要向后传播一次,所以你不需要在计算梯度后存储来自计算图的中间结果。
简而言之:

  • 图形中的 * 中间输出 * 在反向传递后清除,除非使用retain_graph=True明确保留。
    • 渐变 * 默认情况下累加,除非使用zero_grad明确清除。

相关问题