-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numericalizer abc #227
base: master
Are you sure you want to change the base?
Numericalizer abc #227
Changes from all commits
f82aee7
7dde35a
eced615
09a7b5e
4f9d055
9884c06
73f711a
7395f93
ed11aad
09d401d
44bfbcd
3b9c4aa
15be987
7fcec4d
cb0e447
aa192b3
23234dd
07ec682
bca8b63
65906ec
125a015
7a26a04
5668253
f110c5f
e28a777
affa494
dac186c
56beaea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
|
||
class NumericalizerABC(ABC): | ||
"""ABC that contains the interface for Podium numericalizers. Numericalizers are used | ||
to transform tokens into vectors or any other custom datatype during batching. | ||
|
||
Attributes | ||
---------- | ||
finalized: bool | ||
Whether this numericalizer was finalized and is able to be used for | ||
numericalization. | ||
""" | ||
|
||
def __init__(self, eager=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type hint for eager is missing ( |
||
"""Initialises the Numericalizer. | ||
|
||
Parameters | ||
---------- | ||
eager: bool | ||
Whether the Numericalizer is to be updated during loading of the dataset, or | ||
after all data is loaded. | ||
|
||
""" | ||
self._finalized = False | ||
self._eager = eager | ||
|
||
@abstractmethod | ||
def numericalize(self, tokens: List[str]) -> np.ndarray: | ||
"""Converts `tokens` into a numericalized format used in batches. | ||
Numericalizations are most often numpy vectors, but any custom datatype is | ||
supported. | ||
|
||
Parameters | ||
---------- | ||
tokens: List[str] | ||
A list of strings that represent the tokens of this data point. Can also be | ||
any other datatype, as long as this Numericalizer supports it. | ||
|
||
Returns | ||
------- | ||
Numericalization used in batches. Numericalizations are most often numpy vectors, | ||
but any custom datatype is supported. | ||
""" | ||
pass | ||
|
||
def finalize(self): | ||
"""Finalizes the Numericalizer and prepares it for numericalization. | ||
This method must be overridden in classes that require finalization before | ||
numericalization. The override must call `mark_finalize` after successful | ||
completion.""" | ||
self.mark_finalized() | ||
pass | ||
|
||
def update(self, tokens: List[str]) -> None: | ||
"""Updates this Numericalizer with a single data point. Numericalizers that need | ||
to be updated example by example must override this method. Numericalizers that | ||
are eager get updated during the dataset loading process, while non-eager ones get | ||
updated after loading is finished, after all eager numericalizers were fully | ||
updated. | ||
|
||
Parameters | ||
---------- | ||
tokens: List[str] | ||
A list of strings that represent the tokens of this data point. Can also be | ||
any other datatype, as long as this Numericalizer supports it. | ||
|
||
""" | ||
pass | ||
|
||
def mark_finalized(self) -> None: | ||
"""Marks the field as finalized. This method must be called after finalization | ||
completes successfully.""" | ||
self._finalized = True | ||
|
||
Comment on lines
+74
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Honestly, not a huge fan of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd stay with |
||
@property | ||
def finalized(self) -> bool: | ||
"""Whether this Numericalizer was finalized and is ready for numericalization.""" | ||
return self._finalized | ||
|
||
@property | ||
def eager(self) -> bool: | ||
"""Whether this Numericalizer is eager. Numericalizers that | ||
are eager get updated during the dataset loading process, while non-eager ones get | ||
updated after loading is finished, after all eager numericalizers were fully | ||
updated.""" | ||
return self._eager | ||
|
||
def __call__(self, tokens: List[str]) -> np.ndarray: | ||
return self.numericalize(tokens) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,14 +6,16 @@ | |
|
||
import numpy as np | ||
|
||
from podium.preproc import NumericalizerABC | ||
from podium.preproc.tokenizers import get_tokenizer | ||
from podium.storage.vocab import Vocab | ||
|
||
|
||
PretokenizationHookType = Callable[[Any], Any] | ||
PosttokenizationHookType = Callable[[Any, List[str]], Tuple[Any, List[str]]] | ||
TokenizerType = Optional[Union[str, Callable[[Any], List[str]]]] | ||
NumericalizerType = Callable[[str], Union[int, float]] | ||
NumericalizerCallableType = Callable[[str], Union[int, float]] | ||
NumericalizerType = Union[NumericalizerABC, NumericalizerCallableType] | ||
|
||
|
||
class PretokenizationPipeline: | ||
|
@@ -205,6 +207,16 @@ def remove_pretokenize_hooks(self): | |
self._pretokenization_pipeline.clear() | ||
|
||
|
||
class NumericalizerCallableWrapper(NumericalizerABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
def __init__(self, numericalizer: NumericalizerType): | ||
super().__init__(eager=True) | ||
self._wrapped_numericalizer = numericalizer | ||
|
||
def numericalize(self, tokens: List[str]) -> np.ndarray: | ||
numericalized = [self._wrapped_numericalizer(tok) for tok in tokens] | ||
return np.array(numericalized) | ||
|
||
|
||
class Field: | ||
"""Holds the preprocessing and numericalization logic for a single | ||
field of a dataset. | ||
|
@@ -321,12 +333,16 @@ def __init__( | |
else: | ||
self._tokenizer = get_tokenizer(tokenizer) | ||
|
||
if isinstance(numericalizer, Vocab): | ||
self._vocab = numericalizer | ||
self._numericalizer = self.vocab.__getitem__ | ||
else: | ||
self._vocab = None | ||
if isinstance(numericalizer, NumericalizerABC) or numericalizer is None: | ||
self._numericalizer = numericalizer | ||
elif isinstance(numericalizer, Callable): | ||
self._numericalizer = NumericalizerCallableWrapper(numericalizer) | ||
else: | ||
err_msg = ( | ||
f"Field {name}: unsupported numericalizer type " | ||
f'"{type(numericalizer).__name__}"' | ||
) | ||
raise TypeError(err_msg) | ||
Comment on lines
+341
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would put this error message directly inside the TypeError. |
||
|
||
self._keep_raw = keep_raw | ||
|
||
|
@@ -385,12 +401,20 @@ def eager(self): | |
whether this field has a Vocab and whether that Vocab is | ||
marked as eager | ||
""" | ||
return self.vocab is not None and self.vocab.eager | ||
# Pretend to be eager if no numericalizer provided | ||
return self._numericalizer is None or self._numericalizer.eager | ||
|
||
@property | ||
def vocab(self): | ||
"""""" | ||
return self._vocab | ||
if not self.use_vocab: | ||
numericalizer_type = type(self._numericalizer).__name__ | ||
err_msg = ( | ||
f'Field "{self.name}" has no vocab, numericalizer type is ' | ||
f"{numericalizer_type}." | ||
) | ||
raise TypeError(err_msg) | ||
return self._numericalizer | ||
|
||
@property | ||
def use_vocab(self): | ||
|
@@ -402,7 +426,7 @@ def use_vocab(self): | |
Whether the field uses a vocab or not. | ||
""" | ||
|
||
return self.vocab is not None | ||
return isinstance(self._numericalizer, Vocab) | ||
|
||
@property | ||
def is_target(self): | ||
|
@@ -547,6 +571,7 @@ def preprocess( | |
|
||
# Preprocess the raw input | ||
# TODO keep unprocessed or processed raw? | ||
# Keeping processed for now, may change in the future | ||
processed_raw = self._run_pretokenization_hooks(data) | ||
tokenized = ( | ||
self._tokenizer(processed_raw) | ||
|
@@ -556,7 +581,7 @@ def preprocess( | |
|
||
return (self._process_tokens(processed_raw, tokenized),) | ||
|
||
def update_vocab(self, tokenized: List[str]): | ||
def update_numericalizer(self, tokenized: Union[str, List[str]]) -> None: | ||
"""Updates the vocab with a data point in its tokenized form. | ||
If the field does not do tokenization, | ||
|
||
|
@@ -567,11 +592,11 @@ def update_vocab(self, tokenized: List[str]): | |
updated with. | ||
""" | ||
|
||
if not self.use_vocab: | ||
if self._numericalizer is None: | ||
return # TODO throw Error? | ||
|
||
data = tokenized if isinstance(tokenized, (list, tuple)) else (tokenized,) | ||
self._vocab += data | ||
self._numericalizer.update(data) | ||
|
||
@property | ||
def finalized(self) -> bool: | ||
|
@@ -584,13 +609,13 @@ def finalized(self) -> bool: | |
Whether the field's Vocab vas finalized. If the field has no | ||
vocab, returns True. | ||
""" | ||
return True if self.vocab is None else self.vocab.finalized | ||
return self._numericalizer is None or self._numericalizer.finalized | ||
|
||
def finalize(self): | ||
"""Signals that this field's vocab can be built.""" | ||
|
||
if self.use_vocab: | ||
self.vocab.finalize() | ||
if self._numericalizer is not None: | ||
self._numericalizer.finalize() | ||
|
||
def _process_tokens( | ||
self, raw: Any, tokens: Union[Any, List[str]] | ||
|
@@ -616,8 +641,12 @@ def _process_tokens( | |
raw, tokenized = self._run_posttokenization_hooks(raw, tokens) | ||
raw = raw if self._keep_raw else None | ||
|
||
if self.eager and not self.vocab.finalized: | ||
self.update_vocab(tokenized) | ||
if ( | ||
self.eager | ||
and self._numericalizer is not None | ||
and not self._numericalizer.finalized | ||
): | ||
self.update_numericalizer(tokenized) | ||
return self.name, (raw, tokenized) | ||
|
||
def get_default_value(self) -> Union[int, float]: | ||
|
@@ -679,10 +708,7 @@ def numericalize( | |
|
||
tokens = tokenized if isinstance(tokenized, (list, tuple)) else [tokenized] | ||
|
||
if self.use_vocab: | ||
return self.vocab.numericalize(tokens) | ||
else: | ||
return np.array([self._numericalizer(t) for t in tokens]) | ||
return self._numericalizer.numericalize(tokens) | ||
|
||
def _pad_to_length( | ||
self, | ||
|
@@ -1030,16 +1056,16 @@ def __init__( | |
|
||
def finalize(self): | ||
"""Signals that this field's vocab can be built.""" | ||
super().finalize() | ||
if self._num_of_classes is None: | ||
self.fixed_length = self._num_of_classes = len(self.vocab) | ||
self._fixed_length = self._num_of_classes = len(self.vocab) | ||
|
||
if self.use_vocab and len(self.vocab) > self._num_of_classes: | ||
raise ValueError( | ||
"Number of classes in data is greater than the declared number " | ||
f"of classes. Declared: {self._num_of_classes}, " | ||
f"Actual: {len(self.vocab)}" | ||
) | ||
super().finalize() | ||
|
||
def numericalize( | ||
self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would rename to Numericalizer, to follow naming conventions in collections.abc