Skip to content

Commit

Permalink
#53 - Add convenience method for the creation of relation recommendat…
Browse files Browse the repository at this point in the history
…ions

- Add method
- Add test for simalign recommender
  • Loading branch information
reckart committed Jan 7, 2024
1 parent 54c4b7a commit 3cab6bc
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 24 deletions.
67 changes: 66 additions & 1 deletion ariadne/contrib/inception_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
from typing import Optional

import deprecation
from cassis import Cas
from cassis.typesystem import FeatureStructure

Expand All @@ -26,6 +27,7 @@
FEATURE_NAME_AUTO_ACCEPT_MODE_SUFFIX = "_auto_accept"


@deprecation.deprecated(details="Use create_span_prediction()")
def create_prediction(
cas: Cas,
layer: str,
Expand All @@ -36,9 +38,23 @@ def create_prediction(
score: Optional[int] = None,
score_explanation: Optional[str] = None,
auto_accept: Optional[bool] = None,
) -> FeatureStructure:
return create_span_prediction(cas, layer, feature, begin, end, label, score, score_explanation, auto_accept)


def create_span_prediction(
cas: Cas,
layer: str,
feature: str,
begin: int,
end: int,
label: str,
score: Optional[int] = None,
score_explanation: Optional[str] = None,
auto_accept: Optional[bool] = None,
) -> FeatureStructure:
"""
Create a prediction
Create a span prediction
:param cas: the annotated document
:param layer: the layer on which to create the prediction
Expand Down Expand Up @@ -66,3 +82,52 @@ def create_prediction(
prediction[f"{feature}{FEATURE_NAME_AUTO_ACCEPT_MODE_SUFFIX}"] = auto_accept

return prediction


def create_relation_prediction(
cas: Cas,
layer: str,
feature: str,
source: FeatureStructure,
target: FeatureStructure,
label: str,
score: Optional[int] = None,
score_explanation: Optional[str] = None,
auto_accept: Optional[bool] = None,
) -> FeatureStructure:
"""
Create a relation prediction
:param cas: the annotated document
:param layer: the layer on which to create the prediction
:param feature: the feature to predict
:param source: the source of the relation
:param target: the target of the relation
:param label: the predicted label
:param score: the score
:param score_explanation: a rationale for the score / prediction
:param auto_accept: whether the prediction should be automatically accepted
:return: the prediction annotation
"""
AnnotationType = cas.typesystem.get_type(layer)

fields = {
"begin": target.begin,
"end": target.end,
"Governor": source,
"Dependent": target,
IS_PREDICTION: True,
feature: label,
}
prediction = AnnotationType(**fields)

if score is not None:
prediction[f"{feature}{FEATURE_NAME_SCORE_SUFFIX}"] = score

if score_explanation is not None:
prediction[f"{feature}{FEATURE_NAME_SCORE_EXPLANATION_SUFFIX}"] = score_explanation

if auto_accept is not None:
prediction[f"{feature}{FEATURE_NAME_AUTO_ACCEPT_MODE_SUFFIX}"] = auto_accept

return prediction
36 changes: 15 additions & 21 deletions ariadne/contrib/simalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@
from simalign import SentenceAligner

from ariadne.classifier import Classifier
from ariadne.contrib.inception_util import SENTENCE_TYPE
from ariadne.contrib.inception_util import SENTENCE_TYPE, create_relation_prediction

SPAN_ANNOTATION_TYPE = "webanno.custom.Base"


class SimAligner(Classifier):
"""
Alignment of words in two sentences.
The recommender assumes that there are exactly two sentences in the CAS.
For each of the tokens, there must be an annotation of type `webanno.custom.Base`.
The recommender then will predict relations between these base annotations.
"""

def __init__(self):
super().__init__()

Expand All @@ -30,34 +40,18 @@ def __init__(self):
def predict(self, cas: Cas, layer: str, feature: str, project_id: str, document_id: str, user_id: str):
sentences = cas.select(SENTENCE_TYPE)

src_tokens = cas.select_covered("webanno.custom.Base", sentences[0])
trg_tokens = cas.select_covered("webanno.custom.Base", sentences[1])
src_tokens = cas.select_covered(SPAN_ANNOTATION_TYPE, sentences[0])
trg_tokens = cas.select_covered(SPAN_ANNOTATION_TYPE, sentences[1])

src_sentence = [e.get_covered_text() for e in src_tokens]
trg_sentence = [e.get_covered_text() for e in trg_tokens]

print(src_sentence)
print(trg_sentence)

alignments = self._aligner.get_word_aligns(src_sentence, trg_sentence)

Relation = cas.typesystem.get_type(layer)
print(list(Relation.all_features))

for matching_method in alignments:
for source_idx, target_idx in alignments[matching_method]:
src = src_tokens[source_idx]
target = trg_tokens[target_idx]
prediction = Relation(
Governor=src,
Dependent=target,
begin=target.begin,
end=target.end,
inception_internal_predicted=True,
prediction = create_relation_prediction(
cas, layer, feature, src_tokens[source_idx], trg_tokens[target_idx], ""
)
# setattr(prediction, feature, f"{src.get_covered_text()} -> {target.get_covered_text()}")
setattr(prediction, feature, "")
print(source_idx, target_idx, prediction)

cas.add_annotation(prediction)
break
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"dkpro-cassis>=0.7.6",
"joblib",
"gunicorn",
"deprecation",
]

contrib_dependencies = [
Expand All @@ -48,7 +49,8 @@
"jieba~=0.42",
"sentence-transformers~=2.2.2",
"lightgbm~=4.2.0",
"diskcache~=5.2.1"
"diskcache~=5.2.1",
"simalign~=0.4"
]

test_dependencies = [
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sbert_sentence_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest

#pytest.importorskip("lightgbm.LGBMClassifier")
# pytest.importorskip("lightgbm.LGBMClassifier")

from ariadne.contrib.sbert import SbertSentenceClassifier

Expand Down
63 changes: 63 additions & 0 deletions tests/test_simalign.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Technische Universität Darmstadt under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The Technische Universität Darmstadt
# licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from cassis import TypeSystem, Cas
from cassis.typesystem import TYPE_NAME_STRING, TYPE_NAME_ANNOTATION, TYPE_NAME_BOOLEAN
from ariadne.contrib.inception_util import SENTENCE_TYPE, IS_PREDICTION
from ariadne.contrib.simalign import SPAN_ANNOTATION_TYPE, SimAligner

RELATION_ANNOTATION_TYPE = "custom.Relation"


def test_predict():
typesystem = TypeSystem()
Sentence = typesystem.create_type(SENTENCE_TYPE)
Span = typesystem.create_type(SPAN_ANNOTATION_TYPE)
Relation = typesystem.create_type(RELATION_ANNOTATION_TYPE)
typesystem.create_feature(Relation, "Governor", TYPE_NAME_ANNOTATION)
typesystem.create_feature(Relation, "Dependent", TYPE_NAME_ANNOTATION)
typesystem.create_feature(Relation, "value", TYPE_NAME_STRING)
typesystem.create_feature(Relation, IS_PREDICTION, TYPE_NAME_BOOLEAN)

cas = Cas(typesystem)
cas.sofa_string = "I do like the color red. Red is the color that I like."
for start, end in tokenize(cas.sofa_string):
cas.add(Span(**{"begin": start, "end": end}))
cas.add(Sentence(**{"begin": 0, "end": 24}))
cas.add(Sentence(**{"begin": 25, "end": 54}))

sut = SimAligner()
sut.predict(cas, Relation.name, "value", None, None, None)

pairs = [
(r.get("Governor").get_covered_text(), r.get("Dependent").get_covered_text()) for r in cas.select(Relation)
]
assert set(pairs) == set([
("red", "Red"),
("do", "is"),
("the", "the"),
("color", "color"),
("I", "I"),
("like", "like"),
(".", "."),
])


def tokenize(string):
positions = []
for match in re.compile(r"\w+|\S").finditer(string):
positions.append((match.start(), match.end()))
return positions

0 comments on commit 3cab6bc

Please sign in to comment.