diff --git a/src/datachain/catalog/catalog.py b/src/datachain/catalog/catalog.py index d6f27b047..5854024e2 100644 --- a/src/datachain/catalog/catalog.py +++ b/src/datachain/catalog/catalog.py @@ -1,4 +1,5 @@ import ast +import glob import io import json import logging @@ -709,7 +710,12 @@ def enlist_source( client_config = client_config or self.client_config client, path = self.parse_url(source, **client_config) - prefix = posixpath.dirname(path) + stem = os.path.basename(os.path.normpath(path)) + prefix = ( + posixpath.dirname(path) + if glob.has_magic(stem) or client.fs.isfile(source) + else path + ) storage_dataset_name = Storage.dataset_name( client.uri, posixpath.join(prefix, "") ) diff --git a/tests/func/test_catalog.py b/tests/func/test_catalog.py index 681b5c329..4db8660ca 100644 --- a/tests/func/test_catalog.py +++ b/tests/func/test_catalog.py @@ -1068,6 +1068,45 @@ def test_storage_stats(cloud_test_catalog): assert stats.size == 15 +@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True) +def test_enlist_source_handles_slash(cloud_test_catalog): + catalog = cloud_test_catalog.catalog + src_uri = cloud_test_catalog.src_uri + + catalog.enlist_source(f"{src_uri}/dogs", ttl=1234) + stats = catalog.storage_stats(src_uri) + assert stats.num_objects == len(DEFAULT_TREE["dogs"]) + assert stats.size == 15 + + catalog.enlist_source(f"{src_uri}/dogs/", ttl=1234, force_update=True) + stats = catalog.storage_stats(src_uri) + assert stats.num_objects == len(DEFAULT_TREE["dogs"]) + assert stats.size == 15 + + +@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True) +def test_enlist_source_handles_glob(cloud_test_catalog): + catalog = cloud_test_catalog.catalog + src_uri = cloud_test_catalog.src_uri + + catalog.enlist_source(f"{src_uri}/dogs/*.jpg", ttl=1234) + stats = catalog.storage_stats(src_uri) + + assert stats.num_objects == len(DEFAULT_TREE["dogs"]) + assert stats.size == 15 + + +@pytest.mark.parametrize("cloud_type", ["s3", "azure", "gs"], indirect=True) +def test_enlist_source_handles_file(cloud_test_catalog): + catalog = cloud_test_catalog.catalog + src_uri = cloud_test_catalog.src_uri + + catalog.enlist_source(f"{src_uri}/dogs/dog1", ttl=1234) + stats = catalog.storage_stats(src_uri) + assert stats.num_objects == len(DEFAULT_TREE["dogs"]) + assert stats.size == 15 + + @pytest.mark.parametrize("from_cli", [False, True]) def test_garbage_collect(cloud_test_catalog, from_cli, capsys): catalog = cloud_test_catalog.catalog diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index a070b6aff..58c9b9025 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -660,7 +660,7 @@ def test_parse_tabular_partitions(tmp_dir, catalog): def test_parse_tabular_empty(tmp_dir, catalog): path = tmp_dir / "test.parquet" - with pytest.raises(DataChainParamsError): + with pytest.raises(FileNotFoundError): DataChain.from_storage(path.as_uri()).parse_tabular()