From f82aee79e69cc9127ce6428da5abac7d9758344e Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Mon, 26 Oct 2020 23:32:43 +0100 Subject: [PATCH 01/26] WIP: DatasetABC --- podium/datasets/__init__.py | 2 + podium/datasets/dataset.py | 91 ++++----------- podium/datasets/dataset_abc.py | 208 +++++++++++++++++++++++++++++++++ 3 files changed, 230 insertions(+), 71 deletions(-) create mode 100644 podium/datasets/dataset_abc.py diff --git a/podium/datasets/__init__.py b/podium/datasets/__init__.py index e0493899..16674f89 100644 --- a/podium/datasets/__init__.py +++ b/podium/datasets/__init__.py @@ -1,6 +1,7 @@ """Package contains datasets""" from .dataset import Dataset, rationed_split, stratified_split +from .dataset_abc import DatasetABC from .hierarhical_dataset import HierarchicalDataset from .impl.catacx_dataset import CatacxDataset from .impl.conllu_dataset import CoNLLUDataset @@ -21,6 +22,7 @@ __all__ = [ "Dataset", + "DatasetABC", "TabularDataset", "HierarchicalDataset", "stratified_split", diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index 730644a0..e6b8058a 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -4,16 +4,16 @@ import logging import random from abc import ABC -from typing import Any, Callable +from typing import Any, Callable, Dict, List from podium.storage.example_factory import Example -from podium.storage.field import unpack_fields - +from podium.storage.field import unpack_fields, Field +from podium.datasets import DatasetABC _LOGGER = logging.getLogger(__name__) -class Dataset(ABC): +class Dataset(DatasetABC): """A general purpose container for datasets. A dataset is a shallow wrapper for a list of `Example` classes which store the instance data as well as the corresponding `Field` classes, @@ -42,53 +42,14 @@ def __init__(self, examples, fields, sort_key=None): together examples with similar lengths to minimize padding. """ - self.examples = examples + self._examples = examples # we don't need the column -> field mappings with nested tuples # and None values, we just need a flat list of fields - self.fields = unpack_fields(fields) - - self.field_dict = {field.name: field for field in self.fields} - self.sort_key = sort_key + self._fields = unpack_fields(fields) + self._sort_key = sort_key - def __getitem__(self, i): - """Returns an example or a new dataset containing the indexed examples. - - If indexed with an int, only the example at that position will be returned. - If Indexed with a slice or iterable, all examples indexed by the object - will be collected and a new dataset containing only those examples will be - returned. The new dataset will contain copies of the old dataset's fields and - will be identical to the original dataset, with the exception of the example - number and ordering. See wiki for detailed examples. - - Examples in the returned Dataset are the same ones present in the - original dataset. If a complete deep-copy of the dataset, or its slice, - is needed please refer to the `get` method. - - Usage example: - - example = dataset[1] # Indexing by single integer returns a single example - - new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing - # the indexed examples. - - Parameters - ---------- - i : int or slice or iterable - Index used to index examples. - - Returns - ------- - single example or Dataset - If i is an int, a single example will be returned. - If i is a slice or iterable, a copy of this dataset containing - only the indexed examples will be returned. - - """ - - return self.get(i, deep_copy=False) - - def get(self, i, deep_copy=False): + def _get(self, i, deep_copy=False): """Returns an example or a new dataset containing the indexed examples. If indexed with an int, only the example at that position @@ -196,6 +157,18 @@ def attr_generator(attr): else: raise AttributeError(f"Dataset has no field {attr}.") + def _fields_list(self) -> List[Field]: + return self._fields + + def get_example_list(self) -> List[Example]: + return self._examples + + def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": + filtered_examples = [ex for ex in self.examples if predicate(ex)] + return Dataset( + examples=filtered_examples, fields=self.fields, sort_key=self._sort_key + ) + def filter(self, predicate, inplace=False): """Method filters examples with given predicate. @@ -438,30 +411,6 @@ def shuffle_examples(self, random_state=None): random.shuffle(self.examples) - def batch(self): - """Creates an input and target batch containing the whole dataset. - The format of the batch is the same as the batches returned by the - - Returns - ------- - input_batch, target_batch - Two objects containing the input and target batches over - the whole dataset. - """ - # dicts that will be used to create the InputBatch and TargetBatch - # objects - from podium.datasets import SingleBatchIterator - - return next(iter(SingleBatchIterator(self, shuffle=False))) - - def sorted(self, key: Callable[[Example], Any], reverse=False) -> "Dataset": - def index_key(i): - return key(self[i]) - - indices = list(range(len(self))) - indices.sort(key=index_key, reverse=reverse) - return self[indices] - def as_dataset(self): return self diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py new file mode 100644 index 00000000..065576c3 --- /dev/null +++ b/podium/datasets/dataset_abc.py @@ -0,0 +1,208 @@ +from abc import ABC, abstractmethod +from typing import Union, Iterable, Iterator, List, Callable, Tuple, Dict, NamedTuple, Any + +from podium.storage import Example, Field +from podium.datasets import SingleBatchIterator + + +class DatasetABC(ABC): + # ================= Default methods ======================= + def __getitem__( + self, i: Union[int, Iterable[int], slice] + ) -> Union["DatasetABC", Example]: + """Returns an example or a new dataset containing the indexed examples. + + If indexed with an int, only the example at that position will be returned. + If Indexed with a slice or iterable, all examples indexed by the object + will be collected and a new dataset containing only those examples will be + returned. The new dataset will contain copies of the old dataset's fields and + will be identical to the original dataset, with the exception of the example + number and ordering. See wiki for detailed examples. + + Examples in the returned Dataset are the same ones present in the + original dataset. If a complete deep-copy of the dataset, or its slice, + is needed please refer to the `get` method. + + Usage example: + + example = dataset[1] # Indexing by single integer returns a single example + + new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing + # the indexed examples. + + Parameters + ---------- + i : int or slice or iterable + Index used to index examples. + + Returns + ------- + single example or Dataset + If i is an int, a single example will be returned. + If i is a slice or iterable, a copy of this dataset containing + only the indexed examples will be returned. + + """ + + return self.get(i) + + def __iter__(self) -> Iterator[Example]: + """Iterates over all examples in the dataset in order. + + Yields + ------ + Example + Yields examples in the dataset. + """ + for i in range(len(self)): + yield self[i] + + def __getattr__(self, field_name: str) -> Iterator[Tuple[Any, Any]]: + """Returns an Iterator iterating over values of the field with the + given name for every example in the dataset. + + Parameters + ---------- + field_name : str + The name of the field whose values are to be returned. + + Returns + ------ + an Iterator iterating over values of the field with the given name + for every example in the dataset. + + Raises + ------ + AttributeError + If there is no Field with the given name. + """ + for example in self: + return getattr(example, field_name) + + def finalize_fields(self, *datasets: "DatasetABC") -> None: + """ + Builds vocabularies of all the non-eager fields in the dataset, + from the Dataset objects given as \\*args and then finalizes all the + fields. + + Parameters + ---------- + \\*datasets + A variable number of DatasetABC objects from which to build the + vocabularies for non-eager fields. If none provided, the + vocabularies are built from this Dataset (self). + """ + + # if there are non-eager fields, we need to build their vocabularies + fields_to_build = [f for f in self.fields if not f.eager and f.use_vocab] + if fields_to_build: + # there can be multiple datasets we want to iterate over + data_sources = [ds for ds in datasets if isinstance(ds, DatasetABC)] + + # use self as a data source if no other given + if not data_sources: + data_sources.append(self) + + # for each example in each dataset, + # update _all_ non-eager fields + for dataset in data_sources: + for example in dataset: + for field in fields_to_build: + _, tokenized = getattr(example, field.name) + field.update_vocab(tokenized) + + for field in self.fields: + field.finalize() + + def batch(self) -> Tuple[NamedTuple, NamedTuple]: + """Creates an input and target batch containing the whole dataset. + The format of the batch is the same as the batches returned by the + + Returns + ------- + input_batch, target_batch + Two objects containing the input and target batches over + the whole dataset. + """ + # Imported here because of circular import + from podium.datasets import SingleBatchIterator + + return next(iter(SingleBatchIterator(self, shuffle=False))) + + def sorted(self, key: Callable[[Example], Any], reverse=False) -> "DatasetABC": + """Creates a new DatasetABC instance in which all Examples are sorted according to + the value returned by `key`. + + Parameters + ---------- + key: callable + specifies a function of one argument that is used to extract a comparison key + from each Example. + + reverse: bool + If set to True, then the list elements are sorted as if each comparison were + reversed. + + Returns + ------- + DatasetABC + A new DatasetABC instance with sorted Examples. + """ + + def index_key(i): + return key(self[i]) + + indices = list(range(len(self))) + indices.sort(key=index_key, reverse=reverse) + return self[indices] + + @property + def fields(self) -> List[Field]: + """List containing all fields of this dataset.""" + return self._fields_list() + + @property + def field_dict(self) -> Dict[str, Field]: + return {f.name: f for f in self.fields} + + @property + def examples(self) -> List[Example]: + return self.get_example_list() + + # ================= Abstract methods ======================= + + @abstractmethod + def __len__(self) -> int: + pass + + @abstractmethod + def _fields_list(self) -> List[Field]: + """Returns a list of all Fields in this dataset. + + Returns + ------- + list of Fields + a list of all Fields in this Dataset + """ + pass + + @abstractmethod + def _get(self, i: Union[int, Iterable[int], slice]) -> Union["DatasetABC", Example]: + pass + + @abstractmethod + def get_example_list(self) -> List[Example]: + pass + + @abstractmethod + def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": + """Method filters examples with given predicate and returns a new DatasetABC + instance containing those examples. + + Parameters + ---------- + predicate : callable + predicate should be a callable that accepts example as input and returns + true if the example shouldn't be filtered, otherwise returns false + """ + pass From 7dde35a544ea01ae293e2c862730571670689266 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 27 Oct 2020 11:32:34 +0100 Subject: [PATCH 02/26] WIP: DatasetABC --- podium/datasets/dataset.py | 95 +++++-------------- podium/datasets/dataset_abc.py | 87 +++++++++-------- podium/datasets/impl/sst_sentiment_dataset.py | 6 +- tests/storage/test_dataset.py | 8 +- 4 files changed, 77 insertions(+), 119 deletions(-) diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index e6b8058a..aa54095f 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -3,12 +3,11 @@ import itertools import logging import random -from abc import ABC from typing import Any, Callable, Dict, List +from podium.datasets.dataset_abc import DatasetABC from podium.storage.example_factory import Example -from podium.storage.field import unpack_fields, Field -from podium.datasets import DatasetABC +from podium.storage.field import Field, unpack_fields _LOGGER = logging.getLogger(__name__) @@ -43,11 +42,8 @@ def __init__(self, examples, fields, sort_key=None): """ self._examples = examples - - # we don't need the column -> field mappings with nested tuples - # and None values, we just need a flat list of fields - self._fields = unpack_fields(fields) - self._sort_key = sort_key + self.sort_key = sort_key + super().__init__(fields) def _get(self, i, deep_copy=False): """Returns an example or a new dataset containing the indexed examples. @@ -105,6 +101,9 @@ def _get(self, i, deep_copy=False): indexed_examples = [self.examples[index] for index in i] return self._dataset_copy_with_examples(indexed_examples, deep_copy=deep_copy) + def _get_examples(self) -> List[Example]: + return self._examples + def __len__(self): """Returns the number of examples in the dataset. @@ -148,27 +147,18 @@ def __getattr__(self, attr): if attr in self.field_dict: - def attr_generator(attr): - for x in self.examples: - yield x[attr] + def attr_generator(dataset, attr): + for x in dataset.examples: + yield getattr(x, attr) - return attr_generator(attr) + return attr_generator(self, attr) else: raise AttributeError(f"Dataset has no field {attr}.") - def _fields_list(self) -> List[Field]: - return self._fields - def get_example_list(self) -> List[Example]: return self._examples - def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": - filtered_examples = [ex for ex in self.examples if predicate(ex)] - return Dataset( - examples=filtered_examples, fields=self.fields, sort_key=self._sort_key - ) - def filter(self, predicate, inplace=False): """Method filters examples with given predicate. @@ -181,7 +171,7 @@ def filter(self, predicate, inplace=False): if True, do operation inplace and return None """ if inplace: - self.examples = [example for example in self.examples if predicate(example)] + self._examples = [example for example in self.examples if predicate(example)] return filtered_examples = [copy.deepcopy(ex) for ex in self.examples if predicate(ex)] @@ -191,48 +181,13 @@ def filter(self, predicate, inplace=False): examples=filtered_examples, fields=fields_copy, sort_key=self.sort_key ) - def finalize_fields(self, *args): - """ - Builds vocabularies of all the non-eager fields in the dataset, - from the Dataset objects given as \\*args and then finalizes all the - fields. - - Parameters - ---------- - \\*args - A variable number of Dataset objects from which to build the - vocabularies for non-eager fields. If none provided, the - vocabularies are built from this Dataset (self). - """ - - # if there are non-eager fields, we need to build their vocabularies - fields_to_build = [f for f in self.fields if not f.eager and f.use_vocab] - if fields_to_build: - # there can be multiple datasets we want to iterate over - data_sources = [arg for arg in args if isinstance(arg, Dataset)] - - # use self as a data source if no other given - if not data_sources: - data_sources.append(self) - - # for each example in each dataset, - # update _all_ non-eager fields - for dataset in data_sources: - for example in dataset: - for field in fields_to_build: - _, tokenized = example[field.name] - field.update_vocab(tokenized) - - for field in self.fields: - field.finalize() - def split( - self, - split_ratio=0.7, - stratified=False, - strata_field_name=None, - random_state=None, - shuffle=True, + self, + split_ratio=0.7, + stratified=False, + strata_field_name=None, + random_state=None, + shuffle=True, ): """Creates train-(validation)-test splits from this dataset. @@ -370,7 +325,7 @@ def __setstate__(self, state): self.__dict__ = state def _dataset_copy_with_examples( - self, examples: list, deep_copy: bool = False + self, examples: list, deep_copy: bool = False ) -> "Dataset": """Creates a new dataset with the same fields and sort_key. The new dataset contains only the fields passed to this function.Fields are deep-copied into @@ -414,10 +369,6 @@ def shuffle_examples(self, random_state=None): def as_dataset(self): return self - def __repr__(self): - fields = [field.name for field in self.fields] - return f"{type(self).__name__}[Size: {len(self.examples)}, Fields: {fields}]" - def check_split_ratio(split_ratio): """Checks that the split ratio argument is not malformed and if not @@ -550,8 +501,8 @@ def rationed_split(examples, train_ratio, val_ratio, test_ratio, shuffle): indices_tuple = ( indices[:train_len], # Train - indices[train_len : train_len + val_len], # Validation - indices[train_len + val_len :], # Test + indices[train_len: train_len + val_len], # Validation + indices[train_len + val_len:], # Test ) # Create a tuple of 3 lists, the middle of which is empty if only the @@ -562,7 +513,7 @@ def rationed_split(examples, train_ratio, val_ratio, test_ratio, shuffle): def stratified_split( - examples, train_ratio, val_ratio, test_ratio, strata_field_name, shuffle + examples, train_ratio, val_ratio, test_ratio, strata_field_name, shuffle ): """Performs a stratified split on a list of examples according to the given ratios and the given strata field. @@ -597,7 +548,7 @@ def stratified_split( """ # group the examples by the strata_field - strata = itertools.groupby(examples, key=lambda ex: ex[strata_field_name]) + strata = itertools.groupby(examples, key=lambda ex: getattr(ex, strata_field_name)) strata = (list(group) for _, group in strata) train_split, val_split, test_split = [], [], [] diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 065576c3..334a7eb0 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -1,14 +1,39 @@ from abc import ABC, abstractmethod -from typing import Union, Iterable, Iterator, List, Callable, Tuple, Dict, NamedTuple, Any +from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, \ + Union -from podium.storage import Example, Field -from podium.datasets import SingleBatchIterator +from podium.storage import Example, Field, unpack_fields + +FieldArg = Union[Field, List[Field], None] +DictFields = Dict[str, FieldArg] +ListFields = List[FieldArg] class DatasetABC(ABC): + + def __init__(self, fields: Union[DictFields, ListFields]): + self._fields = unpack_fields(fields) + + # ==================== Properties ========================= + + @property + def fields(self) -> List[Field]: + """List containing all fields of this dataset.""" + return self._fields + + @property + def field_dict(self) -> Dict[str, Field]: + """Dictionary containing all field names mapping to their respective Fields.""" + return {f.name: f for f in self.fields} + + @property + def examples(self) -> List[Example]: + """List containing all Examples.""" + return self.get_example_list() + # ================= Default methods ======================= def __getitem__( - self, i: Union[int, Iterable[int], slice] + self, i: Union[int, Iterable[int], slice] ) -> Union["DatasetABC", Example]: """Returns an example or a new dataset containing the indexed examples. @@ -44,7 +69,7 @@ def __getitem__( """ - return self.get(i) + return self._get(i) def __iter__(self) -> Iterator[Example]: """Iterates over all examples in the dataset in order. @@ -156,18 +181,24 @@ def index_key(i): indices.sort(key=index_key, reverse=reverse) return self[indices] - @property - def fields(self) -> List[Field]: - """List containing all fields of this dataset.""" - return self._fields_list() + def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": + """Method filters examples with given predicate and returns a new DatasetABC + instance containing those examples. - @property - def field_dict(self) -> Dict[str, Field]: - return {f.name: f for f in self.fields} + Parameters + ---------- + predicate : callable + predicate should be a callable that accepts example as input and returns + true if the example shouldn't be filtered, otherwise returns false + """ + indices = [i for i, example in enumerate(self) if predicate(example)] + return self[indices] - @property - def examples(self) -> List[Example]: - return self.get_example_list() + # TODO shuffled? + + def __repr__(self): + fields = [field.name for field in self.fields] + return f"{type(self).__name__}[Size: {len(self.examples)}, Fields: {fields}]" # ================= Abstract methods ======================= @@ -175,34 +206,10 @@ def examples(self) -> List[Example]: def __len__(self) -> int: pass - @abstractmethod - def _fields_list(self) -> List[Field]: - """Returns a list of all Fields in this dataset. - - Returns - ------- - list of Fields - a list of all Fields in this Dataset - """ - pass - @abstractmethod def _get(self, i: Union[int, Iterable[int], slice]) -> Union["DatasetABC", Example]: pass @abstractmethod - def get_example_list(self) -> List[Example]: - pass - - @abstractmethod - def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": - """Method filters examples with given predicate and returns a new DatasetABC - instance containing those examples. - - Parameters - ---------- - predicate : callable - predicate should be a callable that accepts example as input and returns - true if the example shouldn't be filtered, otherwise returns false - """ + def _get_examples(self) -> List[Example]: pass diff --git a/podium/datasets/impl/sst_sentiment_dataset.py b/podium/datasets/impl/sst_sentiment_dataset.py index 6f20c87b..c957b0bc 100644 --- a/podium/datasets/impl/sst_sentiment_dataset.py +++ b/podium/datasets/impl/sst_sentiment_dataset.py @@ -68,13 +68,13 @@ def __init__(self, file_path, fields, fine_grained=False, subtrees=False): } ) # Assign these to enable filtering - self.examples = self._create_examples( + self._examples = self._create_examples( file_path=file_path, fields=fields, fine_grained=fine_grained, subtrees=subtrees, ) - self.fields = fields + self._fields = fields # If not fine-grained, return binary task: filter out neutral instances if not fine_grained: @@ -84,7 +84,7 @@ def filter_neutral(example): self.filter(predicate=filter_neutral, inplace=True) - super(SST, self).__init__(**{"examples": self.examples, "fields": self.fields}) + super(SST, self).__init__(**{"examples": self._examples, "fields": self._fields}) @staticmethod def _create_examples(file_path, fields, fine_grained, subtrees): diff --git a/tests/storage/test_dataset.py b/tests/storage/test_dataset.py index ef55a0f7..561048a1 100644 --- a/tests/storage/test_dataset.py +++ b/tests/storage/test_dataset.py @@ -588,13 +588,13 @@ def test_dataset_deep_copy(data, field_list): original_dataset = create_dataset(data, field_list) original_examples = original_dataset.examples - dataset_no_deep_copy = original_dataset.get(slice(0, 5), deep_copy=False) + dataset_no_deep_copy = original_dataset._get(slice(0, 5), deep_copy=False) for original, copy in zip(original_dataset.fields, dataset_no_deep_copy.fields): assert copy is original for original, copy in zip(original_examples, dataset_no_deep_copy.examples): assert copy is original - dataset_deep_copy = original_dataset.get(slice(0, 5), deep_copy=True) + dataset_deep_copy = original_dataset._get(slice(0, 5), deep_copy=True) assert original_dataset.fields is not dataset_deep_copy.fields for original, copy in zip(original_dataset.fields, dataset_deep_copy.fields): @@ -606,9 +606,9 @@ def test_dataset_deep_copy(data, field_list): assert copy["label"] == original["label"] original_example = original_examples[2] - no_copy_example = original_dataset.get(2, deep_copy=False) + no_copy_example = original_dataset._get(2, deep_copy=False) indexed_example = original_dataset[2] - deep_copied_example = original_dataset.get(2, deep_copy=True) + deep_copied_example = original_dataset._get(2, deep_copy=True) assert no_copy_example is original_example assert indexed_example is original_example From eced615232c22a9f93f125f1e5f8752c0e20b730 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 27 Oct 2020 12:24:56 +0100 Subject: [PATCH 03/26] updated Dataset --- podium/datasets/dataset.py | 86 ++++++------------- podium/datasets/dataset_abc.py | 70 +++++++++++++-- podium/datasets/impl/sst_sentiment_dataset.py | 7 +- 3 files changed, 92 insertions(+), 71 deletions(-) diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index aa54095f..6ec13026 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -9,6 +9,7 @@ from podium.storage.example_factory import Example from podium.storage.field import Field, unpack_fields + _LOGGER = logging.getLogger(__name__) @@ -84,7 +85,7 @@ def _get(self, i, deep_copy=False): ------- single example or Dataset If i is an int, a single example will be returned. - If i is a slice or iterable, a copy of this dataset containing + If i is a slice or iterable, a dataset containing only the indexed examples will be returned. """ @@ -101,10 +102,7 @@ def _get(self, i, deep_copy=False): indexed_examples = [self.examples[index] for index in i] return self._dataset_copy_with_examples(indexed_examples, deep_copy=deep_copy) - def _get_examples(self) -> List[Example]: - return self._examples - - def __len__(self): + def __len__(self) -> int: """Returns the number of examples in the dataset. Returns @@ -112,7 +110,10 @@ def __len__(self): int The number of examples in the dataset. """ - return len(self.examples) + return len(self._examples) + + def _get_examples(self) -> List[Example]: + return self._examples def __iter__(self): """Iterates over all examples in the dataset in order. @@ -122,42 +123,7 @@ def __iter__(self): example Yields examples in the dataset. """ - for x in self.examples: - yield x - - def __getattr__(self, attr): - """Returns an Iterator iterating over values of the field with the - given name for every example in the dataset. - - Parameters - ---------- - attr : str - The name of the field whose values are to be returned. - - Returns - ------ - an Iterator iterating over values of the field with the given name - for every example in the dataset. - - Raises - ------ - AttributeError - If there is no Field with the given name. - """ - - if attr in self.field_dict: - - def attr_generator(dataset, attr): - for x in dataset.examples: - yield getattr(x, attr) - - return attr_generator(self, attr) - - else: - raise AttributeError(f"Dataset has no field {attr}.") - - def get_example_list(self) -> List[Example]: - return self._examples + yield from self._examples def filter(self, predicate, inplace=False): """Method filters examples with given predicate. @@ -170,24 +136,26 @@ def filter(self, predicate, inplace=False): inplace : bool, default False if True, do operation inplace and return None """ + filtered_examples = [ex for ex in self.examples if predicate(ex)] + if inplace: - self._examples = [example for example in self.examples if predicate(example)] + self._examples = filtered_examples return + else: + return Dataset( + examples=filtered_examples, fields=self.fields, sort_key=self.sort_key + ) - filtered_examples = [copy.deepcopy(ex) for ex in self.examples if predicate(ex)] - fields_copy = copy.deepcopy(self.fields) - - return Dataset( - examples=filtered_examples, fields=fields_copy, sort_key=self.sort_key - ) + def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": + return self.filter(predicate, inplace=False) def split( - self, - split_ratio=0.7, - stratified=False, - strata_field_name=None, - random_state=None, - shuffle=True, + self, + split_ratio=0.7, + stratified=False, + strata_field_name=None, + random_state=None, + shuffle=True, ): """Creates train-(validation)-test splits from this dataset. @@ -325,7 +293,7 @@ def __setstate__(self, state): self.__dict__ = state def _dataset_copy_with_examples( - self, examples: list, deep_copy: bool = False + self, examples: list, deep_copy: bool = False ) -> "Dataset": """Creates a new dataset with the same fields and sort_key. The new dataset contains only the fields passed to this function.Fields are deep-copied into @@ -501,8 +469,8 @@ def rationed_split(examples, train_ratio, val_ratio, test_ratio, shuffle): indices_tuple = ( indices[:train_len], # Train - indices[train_len: train_len + val_len], # Validation - indices[train_len + val_len:], # Test + indices[train_len : train_len + val_len], # Validation + indices[train_len + val_len :], # Test ) # Create a tuple of 3 lists, the middle of which is empty if only the @@ -513,7 +481,7 @@ def rationed_split(examples, train_ratio, val_ratio, test_ratio, shuffle): def stratified_split( - examples, train_ratio, val_ratio, test_ratio, strata_field_name, shuffle + examples, train_ratio, val_ratio, test_ratio, strata_field_name, shuffle ): """Performs a stratified split on a list of examples according to the given ratios and the given strata field. diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 334a7eb0..0266554c 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -1,18 +1,18 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, \ - Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, Union from podium.storage import Example, Field, unpack_fields + FieldArg = Union[Field, List[Field], None] DictFields = Dict[str, FieldArg] ListFields = List[FieldArg] class DatasetABC(ABC): - def __init__(self, fields: Union[DictFields, ListFields]): self._fields = unpack_fields(fields) + self._fields_dict = {f.name: f for f in self._fields} # ==================== Properties ========================= @@ -24,16 +24,17 @@ def fields(self) -> List[Field]: @property def field_dict(self) -> Dict[str, Field]: """Dictionary containing all field names mapping to their respective Fields.""" - return {f.name: f for f in self.fields} + return self._fields_dict @property def examples(self) -> List[Example]: """List containing all Examples.""" - return self.get_example_list() + return self._get_examples() # ================= Default methods ======================= + def __getitem__( - self, i: Union[int, Iterable[int], slice] + self, i: Union[int, Iterable[int], slice] ) -> Union["DatasetABC", Example]: """Returns an example or a new dataset containing the indexed examples. @@ -101,8 +102,16 @@ def __getattr__(self, field_name: str) -> Iterator[Tuple[Any, Any]]: AttributeError If there is no Field with the given name. """ - for example in self: - return getattr(example, field_name) + if field_name in self.field_dict: + + def attr_generator(_dataset, _field_name): + for x in _dataset: + yield getattr(x, _field_name) + + return attr_generator(self, field_name) + + else: + raise AttributeError(f"Dataset has no field {field_name}.") def finalize_fields(self, *datasets: "DatasetABC") -> None: """ @@ -204,12 +213,57 @@ def __repr__(self): @abstractmethod def __len__(self) -> int: + """Returns the number of examples in the dataset. + + Returns + ------- + int + The number of examples in the dataset. + """ pass @abstractmethod def _get(self, i: Union[int, Iterable[int], slice]) -> Union["DatasetABC", Example]: + """Returns an example or a new dataset containing the indexed examples. + + If indexed with an int, only the example at that position + will be returned. + If Indexed with a slice or iterable, all examples indexed by the object + will be collected and a new dataset containing only those examples will be + returned. The new dataset will contain copies of the old dataset's fields + and will be identical to the original dataset, with the exception of the + example number and ordering. See wiki for detailed examples. + + Example + ------- + # Indexing by a single integers returns a single example + example = dataset.get(1) + + # Same as the first example, but returns a deep_copy of the example + example_copy = dataset.get(1, deep_copy=True) + + # Multi-indexing returns a new dataset containing the indexed examples + s = slice(1, 10) + new_dataset = dataset.get(s) + + new_dataset_copy = dataset.get(s, deep_copy=True) + + Parameters + ---------- + i : int or slice or iterable + Index used to index examples. + + Returns + ------- + single example or DatasetABC + If i is an int, a single example will be returned. + If i is a slice or iterable, a dataset containing + only the indexed examples will be returned. + + """ pass @abstractmethod def _get_examples(self) -> List[Example]: + """Returns a list containing all examples of this dataset.""" pass diff --git a/podium/datasets/impl/sst_sentiment_dataset.py b/podium/datasets/impl/sst_sentiment_dataset.py index c957b0bc..51b95fc6 100644 --- a/podium/datasets/impl/sst_sentiment_dataset.py +++ b/podium/datasets/impl/sst_sentiment_dataset.py @@ -68,13 +68,12 @@ def __init__(self, file_path, fields, fine_grained=False, subtrees=False): } ) # Assign these to enable filtering - self._examples = self._create_examples( + examples = self._create_examples( file_path=file_path, fields=fields, fine_grained=fine_grained, subtrees=subtrees, ) - self._fields = fields # If not fine-grained, return binary task: filter out neutral instances if not fine_grained: @@ -82,9 +81,9 @@ def __init__(self, file_path, fields, fine_grained=False, subtrees=False): def filter_neutral(example): return example["label"][1] != "neutral" - self.filter(predicate=filter_neutral, inplace=True) + examples = [ex for ex in examples if filter_neutral(ex)] - super(SST, self).__init__(**{"examples": self._examples, "fields": self._fields}) + super(SST, self).__init__(**{"examples": examples, "fields": fields}) @staticmethod def _create_examples(file_path, fields, fine_grained, subtrees): From 09a7b5e05f4f49f5e52dc46931ff2c1655458c2c Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 27 Oct 2020 12:49:31 +0100 Subject: [PATCH 04/26] updated ArrowDataset --- podium/arrow/arrow_tabular_dataset.py | 129 ++------------------ tests/arrow/test_pyarrow_tabular_dataset.py | 4 +- 2 files changed, 10 insertions(+), 123 deletions(-) diff --git a/podium/arrow/arrow_tabular_dataset.py b/podium/arrow/arrow_tabular_dataset.py index 46c23ec8..d23e17f1 100644 --- a/podium/arrow/arrow_tabular_dataset.py +++ b/podium/arrow/arrow_tabular_dataset.py @@ -8,7 +8,7 @@ from collections import defaultdict from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union -from podium.datasets import Dataset +from podium.datasets import Dataset, DatasetABC from podium.storage import ExampleFactory, Field, unpack_fields from podium.storage.example_factory import Example @@ -36,7 +36,7 @@ def _group(iterable, n): yield chunk -class ArrowDataset: +class ArrowDataset(DatasetABC): """Podium dataset implementation which uses PyArrow as its data storage backend. Examples are stored in a file which is then memory mapped for fast random access. """ @@ -78,11 +78,10 @@ def __init__( inferred data type if possible. """ self.cache_path = cache_path - self.fields = unpack_fields(fields) - self.field_dict = {field.name: field for field in fields} self.mmapped_file = mmapped_file self.table = table self.data_types = data_types + super().__init__(fields) @staticmethod def from_dataset( @@ -562,11 +561,10 @@ def dump_cache(self, cache_path: str = None) -> str: return cache_path - @property - def examples(self): + def _get_examples(self) -> List[Example]: """Loads this ArrowDataset into memory and returns a list containing the loaded Examples.""" - return self.as_dataset().examples + return list(ArrowDataset._recordbatch_to_examples(self.table, self.fields)) def as_dataset(self) -> Dataset: """Loads this ArrowDataset into memory and returns an Dataset object containing @@ -580,19 +578,6 @@ def as_dataset(self) -> Dataset: examples = list(ArrowDataset._recordbatch_to_examples(self.table, self.fields)) return Dataset(examples, self.fields) - def batch(self): - """Creates an input and target batch containing the whole dataset. - The format of the batch is the same as the batches returned by the Iterator class. - - Returns - ------- - input_batch, target_batch - Two objects containing the input and target batches over - the whole dataset. - """ - # TODO custom batch method? - return self.as_dataset().batch() - @staticmethod def _field_values( record_batch: pa.RecordBatch, fieldname: str @@ -632,7 +617,9 @@ def _field_values( return zip(raw_values, tokenized_values) - def __getitem__(self, item, deep_copy=False) -> Union[Example, "ArrowDataset"]: + def _get( + self, item: Union[int, Iterable[int], slice] + ) -> Union[Example, "ArrowDataset"]: """Returns an example or a new ArrowDataset containing the indexed examples. If indexed with an int, only the example at that position will be returned. If Indexed with a slice or iterable, all examples indexed by the object @@ -658,9 +645,6 @@ def __getitem__(self, item, deep_copy=False) -> Union[Example, "ArrowDataset"]: item: int or slice or iterable Index used to index examples. - deep_copy: bool - Not used. - Returns ------- Example or Dataset @@ -712,69 +696,6 @@ def __iter__(self) -> Iterator[Example]: """ yield from self._recordbatch_to_examples(self.table, self.fields) - def __getattr__(self, fieldname) -> Iterator[Tuple[Any, Any]]: - """Iterates over the raw and tokenized values of all examples in this dataset. - - Parameters - ---------- - fieldname: str - Name of the field whose values are to be iterated over. - - Returns - ------- - Iterator[Tuple[raw, tokenized]] - Iterator over the raw and tokenized values of all examples in this dataset. - - """ - if fieldname in self.field_dict: - return ArrowDataset._field_values(self.table, fieldname) - - else: - raise AttributeError(f"Dataset has no field {fieldname}.") - - def filter(self, predicate: Callable[[Example], bool]) -> "ArrowDataset": - """Creates a new ArrowDataset instance containing only Examples for which the - predicate returns True. - - Parameters - ---------- - predicate : Callable[[Example], bool] - Callable used as a filtering predicate. It takes an Example as a parameter - and returns True if the Example is to be accepted, and False otherwise. - - Returns - ------- - ArrowDataset - New ArrowDataset containing Filtered Examples. - """ - indices = [i for i, example in enumerate(self) if predicate(example)] - return self[indices] - - def sorted(self, key: Callable[[Example], Any], reverse=False) -> "ArrowDataset": - """Returns a new ArrowDataset with sorted Examples. - - Parameters - ---------- - key: Callable[[Example], Any] - Extracts a comparable value from an Example. - That value will be used to determine Example ordering. - - reverse: bool - If True, the returned dataset will be reversed. - - Returns - ------- - ArrowDataset - An ArrowDataset containing sorted Examples from this dataset. - """ - - def index_key(i, _dataset=self): - return key(_dataset[i]) - - indices = list(range(len(self))) - indices.sort(key=index_key, reverse=reverse) - return self[indices] - def close(self): """ Closes resources held by the ArrowDataset.""" if self.mmapped_file is not None: @@ -790,37 +711,3 @@ def delete_cache(self): if self.mmapped_file is not None: self.close() shutil.rmtree(self.cache_path) - - def finalize_fields(self, *datasets): - """Builds vocabularies of all the non-eager fields in the dataset, - from the Dataset objects given as \\*args and then finalizes all the - fields. - - Parameters - ---------- - \\*args - A variable number of Dataset objects from which to build the - vocabularies for non-eager fields. If none provided, the - vocabularies are built from this Dataset (self). - """ - # if there are non-eager fields, we need to build their vocabularies - fields_to_build = [f for f in self.fields if not f.eager and f.use_vocab] - if fields_to_build: - # there can be multiple datasets we want to iterate over - data_sources = [ - dataset for dataset in datasets if isinstance(dataset, Dataset) - ] - - # use self as a data source if no other given - if not data_sources: - data_sources.append(self) - - # for each example in each dataset, - # update _all_ non-eager fields - for dataset in data_sources: - for example in dataset: - for field in fields_to_build: - field.update_vocab(*example[field.name]) - - for field in self.fields: - field.finalize() diff --git a/tests/arrow/test_pyarrow_tabular_dataset.py b/tests/arrow/test_pyarrow_tabular_dataset.py index 74d31a3f..a0a581e0 100644 --- a/tests/arrow/test_pyarrow_tabular_dataset.py +++ b/tests/arrow/test_pyarrow_tabular_dataset.py @@ -159,11 +159,11 @@ def test_finalize_fields(data, fields, mocker): assert f.finalized -def test_filter(data, pyarrow_dataset): +def test_filtered(data, pyarrow_dataset): def filter_even(ex): return ex["number"][0] % 2 == 0 - filtered_dataset = pyarrow_dataset.filter(filter_even) + filtered_dataset = pyarrow_dataset.filtered(filter_even) filtered_data = [d[0] for d in data if d[0] % 2 == 0] for (raw, _), d in zip(filtered_dataset.number, filtered_data): From 4f9d05515873a4398b505d60f9d5253171095185 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 27 Oct 2020 12:54:01 +0100 Subject: [PATCH 05/26] updated iterator.py --- podium/datasets/iterator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/podium/datasets/iterator.py b/podium/datasets/iterator.py index 13570b3c..c01929b7 100644 --- a/podium/datasets/iterator.py +++ b/podium/datasets/iterator.py @@ -213,7 +213,7 @@ def __iter__(self): def _create_batch(self, dataset): - examples = dataset.as_dataset().examples + examples = dataset.examples # dicts that will be used to create the InputBatch and TargetBatch # objects From 9884c0674285cc4190a8edccc5558eb20374386a Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 27 Oct 2020 13:13:10 +0100 Subject: [PATCH 06/26] flake8 fixes --- podium/arrow/arrow_tabular_dataset.py | 2 +- podium/datasets/dataset.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/podium/arrow/arrow_tabular_dataset.py b/podium/arrow/arrow_tabular_dataset.py index d23e17f1..4fe74685 100644 --- a/podium/arrow/arrow_tabular_dataset.py +++ b/podium/arrow/arrow_tabular_dataset.py @@ -6,7 +6,7 @@ import shutil import tempfile from collections import defaultdict -from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Tuple, Union from podium.datasets import Dataset, DatasetABC from podium.storage import ExampleFactory, Field, unpack_fields diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index 6ec13026..aad85c04 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -3,11 +3,10 @@ import itertools import logging import random -from typing import Any, Callable, Dict, List +from typing import Callable, List from podium.datasets.dataset_abc import DatasetABC from podium.storage.example_factory import Example -from podium.storage.field import Field, unpack_fields _LOGGER = logging.getLogger(__name__) From 73f711a87311e9734566b1542a1074174938f503 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Wed, 28 Oct 2020 18:48:59 +0100 Subject: [PATCH 07/26] getattr -> dict notation --- podium/datasets/dataset_abc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 0266554c..d146ec13 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -106,7 +106,7 @@ def __getattr__(self, field_name: str) -> Iterator[Tuple[Any, Any]]: def attr_generator(_dataset, _field_name): for x in _dataset: - yield getattr(x, _field_name) + yield x[field_name] return attr_generator(self, field_name) @@ -142,7 +142,7 @@ def finalize_fields(self, *datasets: "DatasetABC") -> None: for dataset in data_sources: for example in dataset: for field in fields_to_build: - _, tokenized = getattr(example, field.name) + _, tokenized = example[field.name] field.update_vocab(tokenized) for field in self.fields: From 7395f93e377aa636552ae0d9c3581ba44d8658b8 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Thu, 29 Oct 2020 20:10:40 +0100 Subject: [PATCH 08/26] fixed datasetABC __repr__ --- podium/datasets/dataset_abc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index d146ec13..f4dd2334 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -206,8 +206,7 @@ def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": # TODO shuffled? def __repr__(self): - fields = [field.name for field in self.fields] - return f"{type(self).__name__}[Size: {len(self.examples)}, Fields: {fields}]" + return f"{type(self).__name__}[Size: {len(self)}, Fields: {self.fields}]" # ================= Abstract methods ======================= From ed11aad5a3dac1f533610c522b8e605e2fe5d007 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Thu, 29 Oct 2020 20:27:42 +0100 Subject: [PATCH 09/26] _get -> __getitem__ --- podium/arrow/arrow_tabular_dataset.py | 2 +- podium/datasets/dataset.py | 43 +++++++++++++++- podium/datasets/dataset_abc.py | 73 ++++++--------------------- tests/storage/test_dataset.py | 8 +-- 4 files changed, 62 insertions(+), 64 deletions(-) diff --git a/podium/arrow/arrow_tabular_dataset.py b/podium/arrow/arrow_tabular_dataset.py index 4fe74685..6163d9c7 100644 --- a/podium/arrow/arrow_tabular_dataset.py +++ b/podium/arrow/arrow_tabular_dataset.py @@ -617,7 +617,7 @@ def _field_values( return zip(raw_values, tokenized_values) - def _get( + def __getitem__( self, item: Union[int, Iterable[int], slice] ) -> Union[Example, "ArrowDataset"]: """Returns an example or a new ArrowDataset containing the indexed examples. diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index aad85c04..e037ff23 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -3,7 +3,7 @@ import itertools import logging import random -from typing import Callable, List +from typing import Callable, Iterable, List, Union from podium.datasets.dataset_abc import DatasetABC from podium.storage.example_factory import Example @@ -45,7 +45,46 @@ def __init__(self, examples, fields, sort_key=None): self.sort_key = sort_key super().__init__(fields) - def _get(self, i, deep_copy=False): + def __getitem__( + self, i: Union[int, Iterable[int], slice] + ) -> Union["DatasetABC", Example]: + """Returns an example or a new dataset containing the indexed examples. + + If indexed with an int, only the example at that position will be returned. + If Indexed with a slice or iterable, all examples indexed by the object + will be collected and a new dataset containing only those examples will be + returned. The new dataset will contain copies of the old dataset's fields and + will be identical to the original dataset, with the exception of the example + number and ordering. See wiki for detailed examples. + + Examples in the returned Dataset are the same ones present in the + original dataset. If a complete deep-copy of the dataset, or its slice, + is needed please refer to the `get` method. + + Usage example: + + example = dataset[1] # Indexing by single integer returns a single example + + new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing + # the indexed examples. + + Parameters + ---------- + i : int or slice or iterable + Index used to index examples. + + Returns + ------- + single example or Dataset + If i is an int, a single example will be returned. + If i is a slice or iterable, a copy of this dataset containing + only the indexed examples will be returned. + + """ + + return self.get(i) + + def get(self, i, deep_copy=False): """Returns an example or a new dataset containing the indexed examples. If indexed with an int, only the example at that position diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index f4dd2334..4fc32404 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -33,45 +33,6 @@ def examples(self) -> List[Example]: # ================= Default methods ======================= - def __getitem__( - self, i: Union[int, Iterable[int], slice] - ) -> Union["DatasetABC", Example]: - """Returns an example or a new dataset containing the indexed examples. - - If indexed with an int, only the example at that position will be returned. - If Indexed with a slice or iterable, all examples indexed by the object - will be collected and a new dataset containing only those examples will be - returned. The new dataset will contain copies of the old dataset's fields and - will be identical to the original dataset, with the exception of the example - number and ordering. See wiki for detailed examples. - - Examples in the returned Dataset are the same ones present in the - original dataset. If a complete deep-copy of the dataset, or its slice, - is needed please refer to the `get` method. - - Usage example: - - example = dataset[1] # Indexing by single integer returns a single example - - new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing - # the indexed examples. - - Parameters - ---------- - i : int or slice or iterable - Index used to index examples. - - Returns - ------- - single example or Dataset - If i is an int, a single example will be returned. - If i is a slice or iterable, a copy of this dataset containing - only the indexed examples will be returned. - - """ - - return self._get(i) - def __iter__(self) -> Iterator[Example]: """Iterates over all examples in the dataset in order. @@ -222,30 +183,28 @@ def __len__(self) -> int: pass @abstractmethod - def _get(self, i: Union[int, Iterable[int], slice]) -> Union["DatasetABC", Example]: + def __getitem__( + self, i: Union[int, Iterable[int], slice] + ) -> Union["DatasetABC", Example]: """Returns an example or a new dataset containing the indexed examples. - If indexed with an int, only the example at that position - will be returned. + If indexed with an int, only the example at that position will be returned. If Indexed with a slice or iterable, all examples indexed by the object will be collected and a new dataset containing only those examples will be - returned. The new dataset will contain copies of the old dataset's fields - and will be identical to the original dataset, with the exception of the - example number and ordering. See wiki for detailed examples. + returned. The new dataset will contain copies of the old dataset's fields and + will be identical to the original dataset, with the exception of the example + number and ordering. See wiki for detailed examples. - Example - ------- - # Indexing by a single integers returns a single example - example = dataset.get(1) + Examples in the returned Dataset are the same ones present in the + original dataset. If a complete deep-copy of the dataset, or its slice, + is needed please refer to the `get` method. - # Same as the first example, but returns a deep_copy of the example - example_copy = dataset.get(1, deep_copy=True) + Usage example: - # Multi-indexing returns a new dataset containing the indexed examples - s = slice(1, 10) - new_dataset = dataset.get(s) + example = dataset[1] # Indexing by single integer returns a single example - new_dataset_copy = dataset.get(s, deep_copy=True) + new_dataset = dataset[1:10] # Multi-indexing returns a new dataset containing + # the indexed examples. Parameters ---------- @@ -254,9 +213,9 @@ def _get(self, i: Union[int, Iterable[int], slice]) -> Union["DatasetABC", Examp Returns ------- - single example or DatasetABC + single example or Dataset If i is an int, a single example will be returned. - If i is a slice or iterable, a dataset containing + If i is a slice or iterable, a copy of this dataset containing only the indexed examples will be returned. """ diff --git a/tests/storage/test_dataset.py b/tests/storage/test_dataset.py index 561048a1..ef55a0f7 100644 --- a/tests/storage/test_dataset.py +++ b/tests/storage/test_dataset.py @@ -588,13 +588,13 @@ def test_dataset_deep_copy(data, field_list): original_dataset = create_dataset(data, field_list) original_examples = original_dataset.examples - dataset_no_deep_copy = original_dataset._get(slice(0, 5), deep_copy=False) + dataset_no_deep_copy = original_dataset.get(slice(0, 5), deep_copy=False) for original, copy in zip(original_dataset.fields, dataset_no_deep_copy.fields): assert copy is original for original, copy in zip(original_examples, dataset_no_deep_copy.examples): assert copy is original - dataset_deep_copy = original_dataset._get(slice(0, 5), deep_copy=True) + dataset_deep_copy = original_dataset.get(slice(0, 5), deep_copy=True) assert original_dataset.fields is not dataset_deep_copy.fields for original, copy in zip(original_dataset.fields, dataset_deep_copy.fields): @@ -606,9 +606,9 @@ def test_dataset_deep_copy(data, field_list): assert copy["label"] == original["label"] original_example = original_examples[2] - no_copy_example = original_dataset._get(2, deep_copy=False) + no_copy_example = original_dataset.get(2, deep_copy=False) indexed_example = original_dataset[2] - deep_copied_example = original_dataset._get(2, deep_copy=True) + deep_copied_example = original_dataset.get(2, deep_copy=True) assert no_copy_example is original_example assert indexed_example is original_example From 09d401dffd05c314ae77991386774fa609beb2f8 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Thu, 29 Oct 2020 20:37:21 +0100 Subject: [PATCH 10/26] Small changes --- podium/datasets/dataset.py | 2 +- podium/datasets/dataset_abc.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index e037ff23..6b1c92bd 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -554,7 +554,7 @@ def stratified_split( """ # group the examples by the strata_field - strata = itertools.groupby(examples, key=lambda ex: getattr(ex, strata_field_name)) + strata = itertools.groupby(examples, key=lambda ex: ex[strata_field_name]) strata = (list(group) for _, group in strata) train_split, val_split, test_split = [], [], [] diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 4fc32404..f09bf20f 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -44,7 +44,7 @@ def __iter__(self) -> Iterator[Example]: for i in range(len(self)): yield self[i] - def __getattr__(self, field_name: str) -> Iterator[Tuple[Any, Any]]: + def __getattr__(self, field: Union[str, Field]) -> Iterator[Tuple[Any, Any]]: """Returns an Iterator iterating over values of the field with the given name for every example in the dataset. @@ -63,6 +63,8 @@ def __getattr__(self, field_name: str) -> Iterator[Tuple[Any, Any]]: AttributeError If there is no Field with the given name. """ + field_name = field.name if isinstance(field, Field) else field + if field_name in self.field_dict: def attr_generator(_dataset, _field_name): @@ -208,7 +210,7 @@ def __getitem__( Parameters ---------- - i : int or slice or iterable + i : int or slice or iterable of ints Index used to index examples. Returns From 44bfbcde118fa3e5dc5f6a72bc457c7a5cc18a05 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:18:09 +0100 Subject: [PATCH 11/26] made ._fields a tuple --- podium/datasets/dataset_abc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index f09bf20f..98f56413 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -11,8 +11,7 @@ class DatasetABC(ABC): def __init__(self, fields: Union[DictFields, ListFields]): - self._fields = unpack_fields(fields) - self._fields_dict = {f.name: f for f in self._fields} + self._fields = tuple(unpack_fields(fields)) # ==================== Properties ========================= @@ -24,7 +23,7 @@ def fields(self) -> List[Field]: @property def field_dict(self) -> Dict[str, Field]: """Dictionary containing all field names mapping to their respective Fields.""" - return self._fields_dict + return {f.name: f for f in self.fields} @property def examples(self) -> List[Example]: From 3b9c4aa3d0c55563644411139f0d856477e71405 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:26:58 +0100 Subject: [PATCH 12/26] Added shuffled --- podium/datasets/dataset_abc.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 98f56413..6c2b5586 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, Union +import numpy as np + from podium.storage import Example, Field, unpack_fields @@ -16,7 +18,7 @@ def __init__(self, fields: Union[DictFields, ListFields]): # ==================== Properties ========================= @property - def fields(self) -> List[Field]: + def fields(self) -> Tuple[Field]: """List containing all fields of this dataset.""" return self._fields @@ -153,7 +155,7 @@ def index_key(i): return self[indices] def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": - """Method filters examples with given predicate and returns a new DatasetABC + """Filters examples with given predicate and returns a new DatasetABC instance containing those examples. Parameters @@ -161,11 +163,28 @@ def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": predicate : callable predicate should be a callable that accepts example as input and returns true if the example shouldn't be filtered, otherwise returns false + + Returns + ------- + DatasetABC + A new DatasetABC instance containing only the Examples for which `predicate` + returned True. """ indices = [i for i, example in enumerate(self) if predicate(example)] return self[indices] - # TODO shuffled? + def shuffled(self): + """Creates a new DatasetABC instance containing all Examples, but in shuffled + order. + + Returns + ------- + DatasetABC + A new DatasetABC instance containing all Examples, but in shuffled + order. + """ + shuffled_indices = np.random.permutation(len(self)) + return self[shuffled_indices] def __repr__(self): return f"{type(self).__name__}[Size: {len(self)}, Fields: {self.fields}]" From 15be9871c0de221a31c826d7fd45bf682536861e Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:29:29 +0100 Subject: [PATCH 13/26] removed as_dataset --- podium/arrow/arrow_tabular_dataset.py | 12 ------------ podium/datasets/dataset.py | 3 --- tests/arrow/test_pyarrow_tabular_dataset.py | 10 ---------- 3 files changed, 25 deletions(-) diff --git a/podium/arrow/arrow_tabular_dataset.py b/podium/arrow/arrow_tabular_dataset.py index 6163d9c7..388492d7 100644 --- a/podium/arrow/arrow_tabular_dataset.py +++ b/podium/arrow/arrow_tabular_dataset.py @@ -566,18 +566,6 @@ def _get_examples(self) -> List[Example]: the loaded Examples.""" return list(ArrowDataset._recordbatch_to_examples(self.table, self.fields)) - def as_dataset(self) -> Dataset: - """Loads this ArrowDataset into memory and returns an Dataset object containing - the loaded data. - - Returns - ------- - Dataset - Dataset containing all examples of this ArrowDataset. - """ - examples = list(ArrowDataset._recordbatch_to_examples(self.table, self.fields)) - return Dataset(examples, self.fields) - @staticmethod def _field_values( record_batch: pa.RecordBatch, fieldname: str diff --git a/podium/datasets/dataset.py b/podium/datasets/dataset.py index 6b1c92bd..faf946b1 100644 --- a/podium/datasets/dataset.py +++ b/podium/datasets/dataset.py @@ -372,9 +372,6 @@ def shuffle_examples(self, random_state=None): random.shuffle(self.examples) - def as_dataset(self): - return self - def check_split_ratio(split_ratio): """Checks that the split ratio argument is not malformed and if not diff --git a/tests/arrow/test_pyarrow_tabular_dataset.py b/tests/arrow/test_pyarrow_tabular_dataset.py index a0a581e0..8acf5249 100644 --- a/tests/arrow/test_pyarrow_tabular_dataset.py +++ b/tests/arrow/test_pyarrow_tabular_dataset.py @@ -193,16 +193,6 @@ def test_batching(data, pyarrow_dataset): assert index == tokens_vocab[token] -def test_to_dataset(pyarrow_dataset): - dataset = pyarrow_dataset.as_dataset() - - assert isinstance(dataset, Dataset) - assert len(dataset) == len(pyarrow_dataset) - for arrow_ex, dataset_ex in zip(pyarrow_dataset, dataset): - assert arrow_ex.number == dataset_ex.number - assert arrow_ex.tokens == dataset_ex.tokens - - def test_from_dataset(data, fields): example_factory = ExampleFactory(fields) examples = [example_factory.from_list(raw_example) for raw_example in data] From 7fcec4d3519c639fef6ebe6b1c4906bcf358213a Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:47:34 +0100 Subject: [PATCH 14/26] Added type hint to shuffled --- podium/datasets/dataset_abc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 6c2b5586..a9126f2b 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -173,7 +173,7 @@ def filtered(self, predicate: Callable[[Example], bool]) -> "DatasetABC": indices = [i for i, example in enumerate(self) if predicate(example)] return self[indices] - def shuffled(self): + def shuffled(self) -> "DatasetABC": """Creates a new DatasetABC instance containing all Examples, but in shuffled order. From cb0e447da21242180c2c61a13ab0503eb6971ee4 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:50:07 +0100 Subject: [PATCH 15/26] Added type hint overloads to __getitem__ --- podium/datasets/dataset_abc.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index a9126f2b..75f969bf 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, Union +from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, \ + Union, overload import numpy as np @@ -202,6 +203,18 @@ def __len__(self) -> int: """ pass + @overload + def __getitem__( + self, i: int + ) -> Example: + ... + + @overload + def __getitem__( + self, i: Union[Iterable[int], slice] + ) -> "DatasetABC": + ... + @abstractmethod def __getitem__( self, i: Union[int, Iterable[int], slice] From aa192b32f3f307ae044d0637561e28aa1e3324da Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Tue, 3 Nov 2020 17:50:39 +0100 Subject: [PATCH 16/26] black correction --- podium/datasets/dataset_abc.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 75f969bf..37286128 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -1,6 +1,16 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Iterator, List, NamedTuple, Tuple, \ - Union, overload +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Tuple, + Union, + overload, +) import numpy as np @@ -204,15 +214,11 @@ def __len__(self) -> int: pass @overload - def __getitem__( - self, i: int - ) -> Example: + def __getitem__(self, i: int) -> Example: ... @overload - def __getitem__( - self, i: Union[Iterable[int], slice] - ) -> "DatasetABC": + def __getitem__(self, i: Union[Iterable[int], slice]) -> "DatasetABC": ... @abstractmethod From 23234dd32411dd6056a0da09ba26e3c645e99b7f Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Thu, 5 Nov 2020 15:01:21 +0100 Subject: [PATCH 17/26] WIP: NumericalizerABC prototype --- podium/datasets/dataset_abc.py | 2 +- podium/preproc/__init__.py | 3 +- podium/preproc/numericalizer_abc.py | 39 +++++ podium/storage/field.py | 152 +++++++++++--------- podium/storage/vectorizers/tfidf.py | 2 +- podium/storage/vocab.py | 14 +- tests/arrow/test_pyarrow_tabular_dataset.py | 11 +- tests/storage/test_dataset.py | 2 +- tests/storage/test_field.py | 16 +-- 9 files changed, 149 insertions(+), 92 deletions(-) create mode 100644 podium/preproc/numericalizer_abc.py diff --git a/podium/datasets/dataset_abc.py b/podium/datasets/dataset_abc.py index 37286128..8ab8a021 100644 --- a/podium/datasets/dataset_abc.py +++ b/podium/datasets/dataset_abc.py @@ -118,7 +118,7 @@ def finalize_fields(self, *datasets: "DatasetABC") -> None: for example in dataset: for field in fields_to_build: _, tokenized = example[field.name] - field.update_vocab(tokenized) + field.update_numericalizer(tokenized) for field in self.fields: field.finalize() diff --git a/podium/preproc/__init__.py b/podium/preproc/__init__.py index 6a5aab58..e557b131 100644 --- a/podium/preproc/__init__.py +++ b/podium/preproc/__init__.py @@ -12,7 +12,7 @@ from .sentencizers import SpacySentencizer from .stemmer import CroatianStemmer from .tokenizers import get_tokenizer - +from .numericalizer_abc import NumericalizerABC __all__ = [ "CroatianLemmatizer", @@ -26,4 +26,5 @@ "CroatianStemmer", "SpacySentencizer", "get_tokenizer", + "NumericalizerABC" ] diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py new file mode 100644 index 00000000..f6ca336b --- /dev/null +++ b/podium/preproc/numericalizer_abc.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +from typing import List + +import numpy as np + + +class NumericalizerABC(ABC): + + def __init__(self, eager=True): + self._finalized = False + self._eager = eager + + @abstractmethod + def numericalize(self, tokens: List[str]) -> np.ndarray: + pass + + def _finalize(self): + # Subclasses should override this method to add custom + # finalization logic + pass + + def update(self, tokens: List[str]) -> None: + pass + + def finalize(self) -> None: + self._finalize() + self._finalized = True + pass + + @property + def finalized(self) -> bool: + return self._finalized + + @property + def eager(self) -> bool: + return self._eager + + def __call__(self, tokens: List[str]) -> np.ndarray: + return self.numericalize(tokens) diff --git a/podium/storage/field.py b/podium/storage/field.py index 66edcfe7..d3f06f42 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -9,7 +9,7 @@ from podium.preproc.tokenizers import get_tokenizer from podium.storage.vocab import Vocab - +from podium.preproc import NumericalizerABC _LOGGER = logging.getLogger(__name__) @@ -18,7 +18,7 @@ PosttokHook = Callable[[Any, List[str]], Tuple[Any, List[str]]] Tokenizer = Callable[[Any], List[str]] TokenizerArg = Optional[Union[str, Tokenizer]] -Numericalizer = Callable[[str], Union[int, float]] +NumericalizerCallable = Callable[[str], Union[int, float]] class PretokenizationPipeline: @@ -74,10 +74,10 @@ class MultioutputField: for posttokenization processing (posttokenization hooks and vocab updating).""" def __init__( - self, - output_fields: List["Field"], - tokenizer: TokenizerArg = "split", - pretokenize_hooks: Iterable[PretokHook] = [], + self, + output_fields: List["Field"], + tokenizer: TokenizerArg = "split", + pretokenize_hooks: Iterable[PretokHook] = [], ): """Field that does pretokenization and tokenization once and passes it to its output fields. Output fields are any type of field. The output fields are used @@ -210,25 +210,36 @@ def remove_pretokenize_hooks(self): self._pretokenization_pipeline.clear() +class NumericalizerCallableWrapper(NumericalizerABC): + + def __init__(self, numericalizer: NumericalizerCallable): + super().__init__(eager=True) + self._wrapped_numericalizer = numericalizer + + def numericalize(self, tokens: List[str]) -> np.ndarray: + numericalized = [self._wrapped_numericalizer(tok) for tok in tokens] + return np.array(numericalized) + + class Field: """Holds the preprocessing and numericalization logic for a single field of a dataset. """ def __init__( - self, - name: str, - tokenizer: TokenizerArg = "split", - keep_raw: bool = False, - numericalizer: Optional[Union[Vocab, Numericalizer]] = None, - is_target: bool = False, - fixed_length: Optional[int] = None, - allow_missing_data: bool = False, - disable_batch_matrix: bool = False, - padding_token: Union[int, float] = -999, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], - posttokenize_hooks: Iterable[PosttokHook] = [], + self, + name: str, + tokenizer: TokenizerArg = "split", + keep_raw: bool = False, + numericalizer: Optional[Union[NumericalizerABC, NumericalizerCallable]] = None, + is_target: bool = False, + fixed_length: Optional[int] = None, + allow_missing_data: bool = False, + disable_batch_matrix: bool = False, + padding_token: Union[int, float] = -999, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], + posttokenize_hooks: Iterable[PosttokHook] = [], ): """Create a Field from arguments. @@ -328,12 +339,13 @@ def __init__( else: self._tokenizer = get_tokenizer(tokenizer) - if isinstance(numericalizer, Vocab): - self._vocab = numericalizer - self._numericalizer = self.vocab.__getitem__ - else: - self._vocab = None + if isinstance(numericalizer, NumericalizerABC) or numericalizer is None: self._numericalizer = numericalizer + elif isinstance(numericalizer, Callable): + self._numericalizer = NumericalizerCallableWrapper(numericalizer) + else: + # TODO raise error + pass self._keep_raw = keep_raw @@ -398,12 +410,16 @@ def eager(self): whether this field has a Vocab and whether that Vocab is marked as eager """ - return self.vocab is not None and self.vocab.eager + # Pretend to be eager if no numericalizer provided + return self._numericalizer is None or self._numericalizer.eager @property def vocab(self): """""" - return self._vocab + if not self.use_vocab: + # TODO raise error + raise TypeError() + return self._numericalizer @property def use_vocab(self): @@ -415,7 +431,7 @@ def use_vocab(self): Whether the field uses a vocab or not. """ - return self.vocab is not None + return isinstance(self._numericalizer, Vocab) @property def is_target(self): @@ -505,7 +521,7 @@ def _run_pretokenization_hooks(self, data: Any) -> Any: return self._pretokenize_pipeline(data) def _run_posttokenization_hooks( - self, data: Any, tokens: List[str] + self, data: Any, tokens: List[str] ) -> Tuple[Any, List[str]]: """Runs posttokenization hooks on tokenized data. @@ -527,7 +543,7 @@ def _run_posttokenization_hooks( return self._posttokenize_pipeline(data, tokens) def preprocess( - self, data: Any + self, data: Any ) -> Iterable[Tuple[str, Tuple[Any, Optional[List[str]]]]]: """Preprocesses raw data, tokenizing it if required, updating the vocab if the vocab is eager and preserving the raw data @@ -569,7 +585,7 @@ def preprocess( return (self._process_tokens(processed_raw, tokenized),) - def update_vocab(self, tokenized: List[str]): + def update_numericalizer(self, tokenized: Union[str, List[str]]) -> None: """Updates the vocab with a data point in its tokenized form. If the field does not do tokenization, @@ -580,11 +596,11 @@ def update_vocab(self, tokenized: List[str]): updated with. """ - if not self.use_vocab: + if self._numericalizer is None: return # TODO throw Error? data = tokenized if isinstance(tokenized, (list, tuple)) else (tokenized,) - self._vocab += data + self._numericalizer.update(data) @property def finalized(self) -> bool: @@ -597,16 +613,16 @@ def finalized(self) -> bool: Whether the field's Vocab vas finalized. If the field has no vocab, returns True. """ - return True if self.vocab is None else self.vocab.finalized + return self._numericalizer is None or self._numericalizer.finalized def finalize(self): """Signals that this field's vocab can be built.""" - if self.use_vocab: - self.vocab.finalize() + if self._numericalizer is not None: + self._numericalizer.finalize() def _process_tokens( - self, raw: Any, tokens: Union[Any, List[str]] + self, raw: Any, tokens: Union[Any, List[str]] ) -> Tuple[str, Tuple[Any, Optional[Union[Any, List[str]]]]]: """Runs posttokenization processing on the provided data and tokens and updates the vocab if needed. Used by Multioutput field. @@ -629,8 +645,9 @@ def _process_tokens( raw, tokenized = self._run_posttokenization_hooks(raw, tokens) raw = raw if self._keep_raw else None - if self.eager and not self.vocab.finalized: - self.update_vocab(tokenized) + if self.eager and self._numericalizer is not None \ + and not self._numericalizer.finalized: + self.update_numericalizer(tokenized) return self.name, (raw, tokenized) def get_default_value(self) -> Union[int, float]: @@ -653,7 +670,7 @@ def get_default_value(self) -> Union[int, float]: return self._missing_data_token def numericalize( - self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] + self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] ) -> Optional[Union[Any, np.ndarray]]: """Numericalize the already preprocessed data point based either on the vocab that was previously built, or on a custom numericalization @@ -692,18 +709,15 @@ def numericalize( tokens = tokenized if isinstance(tokenized, (list, tuple)) else [tokenized] - if self.use_vocab: - return self.vocab.numericalize(tokens) - else: - return np.array([self._numericalizer(t) for t in tokens]) + return self._numericalizer.numericalize(tokens) def _pad_to_length( - self, - array: np.ndarray, - length: int, - custom_pad_symbol: Optional[Union[int, float]] = None, - pad_left: bool = False, - truncate_left: bool = False, + self, + array: np.ndarray, + length: int, + custom_pad_symbol: Optional[Union[int, float]] = None, + pad_left: bool = False, + truncate_left: bool = False, ): """Either pads the given row with pad symbols, or truncates the row to be of given length. The vocab provides the pad symbol for all @@ -748,7 +762,7 @@ def _pad_to_length( # truncating if truncate_left: - array = array[len(array) - length :] + array = array[len(array) - length:] else: array = array[:length] @@ -779,7 +793,7 @@ def _pad_to_length( return array def get_numericalization_for_example( - self, example, cache: bool = True + self, example, cache: bool = True ) -> Optional[Union[Any, np.ndarray]]: """Returns the numericalized data of this field for the provided example. The numericalized data is generated and cached in the example if 'cache' is true @@ -865,13 +879,13 @@ class LabelField(Field): """ def __init__( - self, - name: str, - numericalizer: Numericalizer = None, - allow_missing_data: bool = False, - is_target: bool = True, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], + self, + name: str, + numericalizer: NumericalizerCallable = None, + allow_missing_data: bool = False, + is_target: bool = True, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], ): """ Field subclass used when no tokenization is required. For example, with a field @@ -941,16 +955,16 @@ class MultilabelField(Field): """ def __init__( - self, - name: str, - tokenizer: TokenizerArg = None, - numericalizer: Numericalizer = None, - num_of_classes: Optional[int] = None, - is_target: bool = True, - allow_missing_data: bool = False, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], - posttokenize_hooks: Iterable[PosttokHook] = [], + self, + name: str, + tokenizer: TokenizerArg = None, + numericalizer: NumericalizerCallable = None, + num_of_classes: Optional[int] = None, + is_target: bool = True, + allow_missing_data: bool = False, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], + posttokenize_hooks: Iterable[PosttokHook] = [], ): """Create a MultilabelField from arguments. @@ -1055,7 +1069,7 @@ def finalize(self): ) def numericalize( - self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] + self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] ) -> np.ndarray: """Numericalize the already preprocessed data point based either on the vocab that was previously built, or on a custom numericalization diff --git a/podium/storage/vectorizers/tfidf.py b/podium/storage/vectorizers/tfidf.py index 1a075747..37e072cb 100644 --- a/podium/storage/vectorizers/tfidf.py +++ b/podium/storage/vectorizers/tfidf.py @@ -153,7 +153,7 @@ def fit(self, dataset, field): ValueError If the vocab or fields vocab are None """ - if self._vocab is None and (field is None or field.vocab is None): + if self._vocab is None and (field is None or not field.use_vocab): raise ValueError( "Vocab is not defined. User should define vocab in constructor " "or by providing field with a non-empty vocab property." diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index a3f7aa13..722babec 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -3,10 +3,11 @@ from collections import Counter from enum import Enum from itertools import chain -from typing import Iterable, Union +from typing import Iterable, Union, List import numpy as np +from podium.preproc import NumericalizerABC _LOGGER = logging.getLogger(__name__) @@ -63,7 +64,7 @@ class SpecialVocabSymbols(Enum): PAD = "" -class Vocab: +class Vocab(NumericalizerABC): """Class for storing vocabulary. It supports frequency counting and size limiting. @@ -100,6 +101,7 @@ def __init__( if true word frequencies will be saved for later use on the finalization """ + super(Vocab, self).__init__(eager) self._freqs = Counter() self._keep_freqs = keep_freqs self._min_freq = min_freq @@ -115,8 +117,6 @@ def __init__( self.stoi.update({k: v for v, k in enumerate(self.itos)}) self._max_size = max_size - self.eager = eager - self.finalized = False # flag to know if we're ready to numericalize _LOGGER.debug("Vocabulary has been created and initialized.") @staticmethod @@ -198,6 +198,9 @@ def padding_index(self): raise ValueError("Padding symbol is not in the vocabulary.") return self.stoi[SpecialVocabSymbols.PAD] + def update(self, tokens: List[str]) -> None: + self.__iadd__(tokens) + def __iadd__(self, values: Union["Vocab", Iterable]): """Adds additional values or another Vocab to this Vocab. @@ -347,7 +350,7 @@ def __add__(self, values: Union["Vocab", Iterable]): new_vocab.finalize() return new_vocab - def finalize(self): + def _finalize(self): """Method finalizes vocab building. It also releases frequency counter if user set not to keep them. @@ -379,7 +382,6 @@ def finalize(self): if not self._keep_freqs: self._freqs = None # release memory - self.finalized = True _LOGGER.debug("Vocabulary is finalized.") def numericalize(self, data): diff --git a/tests/arrow/test_pyarrow_tabular_dataset.py b/tests/arrow/test_pyarrow_tabular_dataset.py index 8acf5249..73eb2290 100644 --- a/tests/arrow/test_pyarrow_tabular_dataset.py +++ b/tests/arrow/test_pyarrow_tabular_dataset.py @@ -139,21 +139,22 @@ def test_dump_and_load(pyarrow_dataset): def test_finalize_fields(data, fields, mocker): for field in fields: mocker.spy(field, "finalize") - mocker.spy(field, "update_vocab") + mocker.spy(field, "update_numericalizer") dataset = pyarrow_dataset(data, fields) for f in fields: # before finalization, no field's dict was updated - if f.vocab is not None: - assert not f.finalized + if f._numericalizer is not None: + assert not f._numericalizer.finalized dataset.finalize_fields() - fields_to_finalize = [f for f in fields if not f.eager and f.use_vocab] + fields_to_finalize = [f for f in fields + if not f.eager and f._numericalizer is not None] for f in fields_to_finalize: # during finalization, only non-eager field's dict should be updated - assert f.update_vocab.call_count == (len(data) if (not f.eager) else 0) + assert f.update_numericalizer.call_count == (len(data) if (not f.eager) else 0) f.finalize.assert_called_once() # all fields should be finalized assert f.finalized diff --git a/tests/storage/test_dataset.py b/tests/storage/test_dataset.py index ef55a0f7..5e453071 100644 --- a/tests/storage/test_dataset.py +++ b/tests/storage/test_dataset.py @@ -78,7 +78,7 @@ def preprocess(self, data): return ((self.name, (raw, tokenized)),) - def update_vocab(self, tokenized): + def update_numericalizer(self, tokenized): assert not self.eager self.updated_count += 1 diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py index ccf36e1a..930c3a82 100644 --- a/tests/storage/test_field.py +++ b/tests/storage/test_field.py @@ -13,7 +13,7 @@ SpecialVocabSymbols, Vocab, ) - +from podium.preproc import NumericalizerABC ONE_TO_FIVE = [1, 2, 3, 4, 5] @@ -36,17 +36,19 @@ def tokenizer(self, string): return MockTokenizer() -class MockVocab(Mock): +class MockVocab(Mock, NumericalizerABC): def __init__(self, eager=True): - super(MockVocab, self).__init__(spec=Vocab) + Mock.__init__(self, spec=Vocab) + NumericalizerABC.__init__(self, eager) self.values = [] - self.finalized = False self.numericalized = False - self.eager = eager def padding_index(self): return PAD_NUM + def update(self, tokens): + self.__iadd__(tokens) + def __add__(self, values): if type(values) == type(self): pass @@ -58,11 +60,9 @@ def __add__(self, values): def __iadd__(self, other): return self.__add__(other) - def finalize(self): + def _finalize(self): if self.finalized: raise Exception - else: - self.finalized = True def numericalize(self, data): self.numericalized = True From 07ec68250ad3240fe185488619d870c6a852fe15 Mon Sep 17 00:00:00 2001 From: Ivan Smokovic Date: Thu, 5 Nov 2020 15:46:42 +0100 Subject: [PATCH 18/26] style changes --- podium/preproc/__init__.py | 5 +- podium/preproc/numericalizer_abc.py | 1 - podium/storage/field.py | 105 ++++++++++---------- podium/storage/vocab.py | 3 +- tests/arrow/test_pyarrow_tabular_dataset.py | 5 +- tests/storage/test_field.py | 3 +- 6 files changed, 64 insertions(+), 58 deletions(-) diff --git a/podium/preproc/__init__.py b/podium/preproc/__init__.py index e557b131..b2de82fa 100644 --- a/podium/preproc/__init__.py +++ b/podium/preproc/__init__.py @@ -9,10 +9,11 @@ TextCleanUp, ) from .lemmatizer import CroatianLemmatizer +from .numericalizer_abc import NumericalizerABC from .sentencizers import SpacySentencizer from .stemmer import CroatianStemmer from .tokenizers import get_tokenizer -from .numericalizer_abc import NumericalizerABC + __all__ = [ "CroatianLemmatizer", @@ -26,5 +27,5 @@ "CroatianStemmer", "SpacySentencizer", "get_tokenizer", - "NumericalizerABC" + "NumericalizerABC", ] diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py index f6ca336b..5c753086 100644 --- a/podium/preproc/numericalizer_abc.py +++ b/podium/preproc/numericalizer_abc.py @@ -5,7 +5,6 @@ class NumericalizerABC(ABC): - def __init__(self, eager=True): self._finalized = False self._eager = eager diff --git a/podium/storage/field.py b/podium/storage/field.py index d3f06f42..20031f78 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -7,9 +7,10 @@ import numpy as np +from podium.preproc import NumericalizerABC from podium.preproc.tokenizers import get_tokenizer from podium.storage.vocab import Vocab -from podium.preproc import NumericalizerABC + _LOGGER = logging.getLogger(__name__) @@ -74,10 +75,10 @@ class MultioutputField: for posttokenization processing (posttokenization hooks and vocab updating).""" def __init__( - self, - output_fields: List["Field"], - tokenizer: TokenizerArg = "split", - pretokenize_hooks: Iterable[PretokHook] = [], + self, + output_fields: List["Field"], + tokenizer: TokenizerArg = "split", + pretokenize_hooks: Iterable[PretokHook] = [], ): """Field that does pretokenization and tokenization once and passes it to its output fields. Output fields are any type of field. The output fields are used @@ -211,7 +212,6 @@ def remove_pretokenize_hooks(self): class NumericalizerCallableWrapper(NumericalizerABC): - def __init__(self, numericalizer: NumericalizerCallable): super().__init__(eager=True) self._wrapped_numericalizer = numericalizer @@ -227,19 +227,19 @@ class Field: """ def __init__( - self, - name: str, - tokenizer: TokenizerArg = "split", - keep_raw: bool = False, - numericalizer: Optional[Union[NumericalizerABC, NumericalizerCallable]] = None, - is_target: bool = False, - fixed_length: Optional[int] = None, - allow_missing_data: bool = False, - disable_batch_matrix: bool = False, - padding_token: Union[int, float] = -999, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], - posttokenize_hooks: Iterable[PosttokHook] = [], + self, + name: str, + tokenizer: TokenizerArg = "split", + keep_raw: bool = False, + numericalizer: Optional[Union[NumericalizerABC, NumericalizerCallable]] = None, + is_target: bool = False, + fixed_length: Optional[int] = None, + allow_missing_data: bool = False, + disable_batch_matrix: bool = False, + padding_token: Union[int, float] = -999, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], + posttokenize_hooks: Iterable[PosttokHook] = [], ): """Create a Field from arguments. @@ -521,7 +521,7 @@ def _run_pretokenization_hooks(self, data: Any) -> Any: return self._pretokenize_pipeline(data) def _run_posttokenization_hooks( - self, data: Any, tokens: List[str] + self, data: Any, tokens: List[str] ) -> Tuple[Any, List[str]]: """Runs posttokenization hooks on tokenized data. @@ -543,7 +543,7 @@ def _run_posttokenization_hooks( return self._posttokenize_pipeline(data, tokens) def preprocess( - self, data: Any + self, data: Any ) -> Iterable[Tuple[str, Tuple[Any, Optional[List[str]]]]]: """Preprocesses raw data, tokenizing it if required, updating the vocab if the vocab is eager and preserving the raw data @@ -622,7 +622,7 @@ def finalize(self): self._numericalizer.finalize() def _process_tokens( - self, raw: Any, tokens: Union[Any, List[str]] + self, raw: Any, tokens: Union[Any, List[str]] ) -> Tuple[str, Tuple[Any, Optional[Union[Any, List[str]]]]]: """Runs posttokenization processing on the provided data and tokens and updates the vocab if needed. Used by Multioutput field. @@ -645,8 +645,11 @@ def _process_tokens( raw, tokenized = self._run_posttokenization_hooks(raw, tokens) raw = raw if self._keep_raw else None - if self.eager and self._numericalizer is not None \ - and not self._numericalizer.finalized: + if ( + self.eager + and self._numericalizer is not None + and not self._numericalizer.finalized + ): self.update_numericalizer(tokenized) return self.name, (raw, tokenized) @@ -670,7 +673,7 @@ def get_default_value(self) -> Union[int, float]: return self._missing_data_token def numericalize( - self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] + self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] ) -> Optional[Union[Any, np.ndarray]]: """Numericalize the already preprocessed data point based either on the vocab that was previously built, or on a custom numericalization @@ -712,12 +715,12 @@ def numericalize( return self._numericalizer.numericalize(tokens) def _pad_to_length( - self, - array: np.ndarray, - length: int, - custom_pad_symbol: Optional[Union[int, float]] = None, - pad_left: bool = False, - truncate_left: bool = False, + self, + array: np.ndarray, + length: int, + custom_pad_symbol: Optional[Union[int, float]] = None, + pad_left: bool = False, + truncate_left: bool = False, ): """Either pads the given row with pad symbols, or truncates the row to be of given length. The vocab provides the pad symbol for all @@ -762,7 +765,7 @@ def _pad_to_length( # truncating if truncate_left: - array = array[len(array) - length:] + array = array[len(array) - length :] else: array = array[:length] @@ -793,7 +796,7 @@ def _pad_to_length( return array def get_numericalization_for_example( - self, example, cache: bool = True + self, example, cache: bool = True ) -> Optional[Union[Any, np.ndarray]]: """Returns the numericalized data of this field for the provided example. The numericalized data is generated and cached in the example if 'cache' is true @@ -879,13 +882,13 @@ class LabelField(Field): """ def __init__( - self, - name: str, - numericalizer: NumericalizerCallable = None, - allow_missing_data: bool = False, - is_target: bool = True, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], + self, + name: str, + numericalizer: NumericalizerCallable = None, + allow_missing_data: bool = False, + is_target: bool = True, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], ): """ Field subclass used when no tokenization is required. For example, with a field @@ -955,16 +958,16 @@ class MultilabelField(Field): """ def __init__( - self, - name: str, - tokenizer: TokenizerArg = None, - numericalizer: NumericalizerCallable = None, - num_of_classes: Optional[int] = None, - is_target: bool = True, - allow_missing_data: bool = False, - missing_data_token: Union[int, float] = -1, - pretokenize_hooks: Iterable[PretokHook] = [], - posttokenize_hooks: Iterable[PosttokHook] = [], + self, + name: str, + tokenizer: TokenizerArg = None, + numericalizer: NumericalizerCallable = None, + num_of_classes: Optional[int] = None, + is_target: bool = True, + allow_missing_data: bool = False, + missing_data_token: Union[int, float] = -1, + pretokenize_hooks: Iterable[PretokHook] = [], + posttokenize_hooks: Iterable[PosttokHook] = [], ): """Create a MultilabelField from arguments. @@ -1069,7 +1072,7 @@ def finalize(self): ) def numericalize( - self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] + self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] ) -> np.ndarray: """Numericalize the already preprocessed data point based either on the vocab that was previously built, or on a custom numericalization diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index 722babec..9efad1e3 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -3,12 +3,13 @@ from collections import Counter from enum import Enum from itertools import chain -from typing import Iterable, Union, List +from typing import Iterable, List, Union import numpy as np from podium.preproc import NumericalizerABC + _LOGGER = logging.getLogger(__name__) diff --git a/tests/arrow/test_pyarrow_tabular_dataset.py b/tests/arrow/test_pyarrow_tabular_dataset.py index 73eb2290..0d62c7af 100644 --- a/tests/arrow/test_pyarrow_tabular_dataset.py +++ b/tests/arrow/test_pyarrow_tabular_dataset.py @@ -150,8 +150,9 @@ def test_finalize_fields(data, fields, mocker): dataset.finalize_fields() - fields_to_finalize = [f for f in fields - if not f.eager and f._numericalizer is not None] + fields_to_finalize = [ + f for f in fields if not f.eager and f._numericalizer is not None + ] for f in fields_to_finalize: # during finalization, only non-eager field's dict should be updated assert f.update_numericalizer.call_count == (len(data) if (not f.eager) else 0) diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py index 930c3a82..4a1b735b 100644 --- a/tests/storage/test_field.py +++ b/tests/storage/test_field.py @@ -5,6 +5,7 @@ import numpy as np import pytest +from podium.preproc import NumericalizerABC from podium.storage import ( Field, LabelField, @@ -13,7 +14,7 @@ SpecialVocabSymbols, Vocab, ) -from podium.preproc import NumericalizerABC + ONE_TO_FIVE = [1, 2, 3, 4, 5] From 65906ecc0acb60d95b9b95aa0efef44e7feabc8b Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Wed, 9 Dec 2020 20:16:48 +0100 Subject: [PATCH 19/26] Merged master --- podium/storage/vocab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index a4f202f1..eb1beec2 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -379,7 +379,7 @@ def _finalize(self): if not self._keep_freqs: self._freqs = None # release memory - self.finalized = True + self._finalized = True def numericalize(self, data): """Method numericalizes given tokens. From 125a0151933d3c09cfd9386acbef48e20360a0b6 Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Wed, 9 Dec 2020 20:29:54 +0100 Subject: [PATCH 20/26] minor style changes --- docs/source/conf.py | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 1d2046be..48328740 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,6 +16,7 @@ # Quick fix for cross-reference warnings: from sphinx.domains.python import PythonDomain + sys.path.insert(0, os.path.abspath('../../')) diff --git a/setup.py b/setup.py index 63be8b59..b107202c 100644 --- a/setup.py +++ b/setup.py @@ -5,8 +5,8 @@ See http://takelab.fer.hr/podium/ for complete documentation. """ import re - from pathlib import Path + from setuptools import find_packages, setup From 7a26a04a853ffc1e22081057540da533daf19612 Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Wed, 9 Dec 2020 21:56:11 +0100 Subject: [PATCH 21/26] TODO docs --- podium/preproc/numericalizer_abc.py | 11 +++++++++++ podium/storage/field.py | 6 ++++-- podium/storage/vocab.py | 1 - 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py index 5c753086..506172b2 100644 --- a/podium/preproc/numericalizer_abc.py +++ b/podium/preproc/numericalizer_abc.py @@ -5,7 +5,18 @@ class NumericalizerABC(ABC): + """ABC that contains the interface for Podium numericalizers. Numericalizers are used + to transform tokens into vectors or any other custom datatype during batching.""" + def __init__(self, eager=True): + """Initialises the Numericalizer. + + Parameters: + ----------- + eager: bool + Whether the + + """ self._finalized = False self._eager = eager diff --git a/podium/storage/field.py b/podium/storage/field.py index cda60acf..9b114036 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -406,8 +406,10 @@ def eager(self): def vocab(self): """""" if not self.use_vocab: - # TODO raise error - raise TypeError() + numericalizer_type = type(self._numericalizer).__name__ + err_msg = f'Field "{self.name}" has no vocab, numericalizer type is ' \ + f'{numericalizer_type}.' + raise TypeError(f"") return self._numericalizer @property diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index eb1beec2..02b73beb 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -379,7 +379,6 @@ def _finalize(self): if not self._keep_freqs: self._freqs = None # release memory - self._finalized = True def numericalize(self, data): """Method numericalizes given tokens. From 5668253a151b9b0aa92b3943c1a98de367a7d08a Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Thu, 10 Dec 2020 18:44:05 +0100 Subject: [PATCH 22/26] _finalize -> mark_finalized --- podium/preproc/numericalizer_abc.py | 6 ++---- podium/storage/vocab.py | 3 ++- tests/storage/test_field.py | 3 ++- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py index 506172b2..2919314a 100644 --- a/podium/preproc/numericalizer_abc.py +++ b/podium/preproc/numericalizer_abc.py @@ -24,7 +24,7 @@ def __init__(self, eager=True): def numericalize(self, tokens: List[str]) -> np.ndarray: pass - def _finalize(self): + def finalize(self): # Subclasses should override this method to add custom # finalization logic pass @@ -32,10 +32,8 @@ def _finalize(self): def update(self, tokens: List[str]) -> None: pass - def finalize(self) -> None: - self._finalize() + def mark_finalized(self) -> None: self._finalized = True - pass @property def finalized(self) -> bool: diff --git a/podium/storage/vocab.py b/podium/storage/vocab.py index 02b73beb..76eb458e 100644 --- a/podium/storage/vocab.py +++ b/podium/storage/vocab.py @@ -347,7 +347,7 @@ def __add__(self, values: Union["Vocab", Iterable]): new_vocab.finalize() return new_vocab - def _finalize(self): + def finalize(self): """Method finalizes vocab building. It also releases frequency counter if user set not to keep them. @@ -379,6 +379,7 @@ def _finalize(self): if not self._keep_freqs: self._freqs = None # release memory + self.mark_finalized() def numericalize(self, data): """Method numericalizes given tokens. diff --git a/tests/storage/test_field.py b/tests/storage/test_field.py index e8829373..1979602e 100644 --- a/tests/storage/test_field.py +++ b/tests/storage/test_field.py @@ -61,9 +61,10 @@ def __add__(self, values): def __iadd__(self, other): return self.__add__(other) - def _finalize(self): + def finalize(self): if self.finalized: raise Exception + self.mark_finalized() def numericalize(self, data): self.numericalized = True From f110c5fa12bcc64dbf84caa618643575b24f115f Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Thu, 10 Dec 2020 18:45:22 +0100 Subject: [PATCH 23/26] black --- podium/storage/field.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/podium/storage/field.py b/podium/storage/field.py index 9b114036..afd98d6d 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -407,9 +407,11 @@ def vocab(self): """""" if not self.use_vocab: numericalizer_type = type(self._numericalizer).__name__ - err_msg = f'Field "{self.name}" has no vocab, numericalizer type is ' \ - f'{numericalizer_type}.' - raise TypeError(f"") + err_msg = ( + f'Field "{self.name}" has no vocab, numericalizer type is ' + f"{numericalizer_type}." + ) + raise TypeError(err_msg) return self._numericalizer @property From e28a77726aa3c3b64cdc38578c1a892d834a794d Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Thu, 10 Dec 2020 20:00:33 +0100 Subject: [PATCH 24/26] docs --- podium/preproc/numericalizer_abc.py | 58 ++++++++++++++++++++++++++--- podium/storage/field.py | 4 +- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/podium/preproc/numericalizer_abc.py b/podium/preproc/numericalizer_abc.py index 2919314a..b24e3f3a 100644 --- a/podium/preproc/numericalizer_abc.py +++ b/podium/preproc/numericalizer_abc.py @@ -6,15 +6,23 @@ class NumericalizerABC(ABC): """ABC that contains the interface for Podium numericalizers. Numericalizers are used - to transform tokens into vectors or any other custom datatype during batching.""" + to transform tokens into vectors or any other custom datatype during batching. + + Attributes + ---------- + finalized: bool + Whether this numericalizer was finalized and is able to be used for + numericalization. + """ def __init__(self, eager=True): """Initialises the Numericalizer. - Parameters: - ----------- + Parameters + ---------- eager: bool - Whether the + Whether the Numericalizer is to be updated during loading of the dataset, or + after all data is loaded. """ self._finalized = False @@ -22,25 +30,63 @@ def __init__(self, eager=True): @abstractmethod def numericalize(self, tokens: List[str]) -> np.ndarray: + """Converts `tokens` into a numericalized format used in batches. + Numericalizations are most often numpy vectors, but any custom datatype is + supported. + + Parameters + ---------- + tokens: List[str] + A list of strings that represent the tokens of this data point. Can also be + any other datatype, as long as this Numericalizer supports it. + + Returns + ------- + Numericalization used in batches. Numericalizations are most often numpy vectors, + but any custom datatype is supported. + """ pass def finalize(self): - # Subclasses should override this method to add custom - # finalization logic + """Finalizes the Numericalizer and prepares it for numericalization. + This method must be overridden in classes that require finalization before + numericalization. The override must call `mark_finalize` after successful + completion.""" + self.mark_finalized() pass def update(self, tokens: List[str]) -> None: + """Updates this Numericalizer with a single data point. Numericalizers that need + to be updated example by example must override this method. Numericalizers that + are eager get updated during the dataset loading process, while non-eager ones get + updated after loading is finished, after all eager numericalizers were fully + updated. + + Parameters + ---------- + tokens: List[str] + A list of strings that represent the tokens of this data point. Can also be + any other datatype, as long as this Numericalizer supports it. + + """ pass def mark_finalized(self) -> None: + """Marks the field as finalized. This method must be called after finalization + completes successfully.""" self._finalized = True @property def finalized(self) -> bool: + """Whether this Numericalizer was finalized and is ready for numericalization.""" return self._finalized @property def eager(self) -> bool: + """Whether this Numericalizer is eager. Numericalizers that + are eager get updated during the dataset loading process, while non-eager ones get + updated after loading is finished, after all eager numericalizers were fully + updated.""" return self._eager def __call__(self, tokens: List[str]) -> np.ndarray: diff --git a/podium/storage/field.py b/podium/storage/field.py index afd98d6d..45068890 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -1053,9 +1053,8 @@ def __init__( def finalize(self): """Signals that this field's vocab can be built.""" - super().finalize() if self._num_of_classes is None: - self.fixed_length = self._num_of_classes = len(self.vocab) + self._fixed_length = self._num_of_classes = len(self.vocab) if self.use_vocab and len(self.vocab) > self._num_of_classes: raise ValueError( @@ -1063,6 +1062,7 @@ def finalize(self): f"of classes. Declared: {self._num_of_classes}, " f"Actual: {len(self.vocab)}" ) + super().finalize() def numericalize( self, data: Tuple[Optional[Any], Optional[Union[Any, List[str]]]] From dac186ccdcf5646f1d09203a7b40cb4cd51bd373 Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Thu, 10 Dec 2020 20:25:54 +0100 Subject: [PATCH 25/26] added error message --- podium/storage/field.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/podium/storage/field.py b/podium/storage/field.py index 75da7c92..c7adbfda 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -338,8 +338,9 @@ def __init__( elif isinstance(numericalizer, Callable): self._numericalizer = NumericalizerCallableWrapper(numericalizer) else: - # TODO raise error - pass + err_msg = f'Field {name}: unsupported numericalizer type ' \ + f'"{type(numericalizer).__name__}"' + raise TypeError(err_msg) self._keep_raw = keep_raw @@ -568,6 +569,7 @@ def preprocess( # Preprocess the raw input # TODO keep unprocessed or processed raw? + # Keeping processed for now, may change in the future processed_raw = self._run_pretokenization_hooks(data) tokenized = ( self._tokenizer(processed_raw) From 56beaea9593cc6b830347402c9ca4ec3690553fe Mon Sep 17 00:00:00 2001 From: I van Smokovic Date: Thu, 10 Dec 2020 20:39:53 +0100 Subject: [PATCH 26/26] black compliance --- podium/storage/field.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/podium/storage/field.py b/podium/storage/field.py index c7adbfda..a4cc1429 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -338,8 +338,10 @@ def __init__( elif isinstance(numericalizer, Callable): self._numericalizer = NumericalizerCallableWrapper(numericalizer) else: - err_msg = f'Field {name}: unsupported numericalizer type ' \ - f'"{type(numericalizer).__name__}"' + err_msg = ( + f"Field {name}: unsupported numericalizer type " + f'"{type(numericalizer).__name__}"' + ) raise TypeError(err_msg) self._keep_raw = keep_raw