pytorch 在用于机器翻译的seq2seq RNN的训练期间,丢失急剧减少,而BLEU分数保持为零

tvz2xvvm  于 2023-01-20  发布在  其他
关注(0)|答案(1)|浏览(171)

我正在尝试用LSTM训练一个RNN进行机器翻译。但是,第一批的BLEU下降到零,并且在整个训练过程中一直保持在这个水平。同时,损失也在急剧下降。可能是什么问题?

密码:

class SimpleRNNTranslator(nn.Module):
    def __init__(self, inp_voc, out_voc, emb_size=64, hid_size=128):
        """
        My version of simple RNN model, I use LSTM instead of GRU as in the baseline
        """
        super().__init__()
        
        self.inp_voc = inp_voc
        self.out_voc = out_voc
        
        self.emb_inp = nn.Embedding(len(inp_voc), emb_size)
        self.emb_out = nn.Embedding(len(out_voc), emb_size)
        
        self.encoder = nn.LSTM(emb_size, hid_size, batch_first=True)
        self.decoder = nn.LSTM(emb_size, hid_size, batch_first=True)
        
        self.decoder_start = nn.Linear(hid_size, hid_size)
        self.logits = nn.Linear(hid_size, len(out_voc))
        
    def forward(self, inp, out):
        """
        Apply model in training mode
        """
        encoded_seq = self.encode(inp)
        decoded_seq, _ = self.decode(encoded_seq, out)
        return self.logits(decoded_seq)
    
    def encode(self, seq_in):
        """
        Take input symbolic sequence, compute initial hidden state for decoder
        :param seq_in: matrix of input tokens [batch_size, seq_in_len]
        :return: initial hidden state for the decoder
        """
        embeddings = self.emb_inp(seq_in)
        output, (_, __) = self.encoder(embeddings)
    
        # last state isn't the actually last because of the padding, the next 2 lines find out the true last state
        seq_lengths = (seq_in != self.inp_voc.eos_ix).sum(dim=-1)
        
        last_states = output[range(seq_lengths.shape[0]), seq_lengths]
        
        return self.decoder_start(last_states)
    
    def decode(self, hidden_state, seq_out, previous_state=None):
        """
        Take output symbolic sequence, compute logits for every token in sequence
        :param hidden_state: matrix of initial_hidden_state [batch_size, hid_size]
        :param previous_state: matrix of previous state [batch_size, hid_size]
        :param seq_out: matrix of output tokens [batch_size, seq_out_len]
        :return: logits for every token in sequence [batch_size, seq_len, out_voc]
        """
        if not torch.is_tensor(previous_state):
            previous_state = torch.randn(*hidden_state.shape).to(device)
            
        embeddings = self.emb_out(seq_out)
        outputs, (_, cn) = self.decoder(embeddings, (hidden_state[None, :, :], previous_state[None, :, :]))
        
        return outputs, cn
    
    def inference(self, inp_tokens, max_len):
        """
        Take initial state and return ids for out words
        :param initial_state: initial_state for a decoder, produced by encoder with input tokens
        """
        initial_state = self.encode(inp_tokens)
        states = [initial_state]
        outputs = [torch.full([initial_state.shape[0]], self.out_voc.bos_ix, dtype=torch.int, device=device)]
        
        cn = None
        
        for i in range(100):
            hidden_state, cn = self.decode(states[-1], outputs[-1][:, None], previous_state=cn)
            hidden_state, cn = hidden_state.squeeze(), cn.squeeze()
            outputs.append(self.logits(hidden_state).argmax(dim=-1))
            states.append(hidden_state)

        
        return torch.stack(outputs, dim=-1), torch.cat(states)
            
    
    def translate_lines(self, lines, max_len=100):
        """
        Take lines and return translation
        :param lines: list of lines in Russian
        """
        inp_tokens = self.inp_voc.to_matrix(lines).to(device)
        out_ids, states = self.inference(inp_tokens, max_len=max_len)
        return self.out_voc.to_lines(out_ids.cpu().numpy()), states

**How I compute BLEU: **
from nltk.translate.bleu_score import corpus_bleu
def compute_bleu(model, inp_lines, out_lines, bpe_sep='@@ ', **flags):
    """
    Estimates corpora-level BLEU score of model's translations given inp and reference out
    Note: if you're serious about reporting your results, use https://pypi.org/project/sacrebleu
    """
    with torch.no_grad():
        translations, _ = model.translate_lines(inp_lines, **flags)
        translations = [line.replace(bpe_sep, '') for line in translations]
        actual = [line.replace(bpe_sep, '') for line in out_lines]
        return corpus_bleu(
            [[ref.split()] for ref in actual],
            [trans.split() for trans in translations],
            smoothing_function=lambda precisions, **kw: [p + 1.0 / p.denominator for p in precisions]
            ) * 100

训练,在开发数据集和Loss Training, plots of BLEU score evaluated on development dataset and Loss上评价的BLEU评分图
我认为这个问题可能与LSTM的工作方式有关。一开始,我没有在序列元素期间传递单元格状态,只传递隐藏状态。我修复了这个问题,但它没有解决这个问题

ekqde3dh

ekqde3dh1#

在计算损失时,您可能忘记移动目标序列。
在训练时,解码器序列需要被移位,使得第(* n *-1)个预测第 * n * 个单词。对于具有句子开始标记[BOS]和句子结束标记[EOS]的序列w1 w2 w3 w4,如下所示:

BOS w1  w2  w3  w4
↓   ↓   ↓   ↓   ↓
▯ → ▯ → ▯ → ▯ → ▯  
↓   ↓   ↓   ↓   ↓
w1  w2  w3  w4  EOS

一般来说:向解码器提供没有最后标记的目标序列,并计算关于没有第一标记的目标序列的丢失。
如果不这样做,则解码器如下所示:

w1  w2  w3  w4
↓   ↓   ↓   ↓
▯ → ▯ → ▯ → ▯
↓   ↓   ↓   ↓
w1  w2  w3  w4

该模型快速学习复制输入标记,并且丢失快速减少,但是该模型不学习翻译。

相关问题