Model description
Is it possible to add simple custom pytorch-crf layer on top of TokenClassification model. It will make the model more robust.
There should be simple Notebook tutorial which teaches us to add our own custom layer on top of Hugging face models for
- Classification
- Token Classification ( BIO)
By taking an example from dslim/bert-base-NER. Then add from torchcrf import CRF on top of it.
I am planning to do this, but I don’t know how to get this feature coded. Any leads or Notebook example would be helpful.
from torchcrf import CRF
model_checkpoint = "dslim/bert-base-NER"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint,add_prefix_space=True)
config = BertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)
bert_model = BertForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,ignore_mismatched_sizes=True)
class BERT_CRF(nn.Module):
def __init__(self, bert_model, num_labels):
super(BERT_CRF, self).__init__()
self.bert = bert_model
self.dropout = nn.Dropout(0.25)
self.classifier = nn.Linear(4*768, num_labels)
self.crf = CRF(num_labels, batch_first = True)
def forward(self, input_ids, attention_mask, labels=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
**sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
sequence_output = self.dropout(sequence_output)
emission = self.classifier(sequence_output) # [32,256,17]
labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
if labels is not None:
loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
return [loss, prediction]
else:
prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
return prediction
args = TrainingArguments(
"spanbert_crf_ner-pos2",
# evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
num_train_epochs=1,
weight_decay=0.01,
per_device_train_batch_size=8,
# per_device_eval_batch_size=32
fp16=True
# bf16=True #Ampere GPU
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_data,
# eval_dataset=train_data,
# data_collator=data_collator,
# compute_metrics=compute_metrics,
tokenizer=tokenizer)
I get error on line **sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1)**
As outputs = self.bert(input_ids, attention_mask=attention_mask) gives the logits for tokenclassification. How can we get hidden states so that I can concate last 4 hidden states. so that I can dooutputs[1][-1]`?
Or is their easier way to implement BERT-CRF model?