-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathner.py
39 lines (27 loc) · 1002 Bytes
/
ner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from flair.models import SequenceTagger
from flair.data import Sentence, Span
from typing import List, Optional, Dict, Any
from ner_entity import NerEntity
from dataclasses import asdict
from pathlib import Path
from os import path
def load_flair_fr() -> SequenceTagger:
return SequenceTagger.load(path.join(Path.cwd(), 'models/fr-ner-wikiner-0.4.pt'))
def load_flair_en() -> SequenceTagger:
return SequenceTagger.load(path.join(Path.cwd(), 'models/en-ner-fast-conll03-v0.4.pt'))
flair_fr = load_flair_fr()
flair_en = load_flair_en()
def get_flair(language: str) -> Optional[SequenceTagger]:
if language == "fr":
return flair_fr
if language == "en":
return flair_en
return None
def evaluate(tagger: SequenceTagger, content: str) -> Dict[Any, Any]:
sentence = Sentence(content)
tagger.predict(sentence)
entities = [
asdict(e)
for e in map(NerEntity.from_span, sentence.get_spans('ner'))
]
return {"entities": entities}