diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 92118df6..dab7ef84 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -538,7 +538,7 @@ def __init__( @property def splits(self): """Overrides since we can't call `info.splits` until after init.""" - return self._splits or self._tfds_dataset.info.splits + return self._splits or self.tfds_dataset.info.splits @property def tfds_dataset(self) -> utils.LazyTfdsLoader: diff --git a/seqio/utils.py b/seqio/utils.py index 212b9e2f..5768c029 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -94,6 +94,15 @@ def _validate_tfds_name(name: str) -> None: raise ValueError(f"TFDS name must contain a version number, got: {name}") +def _get_data_dir_override(tfds_name: Optional[str]) -> Optional[str]: + """Returns the data dir in case it is overridden.""" + if ( + _TFDS_DATA_DIR_OVERRIDE + ): + return _TFDS_DATA_DIR_OVERRIDE + return None + + @dataclasses.dataclass(frozen=True) class TfdsSplit: """Points to a specific TFDS split. @@ -148,6 +157,7 @@ def __init__( _validate_tfds_name(name) self._name = name self._data_dir = data_dir + self._data_dir_override = None self._split_map = split_map self._decoders = decoders self._builder_kwargs = builder_kwargs @@ -238,19 +248,22 @@ def data_dir(self) -> Optional[str]: ) return None + if self._data_dir_override is None: + if data_dir_override := _get_data_dir_override(tfds_name=self.name): + self._data_dir_override = data_dir_override - if ( - _TFDS_DATA_DIR_OVERRIDE - ): - if self._data_dir: - logging.warning( - "Overriding TFDS data directory '%s' with '%s' for dataset '%s'.", - self._data_dir, - _TFDS_DATA_DIR_OVERRIDE, - self.name, - ) - return _TFDS_DATA_DIR_OVERRIDE - return self._data_dir + if self._data_dir_override and self._data_dir: + logging.warning( + "Overriding TFDS data directory '%s' with '%s' for dataset '%s'.", + self._data_dir, + self._data_dir_override, + self.name, + ) + + return self._data_dir_override or self._data_dir + + def override_data_dir(self, data_dir: str) -> None: + self._data_dir_override = data_dir @property def read_config(self): diff --git a/seqio/utils_test.py b/seqio/utils_test.py index a7ca73f9..5261177b 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -285,6 +285,41 @@ def test_read_config_override(self, mock_tfds_load): # reset to default global override utils.set_tfds_read_config_override(None) + @mock.patch("tensorflow_datasets.builder") + def test_override_data_dir(self, mock_tfds_builder): + mock_builder1 = mock.create_autospec(tfds.core.DatasetBuilder) + mock_builder2 = mock.create_autospec(tfds.core.DatasetBuilder) + mock_builder3 = mock.create_autospec(tfds.core.DatasetBuilder) + mock_tfds_builder.side_effect = [ + mock_builder1, + mock_builder2, + mock_builder3, + ] + + orig_data_dir = "/data" + override1 = "/override1" + override2 = "/override2" + + utils.set_tfds_data_dir_override(override1) + + # Should use `override1` that was set globally. + loader = utils.LazyTfdsLoader(name="a/b:1.0.0", data_dir=orig_data_dir) + self.assertEqual(override1, loader.data_dir) + self.assertEqual(loader.builder, mock_builder1) + + loader.override_data_dir(override2) + self.assertEqual(override2, loader.data_dir) + self.assertEqual(loader.builder, mock_builder2) + + # Set back to original data dir and check whether the cache works. + loader.override_data_dir(orig_data_dir) + self.assertEqual(orig_data_dir, loader.data_dir) + self.assertEqual(loader.builder, mock_builder3) + self.assertEqual(mock_tfds_builder.call_count, 3) + + # Unset it to not influence other tests. + utils.set_tfds_data_dir_override(None) + class TransformUtilsTest(parameterized.TestCase): @@ -375,7 +410,6 @@ def fn(ex, seed): mapped_ds = fn(ds) # pylint: disable=no-value-for-parameter results = [7, 5, 6, 6, 7, 11, 12, 16, 15, 15] expected_ds = [{"field": results[i]} for i in range(10)] - print("gaurav", list(mapped_ds.as_numpy_iterator())) self.assertListEqual(list(mapped_ds.as_numpy_iterator()), expected_ds) def test_random_map_fn_with_kwargs(self): @@ -432,7 +466,6 @@ def fn(ex, seeds, val, sequence_length): mapped_ds = map_fn(ds, sequence_length={"field": -1}) # pylint: disable=no-value-for-parameter results = [13, 15, 16, 13, 9, 16, 15, 26, 18, 16] expected_ds = [{"field": results[i]} for i in range(10)] - print("gaurav", list(mapped_ds.as_numpy_iterator())) self.assertListEqual(list(mapped_ds.as_numpy_iterator()), expected_ds)