Skip to content

Commit

Permalink
fix: fix entity aggregation bug for NER detection
Browse files Browse the repository at this point in the history
It looks like it’s because we’re using the “FIRST” aggregation strategy,
with a tokenizer that is not word-aware: we’re falling back to some
heuristics (the presence of spaces before/after the word), that somehow
fails here.
Indeed, XLM-RoBERTa model does not use the same tokenizer as RoBERTa,
and uses an Unigram model (instead of BPE), which is not word-aware.

Another issue of the “FIRST” aggregation strategy is that the ending
dot after the ingredient list is predicted as part of the ingredient
list, even though it’s not in the non-aggregated prediction.
By switching to “SIMPLE” strategy (a strategy without an error
correction mechanism), we don’t have this issue anymore, but two
subwords belonging to the same word are sometimes predicted as
belonging to two entities.
A more in-depth analysis of the TokenClassificationPipeline reveals
that the issue comes from the Punctuation() pre-tokenizer we added:
it was not included in the original tokenizer, and the heuristic
doesn’t take it into account, leading to an incorrect detection.
I updated the heuristic to use the `word_ids` provided by the tokenizer
to know whether the token is a subword or not (with respect to the
pre-tokenization output).
  • Loading branch information
raphael0202 committed Sep 13, 2024
1 parent 6eae9d5 commit 5f2b94c
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 109 deletions.
1 change: 1 addition & 0 deletions robotoff/prediction/ingredient_list/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def predict_batch(
"input_ids": batch_encoding.input_ids[idx],
"offset_mapping": batch_encoding.offset_mapping[idx],
"special_tokens_mask": batch_encoding.special_tokens_mask[idx],
"word_ids": batch_encoding.word_ids(idx),
}
pipeline_output = pipeline.postprocess(model_outputs, aggregation_strategy)

Expand Down
168 changes: 59 additions & 109 deletions robotoff/prediction/ingredient_list/transformers_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
This file has been copied and adapted from
https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/pipelines/token_classification.py
https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/pipelines/token_classification.py
The code is under Apache-2.0 license:
https://github.com/huggingface/transformers/blob/main/LICENSE
Expand All @@ -16,10 +16,16 @@
- remove unnecessary code (everything that is not related to post-processing)
- `postprocess` now accepts a single sample (instead of a batched sample of
size 1)
Furthermore, some modifications were made to allow proper aggregation of entities in
the TokenClassificationPipeline, for the XLM-RoBERTa model with a custom pre-tokenizer
(addition of Punctuation()), for the ingredient detection model.
All significant differences from the original file are marked with the comment
"DIFF-ORIGINAL".
"""

import enum
import warnings
from typing import List, Optional, Tuple

import numpy as np
Expand All @@ -42,60 +48,6 @@ def __init__(self, tokenizer, id2label):
self.tokenizer = tokenizer
self.id2label = id2label

def _sanitize_parameters(
self,
ignore_labels=None,
grouped_entities: Optional[bool] = None,
ignore_subwords: Optional[bool] = None,
aggregation_strategy: Optional[AggregationStrategy] = None,
offset_mapping: Optional[List[Tuple[int, int]]] = None,
):

preprocess_params = {}
if offset_mapping is not None:
preprocess_params["offset_mapping"] = offset_mapping

postprocess_params = {}
if grouped_entities is not None or ignore_subwords is not None:
if grouped_entities and ignore_subwords:
aggregation_strategy = AggregationStrategy.FIRST
elif grouped_entities and not ignore_subwords:
aggregation_strategy = AggregationStrategy.SIMPLE
else:
aggregation_strategy = AggregationStrategy.NONE

if grouped_entities is not None:
warnings.warn(
"`grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to"
f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)
if ignore_subwords is not None:
warnings.warn(
"`ignore_subwords` is deprecated and will be removed in version v5.0.0, defaulted to"
f' `aggregation_strategy="{aggregation_strategy}"` instead.'
)

if aggregation_strategy is not None:
if isinstance(aggregation_strategy, str):
aggregation_strategy = AggregationStrategy[aggregation_strategy.upper()]
if (
aggregation_strategy
in {
AggregationStrategy.FIRST,
AggregationStrategy.MAX,
AggregationStrategy.AVERAGE,
}
and not self.tokenizer.is_fast
):
raise ValueError(
"Slow tokenizers cannot handle subwords. Please set the `aggregation_strategy` option"
'to `"simple"` or use a fast tokenizer.'
)
postprocess_params["aggregation_strategy"] = aggregation_strategy
if ignore_labels is not None:
postprocess_params["ignore_labels"] = ignore_labels
return preprocess_params, {}, postprocess_params

def postprocess(
self,
model_outputs,
Expand All @@ -113,6 +65,7 @@ def postprocess(
else None
)
special_tokens_mask = model_outputs["special_tokens_mask"]
word_ids = model_outputs["word_ids"] # DIFF-ORIGINAL

maxes = np.max(logits, axis=-1, keepdims=True)
shifted_exp = np.exp(logits - maxes)
Expand All @@ -121,12 +74,12 @@ def postprocess(
pre_entities = self.gather_pre_entities(
sentence,
input_ids,
word_ids,
scores,
offset_mapping,
special_tokens_mask,
aggregation_strategy,
)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy)
grouped_entities = self.aggregate(pre_entities, aggregation_strategy, sentence)
# Filter anything that is in self.ignore_labels
entities = [
entity
Expand All @@ -140,49 +93,30 @@ def gather_pre_entities(
self,
sentence: str,
input_ids: np.ndarray,
word_ids: list[Optional[int]],
scores: np.ndarray,
offset_mapping: Optional[List[Tuple[int, int]]],
special_tokens_mask: np.ndarray,
aggregation_strategy: AggregationStrategy,
) -> List[dict]:
"""Fuse various numpy arrays into dicts with all the information
needed for aggregation"""
"""Fuse various numpy arrays into dicts with all the information needed for
aggregation"""
pre_entities = []
previous_word_id = None # DIFF-ORIGINAL
for idx, token_scores in enumerate(scores):
# Filter special_tokens, they should only occur
# at the sentence boundaries since we're not encoding pairs of
# sentences so we don't have to keep track of those.
# DIFF-ORIGINAL: idx may be out of bounds if the input_ids are padded
word_id = word_ids[idx] if idx < len(word_ids) else None # DIFF-ORIGINAL
# Filter special_tokens
if special_tokens_mask[idx]:
previous_word_id = word_id # DIFF-ORIGINAL
continue

word = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
word_ref = sentence[start_ind:end_ind]
if getattr(
self.tokenizer._tokenizer.model, "continuing_subword_prefix", None
):
# This is a BPE, word aware tokenizer, there is a correct
# way to fuse tokens
is_subword = len(word) != len(word_ref)
else:
# This is a fallback heuristic. This will fail most likely
# on any kind of text + punctuation mixtures that will be
# considered "words". Non word aware models cannot do
# better than this unfortunately.
if aggregation_strategy in {
AggregationStrategy.FIRST,
AggregationStrategy.AVERAGE,
AggregationStrategy.MAX,
}:
warnings.warn(
"Tokenizer does not support real words, using fallback heuristic",
UserWarning,
)
is_subword = (
start_ind > 0
and " " not in sentence[start_ind - 1 : start_ind + 1]
)
is_subword = word_id == previous_word_id # DIFF-ORIGINAL
# DIFF-ORIGINAL: we removed here fallback heuristic used for
# subword detection

if int(input_ids[idx]) == self.tokenizer.unk_token_id:
word = word_ref
Expand All @@ -192,6 +126,7 @@ def gather_pre_entities(
end_ind = None
is_subword = False

previous_word_id = word_id # DIFF-ORIGINAL
pre_entity = {
"word": word,
"scores": token_scores,
Expand All @@ -204,7 +139,10 @@ def gather_pre_entities(
return pre_entities

def aggregate(
self, pre_entities: List[dict], aggregation_strategy: AggregationStrategy
self,
pre_entities: List[dict],
aggregation_strategy: AggregationStrategy,
sentence: str,
) -> List[dict]:
if aggregation_strategy in {
AggregationStrategy.NONE,
Expand All @@ -228,7 +166,7 @@ def aggregate(

if aggregation_strategy == AggregationStrategy.NONE:
return entities
return self.group_entities(entities)
return self.group_entities(entities, sentence)

def aggregate_word(
self, entities: List[dict], aggregation_strategy: AggregationStrategy
Expand Down Expand Up @@ -268,12 +206,11 @@ def aggregate_words(
self, entities: List[dict], aggregation_strategy: AggregationStrategy
) -> List[dict]:
"""
Override tokens from a given word that disagree to force agreement on
word boundaries.
Override tokens from a given word that disagree to force agreement on word
boundaries.
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be
rewritten with first strategy as microsoft|
company| B-ENT I-ENT
Example: micro|soft| com|pany| B-ENT I-NAME I-ENT I-ENT will be rewritten with
first strategy as microsoft|company| B-ENT I-ENT
"""
if aggregation_strategy in {
AggregationStrategy.NONE,
Expand All @@ -296,27 +233,32 @@ def aggregate_words(
)
word_group = [entity]
# Last item
word_entities.append(self.aggregate_word(word_group, aggregation_strategy)) # type: ignore
if word_group is not None:
word_entities.append(self.aggregate_word(word_group, aggregation_strategy))
return word_entities

def group_sub_entities(self, entities: List[dict]) -> dict:
def group_sub_entities(
self, entities: List[dict], sentence: str # DIFF-ORIGINAL
) -> dict:
"""
Group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
sentence (`str`): The sentence to predict on.
"""
# Get the first entity in the entity group
entity = entities[0]["entity"].split("-")[-1]
scores = np.nanmean([entity["score"] for entity in entities])
tokens = [entity["word"] for entity in entities]

# DIFF-ORIGINAL
start = entities[0]["start"]
end = entities[-1]["end"]
entity_group = {
"entity_group": entity,
"score": np.mean(scores),
"word": self.tokenizer.convert_tokens_to_string(tokens),
"start": entities[0]["start"],
"end": entities[-1]["end"],
"word": sentence[start:end], # DIFF-ORIGINAL
"start": start,
"end": end,
}
return entity_group

Expand All @@ -334,13 +276,15 @@ def get_tag(self, entity_name: str) -> Tuple[str, str]:
tag = entity_name
return bi, tag

def group_entities(self, entities: List[dict]) -> List[dict]:
def group_entities(
self, entities: List[dict], sentence: str # DIFF-ORIGINAL
) -> list[dict]:
"""
Find and group together the adjacent tokens with the same entity
predicted.
Find and group together the adjacent tokens with the same entity predicted.
Args:
entities (`dict`): The entities predicted by the pipeline.
sentence (`str`): The sentence to predict on.
"""

entity_groups = []
Expand All @@ -351,8 +295,8 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
entity_group_disagg.append(entity)
continue

# If the current entity is similar and adjacent to the previous
# entity, append it to the disaggregated entity group
# If the current entity is similar and adjacent to the previous entity,
# append it to the disaggregated entity group
# The split is meant to account for the "B" and "I" prefixes
# Shouldn't merge if both entities are B-type
bi, tag = self.get_tag(entity["entity"])
Expand All @@ -364,10 +308,16 @@ def group_entities(self, entities: List[dict]) -> List[dict]:
else:
# If the current entity is different from the previous entity
# aggregate the disaggregated entity group
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_groups.append(
self.group_sub_entities(
entity_group_disagg, sentence
) # DIFF-ORIGINAL
)
entity_group_disagg = [entity]
if entity_group_disagg:
# it's the last entity, add it to the entity groups
entity_groups.append(self.group_sub_entities(entity_group_disagg))
entity_groups.append(
self.group_sub_entities(entity_group_disagg, sentence) # DIFF-ORIGINAL
)

return entity_groups

0 comments on commit 5f2b94c

Please sign in to comment.