From 6207842f7d9e0d82f864ccd29fc5443399f9e554 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 16 Oct 2023 14:42:01 +0200 Subject: [PATCH] Add scoring mechanism to sentiment task. --- spacy_llm/tasks/sentiment/registry.py | 7 +++++-- spacy_llm/tasks/sentiment/task.py | 9 +++++++-- spacy_llm/tasks/sentiment/util.py | 17 +++++++++++++++++ spacy_llm/tests/tasks/test_sentiment.py | 19 +++++++++++++++++++ 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/spacy_llm/tasks/sentiment/registry.py b/spacy_llm/tasks/sentiment/registry.py index 13f248c0..4b8ee055 100644 --- a/spacy_llm/tasks/sentiment/registry.py +++ b/spacy_llm/tasks/sentiment/registry.py @@ -1,10 +1,10 @@ from typing import Optional, Type from ...registry import registry -from ...ty import ExamplesConfigType, FewshotExample, TaskResponseParser +from ...ty import ExamplesConfigType, FewshotExample, Scorer, TaskResponseParser from .parser import parse_responses_v1 from .task import DEFAULT_SENTIMENT_TEMPLATE_V1, SentimentTask -from .util import SentimentExample +from .util import SentimentExample, score @registry.llm_tasks("spacy.Sentiment.v1") @@ -14,6 +14,7 @@ def make_sentiment_task( prompt_example_type: Optional[Type[FewshotExample]] = None, examples: ExamplesConfigType = None, field: str = "sentiment", + scorer: Optional[Scorer] = None, ): """Sentiment.v1 task factory. @@ -24,6 +25,7 @@ def make_sentiment_task( examples (Optional[Callable[[], Iterable[Any]]]): Optional callable that reads a file containing task examples for few-shot learning. If None is passed, then zero-shot learning will be used. field (str): The name of the doc extension in which to store the summary. + scorer (Optional[Scorer]): Scorer function. """ raw_examples = examples() if callable(examples) else examples example_type = prompt_example_type or SentimentExample @@ -37,4 +39,5 @@ def make_sentiment_task( prompt_example_type=example_type, prompt_examples=sentiment_examples, field=field, + scorer=scorer or score, ) diff --git a/spacy_llm/tasks/sentiment/task.py b/spacy_llm/tasks/sentiment/task.py index 54c82572..b32d3841 100644 --- a/spacy_llm/tasks/sentiment/task.py +++ b/spacy_llm/tasks/sentiment/task.py @@ -1,10 +1,10 @@ -from typing import Callable, Iterable, List, Optional, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Type from spacy.language import Language from spacy.tokens import Doc from spacy.training import Example -from ...ty import FewshotExample, Self, TaskResponseParser +from ...ty import FewshotExample, Scorer, Self, TaskResponseParser from ..builtin_task import BuiltinTask from ..templates import read_template from .util import SentimentExample @@ -20,6 +20,7 @@ def __init__( prompt_example_type: Type[FewshotExample], field: str, prompt_examples: Optional[List[SentimentExample]], + scorer: Scorer, ): """Sentiment analysis task. @@ -36,6 +37,7 @@ def __init__( prompt_examples=prompt_examples, ) self._field = field + self._scorer = scorer self._check_doc_extension() def _check_doc_extension(self): @@ -79,6 +81,9 @@ def parse_responses( yield doc + def scorer(self, examples: Iterable[Example]) -> Dict[str, Any]: + return self._scorer(examples, field=self._field) + @property def _cfg_keys(self) -> List[str]: return ["_template"] diff --git a/spacy_llm/tasks/sentiment/util.py b/spacy_llm/tasks/sentiment/util.py index f718ff57..c1ea098b 100644 --- a/spacy_llm/tasks/sentiment/util.py +++ b/spacy_llm/tasks/sentiment/util.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, Iterable + from spacy.training import Example from ...compat import Self @@ -14,3 +16,18 @@ def generate(cls, example: Example, **kwargs) -> Self: text=example.reference.text, score=getattr(example.reference._, kwargs["field"]), ) + + +def score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: + """Score lemmatization accuracy in examples. + examples (Iterable[Example]): Examples to score. + RETURNS (Dict[str, Any]): Dict with metric name -> score. + """ + score_diffs = [ + abs( + getattr(example.predicted._, kwargs["field"]) + - getattr(example.reference._, kwargs["field"]) + ) + for example in examples + ] + return {"acc_sentiment": sum(score_diffs) / len(score_diffs)} diff --git a/spacy_llm/tests/tasks/test_sentiment.py b/spacy_llm/tests/tasks/test_sentiment.py index 4b2bd63b..688c351c 100644 --- a/spacy_llm/tests/tasks/test_sentiment.py +++ b/spacy_llm/tests/tasks/test_sentiment.py @@ -1,8 +1,10 @@ from pathlib import Path +import numpy import pytest import spacy from confection import Config +from spacy.training import Example from spacy.util import make_tempdir from spacy_llm.registry import fewshot_reader, file_reader @@ -263,3 +265,20 @@ def test_external_template_actually_loads(): Sentiment: """.strip() ) + + +@pytest.mark.external +@pytest.mark.skipif(has_openai_key is False, reason="OpenAI API key not available") +def test_sentiment_score(request): + """Test scoring mechanism.""" + cfg = request.getfixturevalue("zeroshot_cfg_string") + orig_config = Config().from_str(cfg) + nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) + + sent_diff = 0.2 + doc1 = nlp("This works well.") + doc2 = doc1.copy() + doc2._.sentiment -= sent_diff + assert numpy.isclose( + nlp.get_pipe("llm").score([Example(doc1, doc2)])["acc_sentiment"], sent_diff + )