Two scripts for training and testing BERT-like pretrained models on Named Entity Recognition (NER) tasks.
- Train (finetune) a BERT-based NER model (from Hugging Face or a model stored on your computer) on a dataset formatted in CoNLL-like format.
- Use the trained model to predict NER tags on test data, getting an evaluation report, or new data, getting a tagged version the data.
The scripts handle sequences longer than the maximum length defined by the models by splitting the sequences in multiple chunks with a context overlap.
The input data should follow a CoNLL-like format:
- Each line contains a token in the first position and its corresponding tag in the last position, separated by a space.
- Sentences are separated by blank lines.
Example:
Barack B-PER
Obama I-PER
was O
born O
in O
Hawaii B-LOC
. O
usage: train_ner_tagger.py [-h] --data_dir DATA_DIR [--data_file_suffix DATA_FILE_SUFFIX] --model_dir MODEL_DIR [--model_name MODEL_NAME] [--k K] [--skip_cv] [--skip_train] [--learning_rate LEARNING_RATE]
[--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE] [--per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE] [--num_train_epochs NUM_TRAIN_EPOCHS] [--weight_decay WEIGHT_DECAY]
Train a token classifier using Hugging Face models.
options:
-h, --help show this help message and exit
--data_dir DATA_DIR Directory containing training data files.
--data_file_suffix DATA_FILE_SUFFIX
Extension of training data files (default=.txt).
--model_dir MODEL_DIR
Directory to save the final trained model and/or results.
--model_name MODEL_NAME
Name of the pretrained Hugging Face or path to a local model to use (default=bert-base-cased).
--k K Number of folds for cross-validation (default=10).
--skip_cv Do not do cross-validation (default=do cv).
--skip_train Do not train on the whole training data (default=do training).
--learning_rate LEARNING_RATE
Learning rate for the optimizer (default: 2e-5).
--per_device_train_batch_size PER_DEVICE_TRAIN_BATCH_SIZE
Batch size for training (default: 32).
--per_device_eval_batch_size PER_DEVICE_EVAL_BATCH_SIZE
Batch size for evaluation (default: 32).
--num_train_epochs NUM_TRAIN_EPOCHS
Number of training epochs (default: 10).
--weight_decay WEIGHT_DECAY
Weight decay for optimizer (default: 0.01).
usage: apply_ner_tagger.py [-h] --model_name MODEL_NAME --data_dir DATA_DIR [--data_file_suffix DATA_FILE_SUFFIX] [--ignore_unknown_tags] --output_dir OUTPUT_DIR [--overlap OVERLAP]
Predict tags for input text files using a trained token classification model.
options:
-h, --help show this help message and exit
--model_name MODEL_NAME
Path to a directory containing the trained model or name of a Hugging Face model.
--data_dir DATA_DIR Directory containing input files to be tagged.
--data_file_suffix DATA_FILE_SUFFIX
Extension of input files.
--ignore_unknown_tags
Ignore tags in input files that are not defined in the model (default=throw error).
--output_dir OUTPUT_DIR
Directory to save the output files with predicted tags.
--overlap OVERLAP Context length for sequences longer than model max length (default=50)
This is an example of training and testing a BERT NER model on the CoNNL++ dataset.
The following command trains a BERT model on the CoNLL++ dataset:
CUDA_VISIBLE_DEVICES=3 python train_ner_tagger.py \
--data_dir ./data \
--data_file_suffix train.txt \
--model_dir models/bert_conllpp
Example output:
204567 tokens, 14987 sentences read from 1 files.
Training fold 1...
{'loss': 1.4032, 'grad_norm': 3.16690731048584, 'learning_rate': 1.9952606635071093e-05, 'epoch': 0.02} {'loss': 0.6847, 'grad_norm': 1.0225292444229126, 'learning_rate': 1.990521327014218e-05, 'epoch': 0.05}
[... 10 fold validation ...]
{'loss': 0.0008, 'grad_norm': 0.01256786659359932, 'learning_rate': 9.478672985781992e-08, 'epoch': 9.95} {'loss': 0.0029, 'grad_norm': 0.05868620052933693, 'learning_rate': 4.739336492890996e-08, 'epoch': 9.98} {'loss': 0.0009, 'grad_norm': 0.016589513048529625, 'learning_rate': 0.0, 'epoch': 10.0} {'eval_loss': 0.033828623592853546, 'eval_runtime': 0.5779, 'eval_samples_per_second': 2592.008, 'eval_steps_per_second': 81.325, 'epoch': 10.0}
{'train_runtime': 210.2245, 'train_samples_per_second': 641.647, 'train_steps_per_second': 20.074, 'train_loss': 0.02340272987770111, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4220/4220 [03:30<00:00, 20.07it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [00:00<00:00, 85.92it/s]
Metrics for fold 10: {'eval_loss': 0.027988383546471596, 'eval_runtime': 0.556, 'eval_samples_per_second': 2694.247, 'eval_steps_per_second': 84.532, 'epoch': 10.0}
Confusion Matrix for fold 10 (labels=['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']):
[[680, 5, 5, 3, 0, 0, 0, 0, 5], [6, 326, 5, 1, 0, 4, 0, 0, 7], [9, 6, 594, 5, 0, 0, 4, 0, 7], [1, 2, 5, 638, 0, 0, 0, 2, 7], [0, 0, 0, 0, 100, 2, 4, 1, 2], [0, 2, 0, 0, 3, 100, 0, 0, 3], [0, 0, 0, 0, 2, 3, 333, 2, 3], [0, 0, 0, 0, 0, 1, 1, 445, 2], [2, 8, 4, 1, 0, 7, 5, 0, 16842]]
Cross-validation report saved to output/conllpp/cross_validation_report.json
Retraining on the entire dataset...
{'loss': 0.1346, 'grad_norm': 0.4300181269645691, 'learning_rate': 1.7867803837953093e-05, 'epoch': 1.07}
[... final training on the whole training dataset ...]
{'train_runtime': 199.9011, 'train_samples_per_second': 749.721, 'train_steps_per_second': 23.462, 'train_loss': 0.021958918129203163, 'epoch': 10.0}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4690/4690 [03:19<00:00, 23.46it/s]
Final model and tokenizer saved to models/bert_conllpp
The finetuned model is saved in models/bert_conllpp
and the cross_validation_report.json
in that directory shows the cross validation results:
cat models/bert_conllpp/crossvalidation_report.json
{
"fold_metrics": [
{
"fold": 1,
"metrics": {
"eval_loss": 0.03650137037038803,
"eval_runtime": 1.0708,
"eval_samples_per_second": 1399.861,
"eval_steps_per_second": 43.892,
"epoch": 10.0
},
"confusion_matrix": [
[ 668, 5, 6, 3, 0, 0, 1, 1, 2],
[ 1, 326, 7, 1, 0, 4, 0, 0, 7],
[ 14, 5, 580, 4, 0, 0, 3, 0, 8],
[ 2, 2, 7, 657, 1, 0, 1, 1, 7],
[ 0, 0, 0, 0, 102, 0, 8, 0, 4],
[ 1, 3, 0, 0, 0, 121, 3, 1, 8],
[ 0, 0, 1, 0, 2, 9, 330, 0, 9],
[ 0, 0, 0, 3, 0, 0, 2, 444, 0],
[ 4, 7, 3, 1, 1, 5, 8, 4, 17963]
],
"labels": [ "B-LOC", "B-MISC", "B-ORG", "B-PER", "I-LOC", "I-MISC", "I-ORG", "I-PER", "O"]
},
[... fold metrics repeated for all the folds ...]
],
"average_metrics": {
"eval_loss": 0.036608549393713476,
"eval_runtime": 0.7434000000000001,
"eval_samples_per_second": 2117.7788,
"eval_steps_per_second": 66.4155,
"epoch": 10.0
},
"aggregated_confusion_matrix": [
[ 6956, 52, 61, 26, 4, 0, 8, 1, 32],
[ 43, 3203, 64, 17, 0, 34, 3, 1, 73],
[ 106, 80, 5963, 64, 1, 1, 24, 0, 82],
[ 22, 20, 59, 6433, 1, 1, 2, 15, 47],
[ 5, 1, 1, 0, 1066, 16, 40, 7, 21],
[ 4, 24, 1, 0, 6, 1022, 35, 4, 59],
[ 6, 0, 17, 1, 28, 34, 3527, 16, 75],
[ 1, 0, 0, 11, 5, 5, 16, 4480, 10],
[ 33, 100, 82, 24, 17, 76, 78, 8, 170106]
],
"labels": [ "B-LOC", "B-MISC", "B-ORG", "B-PER", "I-LOC", "I-MISC", "I-ORG", "I-PER", "O"]
}
The following command tests the trained model on a test dataset and writes predictions to an output directory:
CUDA_VISIBLE_DEVICES=3 python apply_ner_tagger.py \
--data_dir ./data \
--data_file_suffix test.txt \
--model_name models/conllpp \
--output_dir ./output_conllpp_annotated
Example output:
46666 tokens, 3684 sentences read from 1 files.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 231/231 [00:01<00:00, 121.99it/s]
Labels:
['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']
Confusion Matrix:
[[ 1575 18 28 1 1 1 9 0 13]
[ 19 623 36 9 0 6 1 0 29]
[ 44 37 1579 21 0 0 7 0 27]
[ 10 2 27 1567 0 0 0 0 12]
[ 1 0 1 2 244 1 6 0 4]
[ 1 8 1 0 5 182 19 1 37]
[ 1 0 4 1 20 14 807 3 32]
[ 0 0 0 2 1 1 2 1155 0]
[ 10 28 48 10 4 36 30 1 38241]]
Classification Report:
precision recall f1-score support
B-LOC 0.95 0.96 0.95 1646
B-MISC 0.87 0.86 0.87 723
B-ORG 0.92 0.92 0.92 1715
B-PER 0.97 0.97 0.97 1618
I-LOC 0.89 0.94 0.91 259
I-MISC 0.76 0.72 0.74 254
I-ORG 0.92 0.91 0.92 882
I-PER 1.00 0.99 1.00 1161
O 1.00 1.00 1.00 38408
accuracy 0.99 46666
macro avg 0.92 0.92 0.92 46666
weighted avg 0.99 0.99 0.99 46666
See LICENSE