Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can you release the hyper-parameter of NER task? #223

Closed
Albert-Ma opened this issue Dec 3, 2018 · 30 comments
Closed

Can you release the hyper-parameter of NER task? #223

Albert-Ma opened this issue Dec 3, 2018 · 30 comments

Comments

@Albert-Ma
Copy link

Albert-Ma commented Dec 3, 2018

My result of NER is not as good as your paper said.

@qiuwei
Copy link

qiuwei commented Dec 5, 2018

@Albert-XY I heard some important details are not mentioned in the paper. They used the document context instead of the sentence context to get the bert activations for the words.

I did a hyperparameter sweep and the best result I got on conll2003 is 0.907(F1 on test), which is quite far away from the numbers reported in the paper.

@Albert-Ma
Copy link
Author

Albert-Ma commented Dec 5, 2018

Could you please tell me the hyper-parameter you changed and their values?

Did you add a additional output layer for classify or just use a softmax-classifier?

Thanks
@qiuwei

@qiuwei
Copy link

qiuwei commented Dec 7, 2018

@Albert-XY I replaced the softmax layer with a CRF layer, add apply a dropout with rate 0.1 before the CRF layer.
The batch size is 32. The learning rate is 2e-5.

That's the best I can get with fine tuning.

If use bert activations as input into a 2 layer bilstm-crf model, I could get 91.7 on the test set, still quite far away from the sota.

@dsindex
Copy link

dsindex commented Dec 7, 2018

@Albert-XY @qiuwei

i also have a similar experience.

https://github.com/dsindex/BERT-BiLSTM-CRF-NER

i can't find a way to increase the test f1 score over 91.3 by official conlleval.pl.
(dev f1 score 96.0 by tf_metrics.py).

  • use fine-tuning or feature-based model
  • add multi-layer bilstm, crf decoder
  • use cased or uncased model
  • add dropout
  • change batch size
  • change learning rate
  • increase epoch(train_steps)

maybe i lost some important details. or miss implementation?
i am very curious about the f1 score(dev f1 96.4, test f1 92.4) in the paper(https://arxiv.org/pdf/1810.04805.pdf).
is it really possible to reproduce?

@qiuwei
Copy link

qiuwei commented Dec 10, 2018

@dsindex I head that they fed the document level context to the language model to get the activation for each word.
However, I am not clear about how the document level context was used. I did some preliminary experiments following this idea, i.e., concatenating the sentences before and afterward, but couldn't get any further improvement.

I hope the authors could release more details about the NER experiment. @jacobdevlin-google

@dsindex
Copy link

dsindex commented Dec 11, 2018

here is the best result so far.

dev eval_f = 0.9627948 (token-based)
test eval_f = 0.92653006(token-based)
test f1(entity-based) = 0.9155

i carefully doubt the f1 score( dev f1 96.4, test f1 92.4) from the paper was calculated by token-based not by entity-based.

[update]
when it comes to evaluating, 'eval_batch_size' and 'predict_batch_size' are 128.

@Albert-Ma
Copy link
Author

here is the best result so far.

dev eval_f = 0.9627948 (token-based)
test eval_f = 0.92653006(token-based)
test f1(entity-based) = 0.9155

i carefully doubt the f1 score( dev f1 96.4, test f1 92.4) from the paper was calculated by token-based not by entity-based.

Is the result got by the official evaluate script 'conlleval.pl' ?

@dsindex
Copy link

dsindex commented Dec 11, 2018

@Albert-XY

dev eval_f = 0.9627948 (token-based)
test eval_f = 0.92653006(token-based)

those scores were reported during training.
(estimator.evaluate())
==> by https://github.com/dsindex/BERT-BiLSTM-CRF-NER/blob/master/tf_metrics.py

def metric_fn(label_ids, pred_ids, per_example_loss, input_mask):
                    # ['<pad>'] + ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "X"]
                    indices = [2, 3, 4, 5, 6, 7, 8, 9]
                    precision = tf_metrics.precision(label_ids, pred_ids, num_labels, indices, input_mask)
                    recall = tf_metrics.recall(label_ids, pred_ids, num_labels, indices, input_mask)
                    f = tf_metrics.f1(label_ids, pred_ids, num_labels, indices, input_mask)
                    accuracy = tf.metrics.accuracy(label_ids, pred_ids, input_mask)
                    loss = tf.metrics.mean(per_example_loss)
                    return {
                        'eval_precision': precision,
                        'eval_recall': recall,
                        'eval_f': f,
                        'eval_accuracy': accuracy,
                        'eval_loss': loss,
                    }
test f1(entity-based) = 0.9155

this score was calculated after converting the predictions.
==> by https://github.com/dsindex/BERT-BiLSTM-CRF-NER/blob/master/conlleval.pl

converted prediction data looks like below.
(skip 'X' tags)

S NN B-NP O O
- : O O O
J NNP B-NP B-LOC B-MISC
GE VB B-VP O O
L NNP B-NP O B-LOC
W NNP I-NP O O
, , O O O
CH NNP B-NP B-PER O
IN IN B-PP O O
S DT B-NP O O
DE NN I-NP O O
. . O O O

Na NNP B-NP B-PER B-PER
La NNP I-NP I-PER I-PER

AL NNP B-NP B-LOC B-LOC
, , O O O
United NNP B-NP B-LOC B-LOC
Arab NNP I-NP I-LOC I-LOC
Emirates NNPS I-NP I-LOC I-LOC
1996 CD I-NP O O

Japan NNP B-NP B-LOC B-LOC
began VBD B-VP O O
the DT B-NP O O
defence NN I-NP O O
of IN B-PP O O
their PRP$ B-NP O O
Asian JJ I-NP B-MISC B-MISC
Cup NNP I-NP I-MISC I-MISC
title NN I-NP O O
with IN B-PP O O
a DT B-NP O O
lucky JJ I-NP O O
2 CD I-NP O O
win VBP B-VP O O
against IN B-PP O O
Syria NNP B-NP B-LOC B-LOC
in IN B-PP O O
a DT B-NP O O
Group NNP I-NP O O
C NNP I-NP O I-MISC
championship NN I-NP O O
match NN I-NP O O
on IN B-PP O O
Friday NNP B-NP O O
. . O O O

...

@Albert-Ma
Copy link
Author

@Albert-XY

dev eval_f = 0.9627948 (token-based)
test eval_f = 0.92653006(token-based)

those scores were reported during training.
(estimator.evaluate())
==> by https://github.com/dsindex/BERT-BiLSTM-CRF-NER/blob/master/tf_metrics.py

def metric_fn(label_ids, pred_ids, per_example_loss, input_mask):
                    # ['<pad>'] + ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC", "X"]
                    indices = [2, 3, 4, 5, 6, 7, 8, 9]
                    precision = tf_metrics.precision(label_ids, pred_ids, num_labels, indices, input_mask)
                    recall = tf_metrics.recall(label_ids, pred_ids, num_labels, indices, input_mask)
                    f = tf_metrics.f1(label_ids, pred_ids, num_labels, indices, input_mask)
                    accuracy = tf.metrics.accuracy(label_ids, pred_ids, input_mask)
                    loss = tf.metrics.mean(per_example_loss)
                    return {
                        'eval_precision': precision,
                        'eval_recall': recall,
                        'eval_f': f,
                        'eval_accuracy': accuracy,
                        'eval_loss': loss,
                    }
test f1(entity-based) = 0.9155

this score was calculated after converting the predictions.
==> by https://github.com/dsindex/BERT-BiLSTM-CRF-NER/blob/master/conlleval.pl

converted prediction data looks like below.
(skip 'X' tags)

S NN B-NP O O
- : O O O
J NNP B-NP B-LOC B-MISC
GE VB B-VP O O
L NNP B-NP O B-LOC
W NNP I-NP O O
, , O O O
CH NNP B-NP B-PER O
IN IN B-PP O O
S DT B-NP O O
DE NN I-NP O O
. . O O O

Na NNP B-NP B-PER B-PER
La NNP I-NP I-PER I-PER

AL NNP B-NP B-LOC B-LOC
, , O O O
United NNP B-NP B-LOC B-LOC
Arab NNP I-NP I-LOC I-LOC
Emirates NNPS I-NP I-LOC I-LOC
1996 CD I-NP O O

Japan NNP B-NP B-LOC B-LOC
began VBD B-VP O O
the DT B-NP O O
defence NN I-NP O O
of IN B-PP O O
their PRP$ B-NP O O
Asian JJ I-NP B-MISC B-MISC
Cup NNP I-NP I-MISC I-MISC
title NN I-NP O O
with IN B-PP O O
a DT B-NP O O
lucky JJ I-NP O O
2 CD I-NP O O
win VBP B-VP O O
against IN B-PP O O
Syria NNP B-NP B-LOC B-LOC
in IN B-PP O O
a DT B-NP O O
Group NNP I-NP O O
C NNP I-NP O I-MISC
championship NN I-NP O O
match NN I-NP O O
on IN B-PP O O
Friday NNP B-NP O O
. . O O O

...

Your code has bugs. It miss some token labels when predict.

For example:

token1 tag
token2 tag
token3 [nothing]

You can check the predicted result carefully.

@dsindex
Copy link

dsindex commented Dec 13, 2018

@Albert-XY

oh~ could you point out the code where the bug comes from?

i checked the converted prediction file.

$ wc -l NERdata/test.txt
50350 NERdata/test.txt

$ wc -l pred.txt
49888

$ grep "DOCSTART" NERdata/test.txt | wc -l
231

because i removed '-DOCSTART-' line.

49888 + 2*231 = 50350

so, there was same lines of output data matched.
and i checked if there is an empty label. but every lines of pred.txt have 5 columns.

hmm...

@Albert-Ma
Copy link
Author

@Albert-XY

oh~ could you point out the code where the bug comes from?

i checked the converted prediction file.

$ wc -l NERdata/test.txt
50350 NERdata/test.txt

$ wc -l pred.txt
49888

$ grep "DOCSTART" NERdata/test.txt | wc -l
231

because i removed '-DOCSTART-' line.

49888 + 2*231 = 50350

so, there was same lines of output data matched.
and i checked if there is an empty label. but every lines of pred.txt have 5 columns.

hmm...

It does match the lines, because some line just '\n', not label. I found two cases, it should have maybe 40 cases. Check out the label_test.txt please.

In label_test.txt

line 1596, tokens: 'Pakistan', only '\n' follow this, no label.
line 3017, tokens:'Barbarians - 15 - Tim...' it miss one label.

....

I'm checking the code...

@dsindex
Copy link

dsindex commented Dec 17, 2018

i got 92.16% with BERT large model(note that this is not on average)

lowercase='False'
bert_model_dir=${CDIR}/cased_L-24_H-1024_A-16

python bert_lstm_ner.py   \
        --task_name="NER"  \
        --do_train=True   \
        --use_feature_based=False \
        --do_predict=True \
        --use_crf=True \
        --data_dir=${CDIR}/NERdata  \
        --vocab_file=${bert_model_dir}/vocab.txt  \
        --do_lower_case=${lowercase} \
        --bert_config_file=${bert_model_dir}/bert_config.json \
        --init_checkpoint=${bert_model_dir}/bert_model.ckpt   \
        --max_seq_length=150   \
        --lstm_size=256 \
        --train_batch_size=16   \
        --eval_batch_size=32   \
        --predict_batch_size=32   \
        --bert_dropout_rate=0.2 \
        --bilstm_dropout_rate=0.2 \
        --learning_rate=2e-5   \
        --num_train_epochs=100   \
        --data_config_path=${CDIR}/data.conf \
        --output_dir=${CDIR}/output/result_dir/
processed 46435 tokens with 5648 phrases; found: 5663 phrases; correct: 5212.
accuracy:  98.47%; precision:  92.04%; recall:  92.28%; FB1:  92.16
              LOC: precision:  93.18%; recall:  93.35%; FB1:  93.26  1671
             MISC: precision:  83.53%; recall:  82.34%; FB1:  82.93  692
              ORG: precision:  90.57%; recall:  91.39%; FB1:  90.98  1676
              PER: precision:  96.00%; recall:  96.41%; FB1:  96.20  1624

@Albert-XY @qiuwei

@nreimers
Copy link

@dsindex
When I understand you correctly, your 92.16 F1 (test set) is based on a system that uses a BiLSTM-CRF on top.

I tried it with the simple classifier that was described in the paper, but sadly only achieve performances of about 90% F1 (base cased BERT-model using official conlleval.pl script).

Did anyone have any success with improving the performances when using a simple classifier like in the paper?

@dsindex
Copy link

dsindex commented Jan 10, 2019

@nreimers

i had same question about it. simple softmax layer on the top of bert did not perform well, far from the reported score.

i guess ‘autoML’ would find a way to get there?
or other resources and network archtectures the paper didn’t say.

@nreimers
Copy link

I'm afraid that either an important aspect is missing or, what could also be, that a different F1-score metric was used in the paper, for example one that computes the score based on tokens and not on spans. NER sadly faces many issues with a bad reproducibility: http://science-miner.com/a-reproducibility-study-on-neural-ner/

I would love to see the official implementation. It's a bit pity to have a paper claiming a simple state-of-the-art architecture that cannot be reproduced.

@r-wheeler
Copy link

r-wheeler commented May 4, 2019

@dsindex Digging up an old thread, but from the paper:

To make this compatible with WordPiece
tokenization, we feed each CoNLL-tokenized
input word into our WordPiece tokenizer and
use the hidden state corresponding to the first
sub-token as input to the classifier.

Where no prediction is made for X. Since 
the WordPiece tokenization boundaries are a
known part of the input, this is done for both
training and test.

This makes it sound a bit different than having the CRF layer learn the X token and then simply removing it after running .predict when writing the predictions to a file for evaluation.

@congchan
Copy link

congchan commented May 8, 2019

@dsindex Digging up an old thread, but from the paper:

To make this compatible with WordPiece
tokenization, we feed each CoNLL-tokenized
input word into our WordPiece tokenizer and
use the hidden state corresponding to the first
sub-token as input to the classifier.

Where no prediction is made for X. Since 
the WordPiece tokenization boundaries are a
known part of the input, this is done for both
training and test.

This makes it sound a bit different than having the CRF layer learn the X token and then simply removing it after running .predict when writing the predictions to a file for evaluation.

I wonder how it is done in training? Does it mean the loss correspond to "##..." tokens and "X" label are never considered in training?

@dsindex
Copy link

dsindex commented May 8, 2019

@congchan

no, it means that ‘##...’ with ‘X’ tags are also trained along with actual tags(ex, PER, LOC, ...).
but, the accuracy for ‘X’ tag is almost perfect. so, the negative effect from those tags would be very small.
i am not sure whether the author used same setting or not.

@1664403370
Copy link

here is the best result so far.

dev eval_f = 0.9627948 (token-based)
test eval_f = 0.92653006(token-based)
test f1(entity-based) = 0.9155

i carefully doubt the f1 score( dev f1 96.4, test f1 92.4) from the paper was calculated by token-based not by entity-based.

[update]
when it comes to evaluating, 'eval_batch_size' and 'predict_batch_size' are 128.

Why is my running result all zero? Is the parameter not passed in? what should I do? Thank you

@ghaddarAbs
Copy link

In order to reproduce the conll score reported in BERT paper (92.4 bert-base and 92.8 bert-large) one trick is to apply a truecaser on article titles (all upper case sentences) as preprocessing step for conll train/dev/test. This can be simply done with the following method.

#https://github.com/daltonfury42/truecase
#pip install truecase
import truecase
import re




# original tokens
#['FULL', 'FEES', '1.875', 'REOFFER', '99.32', 'SPREAD', '+20', 'BP']

def truecase_sentence(tokens):
   word_lst = [(w, idx) for idx, w in enumerate(tokens) if all(c.isalpha() for c in w)]
   lst = [w for w, _ in word_lst if re.match(r'\b[A-Z\.\-]+\b', w)]

   if len(lst) and len(lst) == len(word_lst):
       parts = truecase.get_true_case(' '.join(lst)).split()

       # the trucaser have its own tokenization ...
       # skip if the number of word dosen't match
       if len(parts) != len(word_lst): return tokens

       for (w, idx), nw in zip(word_lst, parts):
           tokens[idx] = nw

# truecased tokens
#['Full', 'fees', '1.875', 'Reoffer', '99.32', 'spread', '+20', 'BP']

Also, i found useful to use : very small learning rate (5e-6) \ large batch size (128) \ high epoch num (>40).

With these configurations and preprocessing, I was able to reach 92.8 with bert-large.

@dsindex
Copy link

dsindex commented Jun 28, 2020

@ghaddarAbs

thank you so much~!

https://github.com/dsindex/ntagger/tree/master/data/conll2003_truecase

with conversion to truecased conll2003 data, i got consistent improvement.

https://github.com/dsindex/ntagger#conll-2003-english-1

BERT-large, BiLSTM :  91.32  -> 91.89
BERT-large-squad, BiLSTM : 91.75 -> 92.17
SpanBERT-large, BiLSTM : 91.39 -> 92.01
RoBERTa-large : 91.83 -> 91.90

but, it is very hard to reach 92.8%;;

@ghaddarAbs
Copy link

ghaddarAbs commented Jun 28, 2020

@dsindex

What are the HPs you are using? ... In my case, I can reach 92.8 on the test with:

  • spanBert-large
  • lr= 5e-6
  • train_epoch= 40
  • batch_size= 96
  • max_seq_len = 64 (you need to intelligently split the sequence)
  • dropout (hidden\attention) = .2
  • crf (gather first sub tokens and apply crf)

to use large batch_size (96 or 128) for fine-tuning you can use the method below:

seq_lst = split_long_sequence(tokenizer, tokens, tags)
for sent_part_num, (tok_lst, tag_lst) in enumerate(seq_lst):
   if not tok_lst: continue
   # after decoding just sum up the sentence parts

def split_long_sequence(tokenizer, tokens, tags, max_seq_len=64):

   # 1 because of [CLS] token
   count, punct_index, tag_window = 1, [], []
   tmp_tags = [0, 0] + tags + [0, 0]

   for idx, token in enumerate(tokens):
       bert_lst = tokenizer.tokenize(token)
       count += len(bert_lst)

       if re.match(r'[^a-zA-Z0-9]', token):
           punct_index.append(idx)

       t_idx = idx + 2
       if idx and all([t == 0 for t in tmp_tags[t_idx-2:t_idx+2]]):
           tag_window.append(idx)

   if count < max_seq_len:
       return [(tokens, tags)]

   pick_lst = tag_window if tag_window else punct_index
   if not pick_lst:
       mid = len(tokens) // 2
   else:
       index_lst = [(i, math.fabs(i - len(tokens)//2)) for i in pick_lst]
       index_lst.sort(key=lambda x:x[1])
       mid = index_lst[0][0]

   l1 = split_long_sequence(tokenizer, tokens[:mid], tags[:mid], max_seq_len)
   l2 = split_long_sequence(tokenizer, tokens[mid:], tags[mid:], max_seq_len)

   return l1 + l2

@dsindex
Copy link

dsindex commented Jun 28, 2020

@ghaddarAbs

* BERT-large, BiLSTM :  91.32  -> 91.89
  - batch : 16
  - lr : 1e-5
  - epoch : 10
  - n_ctx(max_seq_len) : 180
  - dropout : 0.1
  - lstm : bi-lstm 2 layers, 200d
  - print out the prediction result to a file and evaluate by conlleval.pl official script.
    - https://github.com/dsindex/ntagger/blob/master/evaluate.py#L143
    - https://github.com/dsindex/ntagger/blob/master/etc/conlleval.pl

* BERT-large-squad, BiLSTM : 91.75 -> 92.17
  - same as above 
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30

* SpanBERT-large, BiLSTM : 91.39 -> 92.01
  - same as above
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30

* RoBERTa-large : 91.83 -> 91.90
  - same as above
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30
  - and --bert_disable_lstm (do not use lstm layer)

i think the differences are larger batch size, smaller sequence length and final CRF layer.
in case CRF, i did experiments with bert-base and got the best result with LSTM only.
so, i do not use CRF layer.

스크린샷 2020-06-29 오전 1 07 57

@ghaddarAbs
Copy link

@dsindex results were for the fine-tuning approach not feature based with LSTM.

@yzhangcs
Copy link

yzhangcs commented Dec 1, 2020

@ghaddarAbs

* BERT-large, BiLSTM :  91.32  -> 91.89
  - batch : 16
  - lr : 1e-5
  - epoch : 10
  - n_ctx(max_seq_len) : 180
  - dropout : 0.1
  - lstm : bi-lstm 2 layers, 200d
  - print out the prediction result to a file and evaluate by conlleval.pl official script.
    - https://github.com/dsindex/ntagger/blob/master/evaluate.py#L143
    - https://github.com/dsindex/ntagger/blob/master/etc/conlleval.pl

* BERT-large-squad, BiLSTM : 91.75 -> 92.17
  - same as above 
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30

* SpanBERT-large, BiLSTM : 91.39 -> 92.01
  - same as above
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30

* RoBERTa-large : 91.83 -> 91.90
  - same as above
  - except, --use_transformers_optimizer --warmup_epoch=0 --weight_decay=0.0 --epoch=30
  - and --bert_disable_lstm (do not use lstm layer)

i think the differences are larger batch size, smaller sequence length and final CRF layer.
in case CRF, i did experiments with bert-base and got the best result with LSTM only.
so, i do not use CRF layer.

스크린샷 2020-06-29 오전 1 07 57

Are the above results the average of several runs?

@dsindex
Copy link

dsindex commented Dec 2, 2020

@yzhangcs

those are not the average scores but fixed seed yields very similar scores.

https://github.com/dsindex/ntagger#conll-2003-english-1

  • sample commands
$ python preprocess.py --config=configs/config-bert.json --data_dir=data/conll2003 --bert_model_name_or_path=./embeddings/bert-large-cased

$ python train.py --config=configs/config-bert.json --data_dir=data/conll2003 --save_path=pytorch-model-bert.pt --bert_model_name_or_path=./embeddings/bert-large-cased --bert_output_dir=bert-checkpoint --batch_size=16 --lr=1e-5 --epoch=10

$ python evaluate.py --config=configs/config-bert.json --data_dir=data/conll2003 --model_path=pytorch-model-bert.pt --bert_output_dir=bert-checkpoint

$ cd data/conll2003; perl ../../etc/conlleval.pl < test.txt.pred ; cd ../..

@yzhangcs
Copy link

yzhangcs commented Dec 2, 2020

@dsindex Thank u!

@Albert-Ma
Copy link
Author

Huggingface's transformers repo released the run_ner.py, does anyone tried and reproduced some better result?

@BCWang93
Copy link

Huggingface's transformers repo released the run_ner.py, does anyone tried and reproduced some better result?

@Albert-Ma ,hi,can you get the score 92.4 use huggingface code?

@RichardScottOZ
Copy link

Huggingface's transformers repo released the run_ner.py, does anyone tried and reproduced some better result?

@Albert-Ma ,hi,can you get the score 92.4 use huggingface code?

page not found for that one?

JayRvanDam pushed a commit to JayRvanDam/BERTmodelclone that referenced this issue Dec 7, 2023
---
language: en
datasets:
- conll2003
---
# bert-large-NER

## Model description

**bert-large-NER** is a fine-tuned BERT-Large model that is ready to use for **Named Entity Recognition** and achieves **state-of-the-art performance** for the NER task. It has been trained to recognize four types of entities: location (LOC), organizations (ORG), person (PER) and Miscellaneous (MISC). 

Specifically, this model is a *bert-large-cased* model that was fine-tuned on the English version of the standard [CoNLL-2003 Named Entity Recognition](https://www.aclweb.org/anthology/W03-0419.pdf) dataset. 

If you'd like to use a smaller model fine-tuned on the same dataset, a [bert-base-cased](https://huggingface.co/dslim/bert-base-NER) version is also available. 

## Intended uses & limitations

#### How to use

You can use this model with Transformers *pipeline* for NER.

```python
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "My name is Wolfgang and I live in Berlin"

ner_results = nlp(example)
print(ner_results)
```

#### Limitations and bias

This model is limited by its training dataset of entity-annotated news articles from a specific span of time. This may not generalize well for all use cases in different domains. Furthermore, the model occassionally tags subword tokens as entities and post-processing of results may be necessary to handle those cases. 

## Training data

This model was fine-tuned on English version of the standard [CoNLL-2003 Named Entity Recognition](https://www.aclweb.org/anthology/W03-0419.pdf) dataset. 

The training dataset distinguishes between the beginning and continuation of an entity so that if there are back-to-back entities of the same type, the model can output where the second entity begins. As in the dataset, each token will be classified as one of the following classes:

Abbreviation|Description
-|-
O|Outside of a named entity
B-MIS |Beginning of a miscellaneous entity right after another miscellaneous entity
I-MIS | Miscellaneous entity
B-PER |Beginning of a person’s name right after another person’s name
I-PER |Person’s name
B-ORG |Beginning of an organization right after another organization
I-ORG |organization
B-LOC |Beginning of a location right after another location
I-LOC |Location


### CoNLL-2003 English Dataset Statistics
This dataset was derived from the Reuters corpus which consists of Reuters news stories. You can read more about how this dataset was created in the CoNLL-2003 paper. 
#### # of training examples per entity type
Dataset|LOC|MISC|ORG|PER
-|-|-|-|-
Train|7140|3438|6321|6600
Dev|1837|922|1341|1842
Test|1668|702|1661|1617
#### # of articles/sentences/tokens per dataset
Dataset |Articles |Sentences |Tokens
-|-|-|-
Train |946 |14,987 |203,621
Dev |216 |3,466 |51,362
Test |231 |3,684 |46,435

## Training procedure

This model was trained on a single NVIDIA V100 GPU with recommended hyperparameters from the [original BERT paper](https://arxiv.org/pdf/1810.04805) which trained & evaluated the model on CoNLL-2003 NER task. 

## Eval results
metric|dev|test
-|-|-
f1 |95.1 |91.3
precision |95.0 |90.7
recall |95.3 |91.9

The test metrics are a little lower than the official Google BERT results which encoded document context & experimented with CRF. More on replicating the original results [here](google-research/bert#223).

### BibTeX entry and citation info

```
@Article{DBLP:journals/corr/abs-1810-04805,
  author    = {Jacob Devlin and
               Ming{-}Wei Chang and
               Kenton Lee and
               Kristina Toutanova},
  title     = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language
               Understanding},
  journal   = {CoRR},
  volume    = {abs/1810.04805},
  year      = {2018},
  url       = {http://arxiv.org/abs/1810.04805},
  archivePrefix = {arXiv},
  eprint    = {1810.04805},
  timestamp = {Tue, 30 Oct 2018 20:39:56 +0100},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1810-04805.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
```
@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
    title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
    author = "Tjong Kim Sang, Erik F.  and
      De Meulder, Fien",
    booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
    year = "2003",
    url = "https://www.aclweb.org/anthology/W03-0419",
    pages = "142--147",
}
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests