# get your batch data: token_id, mask and labels
token_ids, mask, labels = batch
# get your token embeddings
token_embeds=BertModel.bert.get_input_embeddings().weight[token_ids].clone()
# track gradient of token embeddings
token_embeds.requires_grad=True
# get model output that contains loss value
outs = BertModel(inputs_embeds=inputs_embeds,labels=labels)
loss=outs.loss
y_pred = BertModel(x)
out = loss_func(y_label, y_pred) # not necessary a scalar!
grad = torch.autograd.grad(
outputs=out,
inputs=x,
grad_outputs=torch.ones(out.size()).to(device), # or simply None if out is a scalar
retain_graph=False,
create_graph=False,
only_inputs=True)[0]
2条答案
按热度按时间bqf10yzr1#
Bert模型的一个问题是,您的输入大多包含令牌ID而不是令牌嵌入,这使得获取梯度变得困难,因为令牌ID和令牌嵌入之间的关系是不连续的。要解决这个问题,您可以使用令牌嵌入。
在得到损耗值后,可以在this answer中使用torch.autograd.grad或后退函数
vfwfrxfs2#
您可以使用
torch.autograd.grad
(documentation):如果要使用
grad
计算损失并向后应用(通常用于计算渐变补偿),则应将retain_graph
和create_graph
传递给True
。否则,请将其保留为False
以保存内存和时间。