diff --git a/optformer/data/tasks.py b/optformer/data/tasks.py index 1b4e391..2f429a3 100644 --- a/optformer/data/tasks.py +++ b/optformer/data/tasks.py @@ -23,6 +23,7 @@ import seqio import t5.data import tensorflow as tf +import tensorflow_datasets as tfds Study = converters.Study @@ -272,10 +273,14 @@ def supports_arbitrary_sharding(self) -> bool: def get_dataset( self, - split: str, + split: str = tfds.Split.TRAIN, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[seqio.ShardInfo] = None, + *, + sequence_length: Optional[Mapping[str, int]] = None, # Unused + use_cached: bool = False, # Unused + num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: raise NotImplementedError