我正在尝试用LSTM训练一个RNN进行机器翻译。但是,第一批的BLEU下降到零,并且在整个训练过程中一直保持在这个水平。同时,损失也在急剧下降。可能是什么问题?
密码:
class SimpleRNNTranslator(nn.Module):
def __init__(self, inp_voc, out_voc, emb_size=64, hid_size=128):
"""
My version of simple RNN model, I use LSTM instead of GRU as in the baseline
"""
super().__init__()
self.inp_voc = inp_voc
self.out_voc = out_voc
self.emb_inp = nn.Embedding(len(inp_voc), emb_size)
self.emb_out = nn.Embedding(len(out_voc), emb_size)
self.encoder = nn.LSTM(emb_size, hid_size, batch_first=True)
self.decoder = nn.LSTM(emb_size, hid_size, batch_first=True)
self.decoder_start = nn.Linear(hid_size, hid_size)
self.logits = nn.Linear(hid_size, len(out_voc))
def forward(self, inp, out):
"""
Apply model in training mode
"""
encoded_seq = self.encode(inp)
decoded_seq, _ = self.decode(encoded_seq, out)
return self.logits(decoded_seq)
def encode(self, seq_in):
"""
Take input symbolic sequence, compute initial hidden state for decoder
:param seq_in: matrix of input tokens [batch_size, seq_in_len]
:return: initial hidden state for the decoder
"""
embeddings = self.emb_inp(seq_in)
output, (_, __) = self.encoder(embeddings)
# last state isn't the actually last because of the padding, the next 2 lines find out the true last state
seq_lengths = (seq_in != self.inp_voc.eos_ix).sum(dim=-1)
last_states = output[range(seq_lengths.shape[0]), seq_lengths]
return self.decoder_start(last_states)
def decode(self, hidden_state, seq_out, previous_state=None):
"""
Take output symbolic sequence, compute logits for every token in sequence
:param hidden_state: matrix of initial_hidden_state [batch_size, hid_size]
:param previous_state: matrix of previous state [batch_size, hid_size]
:param seq_out: matrix of output tokens [batch_size, seq_out_len]
:return: logits for every token in sequence [batch_size, seq_len, out_voc]
"""
if not torch.is_tensor(previous_state):
previous_state = torch.randn(*hidden_state.shape).to(device)
embeddings = self.emb_out(seq_out)
outputs, (_, cn) = self.decoder(embeddings, (hidden_state[None, :, :], previous_state[None, :, :]))
return outputs, cn
def inference(self, inp_tokens, max_len):
"""
Take initial state and return ids for out words
:param initial_state: initial_state for a decoder, produced by encoder with input tokens
"""
initial_state = self.encode(inp_tokens)
states = [initial_state]
outputs = [torch.full([initial_state.shape[0]], self.out_voc.bos_ix, dtype=torch.int, device=device)]
cn = None
for i in range(100):
hidden_state, cn = self.decode(states[-1], outputs[-1][:, None], previous_state=cn)
hidden_state, cn = hidden_state.squeeze(), cn.squeeze()
outputs.append(self.logits(hidden_state).argmax(dim=-1))
states.append(hidden_state)
return torch.stack(outputs, dim=-1), torch.cat(states)
def translate_lines(self, lines, max_len=100):
"""
Take lines and return translation
:param lines: list of lines in Russian
"""
inp_tokens = self.inp_voc.to_matrix(lines).to(device)
out_ids, states = self.inference(inp_tokens, max_len=max_len)
return self.out_voc.to_lines(out_ids.cpu().numpy()), states
**How I compute BLEU: **
from nltk.translate.bleu_score import corpus_bleu
def compute_bleu(model, inp_lines, out_lines, bpe_sep='@@ ', **flags):
"""
Estimates corpora-level BLEU score of model's translations given inp and reference out
Note: if you're serious about reporting your results, use https://pypi.org/project/sacrebleu
"""
with torch.no_grad():
translations, _ = model.translate_lines(inp_lines, **flags)
translations = [line.replace(bpe_sep, '') for line in translations]
actual = [line.replace(bpe_sep, '') for line in out_lines]
return corpus_bleu(
[[ref.split()] for ref in actual],
[trans.split() for trans in translations],
smoothing_function=lambda precisions, **kw: [p + 1.0 / p.denominator for p in precisions]
) * 100
训练,在开发数据集和Loss Training, plots of BLEU score evaluated on development dataset and Loss上评价的BLEU评分图
我认为这个问题可能与LSTM的工作方式有关。一开始,我没有在序列元素期间传递单元格状态,只传递隐藏状态。我修复了这个问题,但它没有解决这个问题
1条答案
按热度按时间ekqde3dh1#
在计算损失时,您可能忘记移动目标序列。
在训练时,解码器序列需要被移位,使得第(* n *-1)个预测第 * n * 个单词。对于具有句子开始标记
[BOS]
和句子结束标记[EOS]
的序列w1 w2 w3 w4
,如下所示:一般来说:向解码器提供没有最后标记的目标序列,并计算关于没有第一标记的目标序列的丢失。
如果不这样做,则解码器如下所示:
该模型快速学习复制输入标记,并且丢失快速减少,但是该模型不学习翻译。