Skip to content

Commit

Permalink
Merge pull request #469 from tomaarsen/absa_predict_gold_aspects
Browse files Browse the repository at this point in the history
[`ABSA`] Predict with a gold aspect dataset
  • Loading branch information
tomaarsen authored Jan 11, 2024
2 parents 6ef9482 + 38e9075 commit 3e3d828
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 1 deletion.
97 changes: 96 additions & 1 deletion src/setfit/span/modeling.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import copy
import os
import re
import tempfile
import types
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

import torch
from datasets import Dataset
from huggingface_hub.utils import SoftTemporaryDirectory

from setfit.utils import set_docstring
Expand Down Expand Up @@ -148,7 +151,99 @@ class AbsaModel:
aspect_model: AspectModel
polarity_model: PolarityModel

def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
def gold_aspect_spans_to_aspects_list(self, inputs: Dataset) -> List[List[slice]]:
# First group inputs by text
grouped_data = defaultdict(list)
for sample in inputs:
text = sample.pop("text")
grouped_data[text].append(sample)

# Get the spaCy docs
docs, _ = self.aspect_extractor(grouped_data.keys())

# Get the aspect spans for each doc by matching gold spans to the spaCy tokens
aspects_list = []
index = -1
skipped_indices = []
for doc, samples in zip(docs, grouped_data.values()):
aspects_list.append([])
for sample in samples:
index += 1
match_objects = re.finditer(re.escape(sample["span"]), doc.text)
for i, match in enumerate(match_objects):
if i == sample["ordinal"]:
char_idx_start = match.start()
char_idx_end = match.end()
span = doc.char_span(char_idx_start, char_idx_end)
if span is None:
logger.warning(
f"Aspect term {sample['span']!r} with ordinal {sample['ordinal']}, isn't a token in {doc.text!r} according to spaCy. "
"Skipping this sample."
)
skipped_indices.append(index)
continue
aspects_list[-1].append(slice(span.start, span.end))
return docs, aspects_list, skipped_indices

def predict_dataset(self, inputs: Dataset) -> Dataset:
if set(inputs.column_names) >= {"text", "span", "ordinal"}:
pass
elif set(inputs.column_names) >= {"text", "span"}:
inputs = inputs.add_column("ordinal", [0] * len(inputs))
else:
raise ValueError(
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
f"Found a dataset with these columns: {inputs.column_names}."
)
if "pred_polarity" in inputs.column_names:
raise ValueError(
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
)
docs, aspects_list, skipped_indices = self.gold_aspect_spans_to_aspects_list(inputs)
polarity_list = sum(self.polarity_model(docs, aspects_list), [])
for index in skipped_indices:
polarity_list.insert(index, None)
return inputs.add_column("pred_polarity", polarity_list)

def predict(self, inputs: Union[str, List[str], Dataset]) -> Union[List[Dict[str, Any]], Dataset]:
"""Predicts aspects & their polarities of the given inputs.
Example::
>>> from setfit import AbsaModel
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> model.predict("The food and wine are just exquisite.")
[{'span': 'food', 'polarity': 'positive'}, {'span': 'wine', 'polarity': 'positive'}]
>>> from setfit import AbsaModel
>>> from datasets import load_dataset
>>> model = AbsaModel.from_pretrained(
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
... "tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
... )
>>> dataset = load_dataset("tomaarsen/setfit-absa-semeval-restaurants", split="train")
>>> model.predict(dataset)
Dataset({
features: ['text', 'span', 'label', 'ordinal', 'pred_polarity'],
num_rows: 3693
})
Args:
inputs (Union[str, List[str], Dataset]): Either a sentence, a list of sentences,
or a dataset with columns `text` and `span` and optionally `ordinal`. This dataset
contains gold aspects, and we only predict the polarities for them.
Returns:
Union[List[Dict[str, Any]], Dataset]: Either a list of dictionaries with keys `span`
and `polarity` if the input was a sentence or a list of sentences, or a dataset with
columns `text`, `span`, `ordinal`, and `pred_polarity`.
"""
if isinstance(inputs, Dataset):
return self.predict_dataset(inputs)

is_str = isinstance(inputs, str)
inputs_list = [inputs] if is_str else inputs
docs, aspects_list = self.aspect_extractor(inputs_list)
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ def absa_model() -> AbsaModel:
return AbsaModel.from_pretrained("sentence-transformers/paraphrase-albert-small-v2", spacy_model="en_core_web_sm")


@pytest.fixture()
def trained_absa_model() -> AbsaModel:
return AbsaModel.from_pretrained(
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-aspect",
"tomaarsen/setfit-absa-bge-small-en-v1.5-restaurants-polarity",
)


@pytest.fixture()
def absa_dataset() -> Dataset:
texts = [
Expand Down
81 changes: 81 additions & 0 deletions tests/span/test_modeling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import re
from pathlib import Path
from tempfile import TemporaryDirectory

import pytest
import torch
from datasets import Dataset
from pytest import LogCaptureFixture

from setfit import AbsaModel
Expand Down Expand Up @@ -144,3 +146,82 @@ def test_load_model_on_device(device):
assert model.device.type == device
assert model.polarity_model.device.type == device
assert model.aspect_model.device.type == device


def test_predict_dataset(trained_absa_model: AbsaModel):
inputs = Dataset.from_dict(
{
"text": [
"But the staff was so horrible to us.",
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
],
"span": ["staff", "food", "food", "kitchen", "menu"],
"label": ["negative", "positive", "positive", "positive", "neutral"],
"ordinal": [0, 0, 0, 0, 0],
}
)
outputs = trained_absa_model.predict(inputs)
assert isinstance(outputs, Dataset)
assert set(outputs.column_names) == {"pred_polarity", "text", "span", "label", "ordinal"}

inputs = Dataset.from_dict(
{
"text": [
"But the staff was so horrible to us.",
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
],
"span": ["staff", "food", "food", "kitchen", "menu"],
}
)
outputs = trained_absa_model.predict(inputs)
assert isinstance(outputs, Dataset)
assert "pred_polarity" in outputs.column_names


def test_predict_dataset_errors(trained_absa_model: AbsaModel):
inputs = Dataset.from_dict(
{
"text": [
"But the staff was so horrible to us.",
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
],
}
)
with pytest.raises(
ValueError,
match=re.escape(
"`inputs` must be either a `str`, a `List[str]`, or a `datasets.Dataset` with columns `text` and `span` and optionally `ordinal`. "
"Found a dataset with these columns: ['text']."
),
):
trained_absa_model.predict(inputs)

inputs = Dataset.from_dict(
{
"text": [
"But the staff was so horrible to us.",
"To be completely fair, the only redeeming factor was the food, which was above average, but couldn't make up for all the other deficiencies of Teodora.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
"The food is uniformly exceptional, with a very capable kitchen which will proudly whip up whatever you feel like eating, whether it's on the menu or not.",
],
"span": ["staff", "food", "food", "kitchen", "menu"],
"pred_polarity": ["negative", "positive", "positive", "positive", "neutral"],
}
)
with pytest.raises(
ValueError,
match=re.escape(
"`predict_dataset` wants to add a `pred_polarity` column, but the input dataset already contains that column."
),
):
trained_absa_model.predict(inputs)

0 comments on commit 3e3d828

Please sign in to comment.