Pytorch BERT输入渐变

sirbozc5  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(173)

我试图从pytorch中的一个BERT模型得到输入梯度。我该怎么做呢?假设,y' = BertModel(x)。我试图找到$d(loss(y,y'))/dx$

bqf10yzr

bqf10yzr1#

Bert模型的一个问题是,您的输入大多包含令牌ID而不是令牌嵌入,这使得获取梯度变得困难,因为令牌ID和令牌嵌入之间的关系是不连续的。要解决这个问题,您可以使用令牌嵌入。


# get your batch data: token_id, mask and labels

token_ids, mask, labels = batch

# get your token embeddings

token_embeds=BertModel.bert.get_input_embeddings().weight[token_ids].clone()

# track gradient of token embeddings

token_embeds.requires_grad=True

# get model output that contains loss value

outs = BertModel(inputs_embeds=inputs_embeds,labels=labels)
loss=outs.loss

在得到损耗值后,可以在this answer中使用torch.autograd.grad或后退函数

loss.backward()
grad=token_embeds.grad
vfwfrxfs

vfwfrxfs2#

您可以使用torch.autograd.graddocumentation):

y_pred = BertModel(x)
out = loss_func(y_label, y_pred)  # not necessary a scalar!
grad = torch.autograd.grad(
    outputs=out,
    inputs=x,
    grad_outputs=torch.ones(out.size()).to(device), # or simply None if out is a scalar
    retain_graph=False,
    create_graph=False,
    only_inputs=True)[0]

如果要使用grad计算损失并向后应用(通常用于计算渐变补偿),则应将retain_graphcreate_graph传递给True。否则,请将其保留为False以保存内存和时间。

相关问题