diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 3f40f6ef..56eaa4a8 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -117,7 +117,7 @@ def finalize_fields(self, *datasets: "DatasetABC") -> None: for example in dataset: for field in fields_to_build: _, tokenized = example[field.name] - field.update_vocab(tokenized) + field.update_numericalizer(tokenized) for field in self.fields: field.finalize() diff --git a/podium/preproc/__init__.py b/podium/preproc/__init__.py index 66c141af..5bbe654a 100644 --- a/podium/preproc/__init__.py +++ b/podium/preproc/__init__.py @@ -9,6 +9,7 @@ TextCleanUp, ) from .lemmatizer import CroatianLemmatizer +from .numericalizer_abc import NumericalizerABC from .sentencizers import SpacySentencizer from .stemmer import CroatianStemmer from .tokenizers import get_tokenizer diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py new file mode 100644 index 00000000..b24e3f3a --- /dev/null +++ b/podium/preproc/numericalizer_abc.py @@ -0,0 +1,93 @@ +from abc import ABC, abstractmethod +from typing import List + +import numpy as np + + +class NumericalizerABC(ABC): + """ABC that contains the interface for Podium numericalizers. Numericalizers are used + to transform tokens into vectors or any other custom datatype during batching. + + Attributes + ---------- + finalized: bool + Whether this numericalizer was finalized and is able to be used for + numericalization. + """ + + def __init__(self, eager=True): + """Initialises the Numericalizer. + + Parameters + ---------- + eager: bool + Whether the Numericalizer is to be updated during loading of the dataset, or + after all data is loaded. + + """ + self._finalized = False + self._eager = eager + + @abstractmethod + def numericalize(self, tokens: List[str]) -> np.ndarray: + """Converts `tokens` into a numericalized format used in batches. + Numericalizations are most often numpy vectors, but any custom datatype is + supported. + + Parameters + ---------- + tokens: List[str] + A list of strings that represent the tokens of this data point. Can also be + any other datatype, as long as this Numericalizer supports it. + + Returns + ------- + Numericalization used in batches. Numericalizations are most often numpy vectors, + but any custom datatype is supported. + """ + pass + + def finalize(self): + """Finalizes the Numericalizer and prepares it for numericalization. + This method must be overridden in classes that require finalization before + numericalization. The override must call `mark_finalize` after successful + completion.""" + self.mark_finalized() + pass + + def update(self, tokens: List[str]) -> None: + """Updates this Numericalizer with a single data point. Numericalizers that need + to be updated example by example must override this method. Numericalizers that + are eager get updated during the dataset loading process, while non-eager ones get + updated after loading is finished, after all eager numericalizers were fully + updated. + + Parameters + ---------- + tokens: List[str] + A list of strings that represent the tokens of this data point. Can also be + any other datatype, as long as this Numericalizer supports it. + + """ + pass + + def mark_finalized(self) -> None: + """Marks the field as finalized. This method must be called after finalization + completes successfully.""" + self._finalized = True + + @property + def finalized(self) -> bool: + """Whether this Numericalizer was finalized and is ready for numericalization.""" + return self._finalized + + @property + def eager(self) -> bool: + """Whether this Numericalizer is eager. Numericalizers that + are eager get updated during the dataset loading process, while non-eager ones get + updated after loading is finished, after all eager numericalizers were fully + updated.""" + return self._eager + + def __call__(self, tokens: List[str]) -> np.ndarray: + return self.numericalize(tokens) diff --git a/podium/storage/field.py b/podium/storage/field.py index d1595b27..a4cc1429 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -6,6 +6,7 @@ import numpy as np +from podium.preproc import NumericalizerABC from podium.preproc.tokenizers import get_tokenizer from podium.storage.vocab import Vocab @@ -13,7 +14,8 @@ PretokenizationHookType = Callable[[Any], Any] PosttokenizationHookType = Callable[[Any, List[str]], Tuple[Any, List[str]]] TokenizerType = Optional[Union[str, Callable[[Any], List[str]]]] -NumericalizerType = Callable[[str], Union[int, float]] +NumericalizerCallableType = Callable[[str], Union[int, float]] +NumericalizerType = Union[NumericalizerABC, NumericalizerCallableType] class PretokenizationPipeline: @@ -205,6 +207,16 @@ def remove_pretokenize_hooks(self): self._pretokenization_pipeline.clear() +class NumericalizerCallableWrapper(NumericalizerABC): + def __init__(self, numericalizer: NumericalizerType): + super().__init__(eager=True) + self._wrapped_numericalizer = numericalizer + + def numericalize(self, tokens: List[str]) -> np.ndarray: + numericalized = [self._wrapped_numericalizer(tok) for tok in tokens] + return np.array(numericalized) + + class Field: """Holds the preprocessing and numericalization logic for a single field of a dataset. @@ -321,12 +333,16 @@ def __init__( else: self._tokenizer = get_tokenizer(tokenizer) - if isinstance(numericalizer, Vocab): - self._vocab = numericalizer - self._numericalizer = self.vocab.__getitem__ - else: - self._vocab = None + if isinstance(numericalizer, NumericalizerABC) or numericalizer is None: self._numericalizer = numericalizer + elif isinstance(numericalizer, Callable): + self._numericalizer = NumericalizerCallableWrapper(numericalizer) + else: + err_msg = ( + f"Field {name}: unsupported numericalizer type " + f'"{type(numericalizer).__name__}"' + ) + raise TypeError(err_msg) self._keep_raw = keep_raw @@ -385,12 +401,20 @@ def eager(self): whether this field has a Vocab and whether that Vocab is marked as eager """ - return self.vocab is not None and self.vocab.eager + # Pretend to be eager if no numericalizer provided + return self._numericalizer is None or self._numericalizer.eager @property def vocab(self): """""" - return self._vocab + if not self.use_vocab: + numericalizer_type = type(self._numericalizer).__name__ + err_msg = ( + f'Field "{self.name}" has no vocab, numericalizer type is ' + f"{numericalizer_type}." + ) + raise TypeError(err_msg) + return self._numericalizer @property def use_vocab(self): @@ -402,7 +426,7 @@ def use_vocab(self): Whether the field uses a vocab or not. """ - return self.vocab is not None + return isinstance(self._numericalizer, Vocab) @property def is_target(self): @@ -547,6 +571,7 @@ def preprocess( # Preprocess the raw input # TODO keep unprocessed or processed raw? + # Keeping processed for now, may change in the future processed_raw = self._run_pretokenization_hooks(data) tokenized = ( self._tokenizer(processed_raw) @@ -556,7 +581,7 @@ def preprocess( return (self._process_tokens(processed_raw, tokenized),) - def update_vocab(self, tokenized: List[str]): + def update_numericalizer(self, tokenized: Union[str, List[str]]) -> None: """Updates the vocab with a data point in its tokenized form. If the field does not do tokenization, @@ -567,11 +592,11 @@ def update_vocab(self, tokenized: List[str]): updated with. """ - if not self.use_vocab: + if self._numericalizer is None: return # TODO throw Error? data = tokenized if isinstance(tokenized, (list, tuple)) else (tokenized,) - self._vocab += data + self._numericalizer.update(data) @property def finalized(self) -> bool: @@ -584,13 +609,13 @@ def finalized(self) -> bool: Whether the field's Vocab vas finalized. If the field has no vocab, returns True. """ - return True if self.vocab is None else self.vocab.finalized + return self._numericalizer is None or self._numericalizer.finalized def finalize(self): """Signals that this field's vocab can be built.""" - if self.use_vocab: - self.vocab.finalize() + if self._numericalizer is not None: + self._numericalizer.finalize() def _process_tokens( self, raw: Any, tokens: Union[Any, List[str]] @@ -616,8 +641,12 @@ def _process_tokens( raw, tokenized = self._run_posttokenization_hooks(raw, tokens) raw = raw if self._keep_raw else None - if self.eager and not self.vocab.finalized: - self.update_vocab(tokenized) + if ( + self.eager + and self._numericalizer is not None + and not self._numericalizer.finalized + ): + self.update_numericalizer(tokenized) return self.name, (raw, tokenized) def get_default_value(self) -> Union[int, float]: @@ -679,10 +708,7 @@ def numericalize( tokens = tokenized if isinstance(tokenized, (list, tuple)) else [tokenized] - if self.use_vocab: - return self.vocab.numericalize(tokens) - else: - return np.array([self._numericalizer(t) for t in tokens]) + return self._numericalizer.numericalize(tokens) def _pad_to_length( self, @@ -1030,9 +1056,8 @@ def __init__( def finalize(self): """Signals that this field's vocab can be built.""" - super().finalize() if self._num_of_classes is None: - self.fixed_length = self._num_of_classes = len(self.vocab) + self._fixed_length = self._num_of_classes = len(self.vocab) if self.use_vocab and len(self.vocab) > self._num_of_classes: raise ValueError( @@ -1040,6 +1065,7 @@ def finalize(self): f"of classes. Declared: {self._num_of_classes}, " f"Actual: {len(self.vocab)}" ) + super().finalize() def numericalize( self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] diff --git a/podium/storage/vectorizers/tfidf.py b/podium/storage/vectorizers/tfidf.py index 1e6f855f..dab91e9c 100644 --- a/podium/storage/vectorizers/tfidf.py +++ b/podium/storage/vectorizers/tfidf.py @@ -149,7 +149,7 @@ def fit(self, dataset, field): ValueError If the vocab or fields vocab are None """ - if self._vocab is None and (field is None or field.vocab is None): + if self._vocab is None and (field is None or not field.use_vocab): raise ValueError( "Vocab is not defined. User should define vocab in constructor " "or by providing field with a non-empty vocab property." diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index c8ed1b34..76eb458e 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -3,10 +3,12 @@ from collections import Counter from enum import Enum from itertools import chain -from typing import Iterable, Union +from typing import Iterable, List, Union import numpy as np +from podium.preproc import NumericalizerABC + def unique(values: Iterable): """Generator that iterates over the first occurrence of every value in values, @@ -60,7 +62,7 @@ class SpecialVocabSymbols(Enum): PAD = "" -class Vocab: +class Vocab(NumericalizerABC): """Class for storing vocabulary. It supports frequency counting and size limiting. @@ -97,6 +99,7 @@ def __init__( if true word frequencies will be saved for later use on the finalization """ + super(Vocab, self).__init__(eager) self._freqs = Counter() self._keep_freqs = keep_freqs self._min_freq = min_freq @@ -112,8 +115,6 @@ def __init__( self.stoi.update({k: v for v, k in enumerate(self.itos)}) self._max_size = max_size - self.eager = eager - self.finalized = False # flag to know if we're ready to numericalize @staticmethod def _init_default_unk_index(specials): @@ -194,6 +195,9 @@ def padding_index(self): raise ValueError("Padding symbol is not in the vocabulary.") return self.stoi[SpecialVocabSymbols.PAD] + def update(self, tokens: List[str]) -> None: + self.__iadd__(tokens) + def __iadd__(self, values: Union["Vocab", Iterable]): """Adds additional values or another Vocab to this Vocab. @@ -375,7 +379,7 @@ def finalize(self): if not self._keep_freqs: self._freqs = None # release memory - self.finalized = True + self.mark_finalized() def numericalize(self, data): """Method numericalizes given tokens. diff --git a/tests/arrow/test_pyarrow_tabular_dataset.py b/tests/arrow/test_pyarrow_tabular_dataset.py index 8acf5249..0d62c7af 100644 --- a/tests/arrow/test_pyarrow_tabular_dataset.py +++ b/tests/arrow/test_pyarrow_tabular_dataset.py @@ -139,21 +139,23 @@ def test_dump_and_load(pyarrow_dataset): def test_finalize_fields(data, fields, mocker): for field in fields: mocker.spy(field, "finalize") - mocker.spy(field, "update_vocab") + mocker.spy(field, "update_numericalizer") dataset = pyarrow_dataset(data, fields) for f in fields: # before finalization, no field's dict was updated - if f.vocab is not None: - assert not f.finalized + if f._numericalizer is not None: + assert not f._numericalizer.finalized dataset.finalize_fields() - fields_to_finalize = [f for f in fields if not f.eager and f.use_vocab] + fields_to_finalize = [ + f for f in fields if not f.eager and f._numericalizer is not None + ] for f in fields_to_finalize: # during finalization, only non-eager field's dict should be updated - assert f.update_vocab.call_count == (len(data) if (not f.eager) else 0) + assert f.update_numericalizer.call_count == (len(data) if (not f.eager) else 0) f.finalize.assert_called_once() # all fields should be finalized assert f.finalized diff --git a/tests/storage/test_dataset.py b/tests/storage/test_dataset.py index ef55a0f7..5e453071 100644 --- a/tests/storage/test_dataset.py +++ b/tests/storage/test_dataset.py @@ -78,7 +78,7 @@ def preprocess(self, data): return ((self.name, (raw, tokenized)),) - def update_vocab(self, tokenized): + def update_numericalizer(self, tokenized): assert not self.eager self.updated_count += 1 diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py index 6f785512..1979602e 100644 --- a/tests/storage/test_field.py +++ b/tests/storage/test_field.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from podium.preproc import NumericalizerABC from podium.storage import ( Field, LabelField, @@ -36,17 +37,19 @@ def tokenizer(self, string): return MockTokenizer() -class MockVocab(Mock): +class MockVocab(Mock, NumericalizerABC): def __init__(self, eager=True): - super(MockVocab, self).__init__(spec=Vocab) + Mock.__init__(self, spec=Vocab) + NumericalizerABC.__init__(self, eager) self.values = [] - self.finalized = False self.numericalized = False - self.eager = eager def padding_index(self): return PAD_NUM + def update(self, tokens): + self.__iadd__(tokens) + def __add__(self, values): if type(values) == type(self): pass @@ -61,8 +64,7 @@ def __iadd__(self, other): def finalize(self): if self.finalized: raise Exception - else: - self.finalized = True + self.mark_finalized() def numericalize(self, data): self.numericalized = True