pytorch 交叉编码器Transformer将每个输入收敛到相同的CLS嵌入

mrfwxfqh  于 2023-11-19  发布在  其他
关注(0)|答案(1)|浏览(99)

我尝试使用Huggingface transformers和PyTorch从“distilbert-base-uncased”开始创建一个交叉编码器模型。架构很简单:从连接的输入字符串中获取CLS嵌入(这由Huggingface tokenizer处理),然后将其通过最终FC线性层传递到1个输出logit。损失函数是内置的torch.nn.BCEWithLogitsLoss函数。
该模型未能在我的数据集上正确学习,而是快速将CLS嵌入(在最终FC线性层之前)收敛到每个输入的相同嵌入(事实上,其他令牌也收敛到相同的嵌入)。最后一层然后将此嵌入Map到训练样本中的阳性比率,这是假设恒定嵌入函数的预期行为。
为了调试的目的,我简单地给它提供了一个由相同的3个句子对组成的虚拟数据集(1个标记为阳性,另外2个标记为阴性)。同样的行为持续存在,但是当我冻结Transformer参数时(这样只有最后的FC被训练),模型正确地在数据点上过拟合,正如预期的那样。
我的模型架构:

class CrossEncoderModel(nn.Module):
    """
    Architecture:
    - Transformer
    - Final FC linear layer to one output for binary classification
    """

    def __init__(
        self, transformer_model: str, tokenizer: PreTrainedTokenizerFast
    ) -> None:
        super(ParagraphCrossEncoderModel, self).__init__()

        self.transformer = AutoModel.from_pretrained(transformer_model)
        print(type(self.transformer))
        self.transformer.resize_token_embeddings(len(tokenizer))
        self.fc = nn.Linear(self.transformer.config.hidden_size, 1)

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        cls_embedding = outputs.last_hidden_state[:, 0, :]
        logits = self.fc(cls_embedding).squeeze(-1)
        return logits

字符串
我的损失/更新:

loss = torch.nn.BCEWithLogitsLoss(logits, labels)
loss.backward()

optimizer.step()
optimizer.zero_grad()


我尝试改变学习率和批量大小,但都没有改变到相同CLS嵌入的收敛。我怀疑我的模型架构有问题,但我很难找到确切的问题。当我用手动目标替换损失函数时,这种行为仍然存在:

class TestLoss(nn.Module):
    def __init__(self):
        super(TestLoss, self).__init__()

    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        return torch.sum(torch.abs(logits - torch.tensor([10.0, -10.0, -10.0]).to(device)))
    # still converges to all embeddings being the same

9njqaruj

9njqaruj1#

我知道你是如何处理来自Transformer输出的标记嵌入的。当使用交叉编码器模型时,你通常希望使用两个输入句子(或句子对)的嵌入,而不仅仅是CLS标记嵌入

class CrossEncoderModel(nn.Module):
    def __init__(self, transformer_model: str, tokenizer: PreTrainedTokenizerFast) -> None:
        super(CrossEncoderModel, self).__init__()

        self.transformer = AutoModel.from_pretrained(transformer_model)
        self.transformer.resize_token_embeddings(len(tokenizer))
        self.fc = nn.Linear(self.transformer.config.hidden_size, 1)
        self.layer_norm = nn.LayerNorm(self.transformer.config.hidden_size)

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        # Use pooled output or concatenate and pool all embeddings
        pooled_output = outputs.pooler_output
        # Alternatively, concatenate and use mean pooling
        # all_embeddings = outputs.last_hidden_state
        # pooled_output = torch.mean(all_embeddings, dim=1)

        pooled_output = self.layer_norm(pooled_output)
        logits = self.fc(pooled_output).squeeze(-1)
        return logits

字符串
尝试使用池化策略,看看哪种策略最适合您的特定任务

相关问题