From 9549a10b82d9f505eeddad608ee0afd4259b5e4f Mon Sep 17 00:00:00 2001 From: Robert Neale Date: Fri, 3 Feb 2023 22:25:35 -0800 Subject: [PATCH] Standardize inheritance of DatasetProviderBase. This is mostly trying to align the inherited function get_dataset for all dataset providers. All changes should be non-destructive, and thus not affect current clients. PiperOrigin-RevId: 507072854 --- optformer/data/tasks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/optformer/data/tasks.py b/optformer/data/tasks.py index 1b4e391..2f429a3 100644 --- a/optformer/data/tasks.py +++ b/optformer/data/tasks.py @@ -23,6 +23,7 @@ import seqio import t5.data import tensorflow as tf +import tensorflow_datasets as tfds Study = converters.Study @@ -272,10 +273,14 @@ def supports_arbitrary_sharding(self) -> bool: def get_dataset( self, - split: str, + split: str = tfds.Split.TRAIN, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[seqio.ShardInfo] = None, + *, + sequence_length: Optional[Mapping[str, int]] = None, # Unused + use_cached: bool = False, # Unused + num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: raise NotImplementedError