Skip to content

Commit

Permalink
#63 - Add Transformer-based NER classifier using Hugging Face models
Browse files Browse the repository at this point in the history
* Add Transformer-based NER classifier using Hugging Face models
* Add test file to transformers classifier
* Fix the transformer classifier
* Update dependencies
  • Loading branch information
lfcc1 authored Sep 29, 2024
1 parent 02a7f7e commit c683233
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 2 deletions.
42 changes: 42 additions & 0 deletions ariadne/contrib/transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Technische Universität Darmstadt under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The Technische Universität Darmstadt
# licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import create_prediction
from cassis import Cas

class TransformerNerClassifier(Classifier):
def __init__(self, model_name: str):
super().__init__()
# Load the Hugging Face model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=512)
self.model = AutoModelForTokenClassification.from_pretrained(model_name)
self.ner_pipeline = pipeline("ner", model=self.model, tokenizer=self.tokenizer, aggregation_strategy="first")



def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):

document_text = cas.sofa_string
predictions = self.ner_pipeline(document_text)
for prediction in predictions:
start_char = prediction['start']
end_char = prediction['end']
label = prediction['entity_group']
cas_prediction = create_prediction(cas, layer, feature, start_char, end_char, label)
cas.add(cas_prediction)

5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@
"lightgbm~=4.2.0",
"diskcache~=5.2.1",
"simalign~=0.4",
"flair>=0.13.1"
"flair>=0.13.1",
"transformers[torch]~=4.41.1", # TransformerNerClassifier
]

test_dependencies = [
"tox",
"pytest",
"codecov",
"pytest-cov",
"pytest-cov",
]

dev_dependencies = [
Expand Down
34 changes: 34 additions & 0 deletions tests/test_transformer_recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Licensed to the Technische Universität Darmstadt under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The Technische Universität Darmstadt
# licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

pytest.importorskip("transformers")

from ariadne.contrib.transformers import TransformerNerClassifier
from tests.util import load_obama, PREDICTED_TYPE, PREDICTED_FEATURE, PROJECT_ID, USER


def test_predict_ner(tmpdir_factory):
cas = load_obama()
sut = TransformerNerClassifier("lfcc/lusa_events")

sut.predict(cas, PREDICTED_TYPE, PREDICTED_FEATURE, PROJECT_ID, "doc_42", USER)
predictions = list(cas.select(PREDICTED_TYPE))

assert len(predictions)

for prediction in predictions:
assert getattr(prediction, PREDICTED_FEATURE) is not None

0 comments on commit c683233

Please sign in to comment.