我正在尝试训练一个用于序列建模的Transformer模型。下面是一个独立的示例:
import torch
import torch.nn as nn
criterion = nn.MSELoss()
decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=12)
memory = torch.rand(10, 32, 512)
y = torch.rand(20, 32, 512)
start_token = torch.ones((1,32,512))
tgt_input = torch.cat((start_token,y[:-1,:]),axis=0)
optimizer = torch.optim.Adam(transformer_decoder.parameters())
###################Teacher forced
while(True):
optimizer.zero_grad()
out = transformer_decoder(tgt_input, memory, nn.Transformer.generate_square_subsequent_mask(20,20))
loss = criterion(out,y)
print("loss: ", loss.item())
loss.backward()
optimizer.step()
字符串
对于12层解码器,该模型在具有8GB内存的个人机器上工作正常。该模型是自回归的,并与转移的目标。鉴于我们提供了上述目标,我把这种设置称为“教师强迫”。
然而,在推理阶段,我们不会像上面那样提供目标,并且需要在运行时生成目标。此设置如下所示:
###################Non Teacher forced
while(True):
optimizer.zero_grad()
predictions = torch.ones((1,32,512))
for i in range(1,21):
predictions = torch.cat((predictions, transformer_decoder(tgt_input[:i], memory, nn.Transformer.generate_square_subsequent_mask(i,i))[-1].unsqueeze(0)),axis=0)
print("i: ", i, "predictions.shape: ", predictions.shape)
loss = criterion(predictions[1:],y)
print("loss: ", loss.item())
loss.backward()
optimizer.step()
型
我希望用混合训练策略训练模型,有,没有老师强迫。然而,非教师强制策略导致内存不足异常,并且不起作用。对于最终推理(测试),通常with torch.no_grad()
可以工作,但在训练中不能工作。谁能解释一下为什么这会导致内存瓶颈呢?
2条答案
按热度按时间sqxo8psd1#
这是因为计算图的滚动。对于教师强制模型,渐变不会在真实值之后传播。然而,对于非教师强制模型,它们反向传播,使梯度累积(类似于RNN)。
unftdfkk2#
您可以尝试检查模型的标题生成器部分,并尝试torch.no在不需要背景的代码部分使用www.example.com _grad()。这样可以保存一些内存。此外,您可以尝试以最佳方式减少最大序列长度(我通过反复试验进行了尝试)。我遇到了一个非常类似的问题,按照上面的步骤,我设法在一定程度上增加了批量大小。