pytorch 损失函数不训练

soat7uwm  于 2023-05-29  发布在  其他
关注(0)|答案(1)|浏览(190)

我们正在为SQUAD v2数据集训练一个问答模型。
一个RoBERTa编码器,上面有一个分类器。预测答案跨度非常有效。然而,我们想添加一个前端分类器来预测问题的可回答性(如论文“Retrospective Reader for Machine阅读Comprehension”中所建议的)。
使用下面的模型,开始和结束损失都在减少,但是answerable_loss没有训练。
我们打算做的是:

  • 获取编码输入的第一个元素
  • 预测(可回答的,不可回答的)此元素
  • softmax(answerable,unanswerable)获取确定性百分比
  • 用交叉熵计算损失

(PS:注意到由于批处理而导致的一些棘手的维度问题)
据我所知,这是我们应该做的,也是我们正在代码中做的。但很明显它不起作用...
我似乎找不到我的错误。为什么会出现这种情况?

class LSTMFrontVerifier(nn.Module):
    name = "lstm-front-verifier"
    
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
    
        self.answerable = nn.Linear(encoder.config.hidden_size, 2)
        self.classifier = nn.Linear(encoder.config.hidden_size, 2)
    
    def forward(self, input_ids, attention_mask=None, start_positions=None, end_positions=None):
        outputs = self.encoder(input_ids, attention_mask=attention_mask)
        lstm_output = outputs.last_hidden_state
    
        logits = self.classifier(lstm_output)
    
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits, end_logits = start_logits.squeeze(-1), end_logits.squeeze(-1)
    
        # Given answer-ability
        unanswerabe = torch.logical_and(start_positions == 0, end_positions == 0).float()
    
        start_loss = F.cross_entropy(start_logits, start_positions)
        end_loss = F.cross_entropy(end_logits, end_positions)
    
        # Predict answerability
        pred_answerable = self.answerable(lstm_output[:, 0])
        answerable_pred = F.softmax(pred_answerable, dim=1)
        answerable_loss = F.cross_entropy(answerable_pred[:, 0], 1-unanswerabe) + \
                          F.cross_entropy(answerable_pred[:, 1], unanswerabe)
    
        print((start_loss + end_loss).item(), answerable_loss.item())
        loss = start_loss + end_loss + answerable_loss
    
        return loss

    def forward_ex(self, example):
        input_ids = example["input_ids"].to(device)
        start_positions = example["start_positions"].to(device) if "start_positions" in example else None
        end_positions = example["end_positions"].to(device) if "end_positions" in example else None
        attention_mask = example["attention_mask"].to(device) if "attention_mask" in example else None
        return self.forward(input_ids, attention_mask, start_positions, end_positions)

出于可重复性的目的,这里有一个最小的代码示例:

import torch
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
import torch.nn as nn
from transformers import AutoModel

max_train_examples = 1000
max_length = 384
stride = 128
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("roberta-base")

def preprocess_examples(examples):
    questions = [q.strip() for q in examples["question"]]

    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        if not answer["answer_start"]:
            start_positions.append(0)
            end_positions.append(0)
            continue

        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

def convert_to_tensors(examples):
    return {k: torch.tensor([x[k] for x in examples]) for k in examples[0]}

squad = load_dataset('squad_v2')

squad["train"] = squad["train"].select(range(max_train_examples))
tokenized_datasets = squad.map(
    preprocess_examples,
    batched=True,
    remove_columns=squad["train"].column_names,
)

train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=8, collate_fn=convert_to_tensors, shuffle=True)    

roberta = AutoModel.from_pretrained("roberta-base").to(device)
model = LSTMFrontVerifier(roberta).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-3)

model.train()
optimizer.zero_grad()  # Reset gradients tensors

for batch, x in enumerate(train_dataloader):
    # Compute prediction error
    loss = model.forward_ex(x)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

Package :
1.使用gpu安装pytorch:https://pytorch.org/get-started/locally/

  1. pip install transformers数据集
    print((start_loss + end_loss),answerable_loss)的一些打印:
11.970989227294922 16.621864318847656
11.742887496948242 16.640228271484375
11.448827743530273 16.66130828857422
11.063633918762207 16.665538787841797
10.49462890625 16.663175582885742
10.176043510437012 16.661867141723633
10.13321590423584 16.646312713623047
10.05152702331543 16.64898681640625
9.76708698272705 16.648393630981445
9.409551620483398 16.700483322143555
8.939659118652344 16.641441345214844
9.03275203704834 16.647899627685547
8.160870552062988 16.63787078857422
8.27975845336914 16.641223907470703
7.900142669677734 16.64410972595215
6.427922248840332 16.644954681396484
6.4332380294799805 16.643535614013672
6.626171112060547 16.642642974853516
4.79502010345459 16.640335083007812
6.948017120361328 16.641925811767578
5.472411632537842 16.642606735229492
6.458420753479004 16.63710594177246
6.552549362182617 16.637182235717773
4.95197868347168 16.637977600097656
5.235410690307617 16.643829345703125
4.700412750244141 16.63840675354004
4.11396598815918 16.646831512451172
5.13016414642334 16.643505096435547
3.7867109775543213 16.63637924194336
5.582259178161621 16.643115997314453
5.7655229568481445 16.64023208618164
5.085046768188477 16.63158416748047
4.153951644897461 16.63810920715332
4.100613594055176 16.644237518310547
4.206878662109375 16.636249542236328
3.450410842895508 16.635835647583008
4.827783584594727 16.63633918762207
2.2874913215637207 16.644474029541016
2.3297667503356934 16.647319793701172
3.4870200157165527 16.652259826660156
3.31907320022583 16.6363582611084
4.377845764160156 16.637659072875977
3.427989959716797 16.635705947875977
4.224106311798096 16.640310287475586
7rfyedvj

7rfyedvj1#

交叉熵损失的实现不正确。用于计算交叉熵可回答性损失的正确代码片段是:

pred_answerable = self.answerable(lstm_output[:, 0])
answerable_loss = F.cross_entropy(pred_answerable, (1-unanswerabe).long())
answerability_prediction = F.softmax(pred_answerable, dim=1)

相关问题