我按照这个链接,但它在Keras实施.
Cannot add CRF layer on top of BERT in keras for NER
型号描述
有没有可能在TokenClassification model
的上面添加一个简单的自定义pytorch-crf
层,这会使模型更加健壮。
from torchcrf import CRF
model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)
class BERT_CRF(nn.Module):
def __init__(self, bert_model, num_labels):
super(BERT_CRF, self).__init__()
self.bert = bert_model
self.dropout = nn.Dropout(0.25)
self.classifier = nn.Linear(4*768, num_labels)
self.crf = CRF(num_labels, batch_first = True)
def forward(self, input_ids, attention_mask, labels=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
**sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
sequence_output = self.dropout(sequence_output)
emission = self.classifier(sequence_output) # [32,256,17]
labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
if labels is not None:
loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
return [loss, prediction]
else:
prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
return prediction
args = TrainingArguments(
"spanbert_crf_ner-pos2",
# evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
num_train_epochs=1,
weight_decay=0.01,
per_device_train_batch_size=8,
# per_device_eval_batch_size=32
fp16=True
# bf16=True #Ampere GPU
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_data,
# eval_dataset=train_data,
# data_collator=data_collator,
# compute_metrics=compute_metrics,
tokenizer=tokenizer)
我在**sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
行上得到错误
由于outputs = self.bert(input_ids, attention_mask=attention_mask)
给出了令牌分类. How can we get hidden states so that I can concate last 4 hidden states. so that I can do
的对数,因此输出[1][-1]'?
或者是他们实现BERT-CRF
模型的更简单的方法?
1条答案
按热度按时间vh0rcniy1#
我也在寻找一种方法来尝试BERT + CRF。我写的代码基于下面的youtube视频,它的工作,所以我分享给你。我认为时间上的视频供参考是大约15:35,18:24。https://www.youtube.com/watch?v=Yss8RRDBMzg&ab_channel=StartPythonClub