pytorch 我怎样才能从deberta模型中得到汇总输出?

nxagd54h  于 2023-02-04  发布在  其他
关注(0)|答案(1)|浏览(236)

谁能告诉我如何才能得到池输出形式德贝塔模型?
有人能告诉我如何从Deberta模型中获取池化输出吗?我希望将它从 DebertaModel 中用于我的分类模型,而不使用 *DebertaForSequenceClassification *

wribegjk

wribegjk1#

我是这样解的:

deberta_model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base")
deberta_model.config.num_labels = 1

class DebrtaRegressor(nn.Module):

def __init__(self):
    
    super(DebrtaRegressor, self).__init__()

    self.deberta = deberta_model
    self.sigmoid1 = nn.Sigmoid()
    
    
    
def forward(self, input_ids, attention_masks):
    outputs = self.deberta(input_ids, attention_masks)
    outputs = outputs.logits[:, : 1]
    outputs = self.sigmoid1(outputs)
    return outputs

相关问题