pytorch 具有自回归Transformer解码的存储器瓶颈

wz8daaqr  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(116)

我正在尝试训练一个用于序列建模的Transformer模型。下面是一个独立的示例:

import torch
import torch.nn as nn

criterion = nn.MSELoss()

decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=12)
memory = torch.rand(10, 32, 512)
y = torch.rand(20, 32, 512)

start_token = torch.ones((1,32,512))
tgt_input = torch.cat((start_token,y[:-1,:]),axis=0)

optimizer = torch.optim.Adam(transformer_decoder.parameters())

###################Teacher forced
while(True):
    optimizer.zero_grad()
    out = transformer_decoder(tgt_input, memory, nn.Transformer.generate_square_subsequent_mask(20,20))

    loss = criterion(out,y)
    print("loss: ", loss.item())
    
    loss.backward()
    optimizer.step()

字符串
对于12层解码器,该模型在具有8GB内存的个人机器上工作正常。该模型是自回归的,并与转移的目标。鉴于我们提供了上述目标,我把这种设置称为“教师强迫”。
然而,在推理阶段,我们不会像上面那样提供目标,并且需要在运行时生成目标。此设置如下所示:

###################Non Teacher forced
while(True):
    optimizer.zero_grad()
    predictions = torch.ones((1,32,512))
    for i in range(1,21):
        predictions = torch.cat((predictions, transformer_decoder(tgt_input[:i], memory, nn.Transformer.generate_square_subsequent_mask(i,i))[-1].unsqueeze(0)),axis=0)
        print("i: ", i, "predictions.shape: ", predictions.shape)
        
    loss = criterion(predictions[1:],y)
    print("loss: ", loss.item())
    
    loss.backward()
    optimizer.step()


我希望用混合训练策略训练模型,有,没有老师强迫。然而,非教师强制策略导致内存不足异常,并且不起作用。对于最终推理(测试),通常with torch.no_grad()可以工作,但在训练中不能工作。谁能解释一下为什么这会导致内存瓶颈呢?

sqxo8psd

sqxo8psd1#

这是因为计算图的滚动。对于教师强制模型,渐变不会在真实值之后传播。然而,对于非教师强制模型,它们反向传播,使梯度累积(类似于RNN)。

unftdfkk

unftdfkk2#

您可以尝试检查模型的标题生成器部分,并尝试torch.no在不需要背景的代码部分使用www.example.com _grad()。这样可以保存一些内存。此外,您可以尝试以最佳方式减少最大序列长度(我通过反复试验进行了尝试)。我遇到了一个非常类似的问题,按照上面的步骤,我设法在一定程度上增加了批量大小。

相关问题