Skip to content

Commit

Permalink
Fix storage dependencies (#421)
Browse files Browse the repository at this point in the history
* fix storage dependencies

* fix test

* fix dataset parsing
  • Loading branch information
ilongin authored Sep 16, 2024
1 parent 7a7f2fc commit 78ee1ba
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/datachain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def parse(

if is_listing_dataset(dataset_name):
dependency_type = DatasetDependencyType.STORAGE # type: ignore[arg-type]
dependency_name = listing_uri_from_name(dataset_name)
dependency_name, _ = Client.parse_url(listing_uri_from_name(dataset_name))

return cls(
id,
Expand Down
18 changes: 17 additions & 1 deletion tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from PIL import Image
from sqlalchemy import Column

from datachain.client.local import FileClient
from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.dataset import DatasetStats
from datachain.dataset import DatasetDependencyType, DatasetStats
from datachain.lib.dc import C, DataChain, DataChainColumnError
from datachain.lib.file import File, ImageFile
from datachain.lib.listing import (
Expand Down Expand Up @@ -178,6 +179,21 @@ def _list_dataset_name(uri: str) -> str:
)


def test_from_storage_dependencies(cloud_test_catalog, cloud_type):
ctc = cloud_test_catalog
src_uri = ctc.src_uri
uri = f"{src_uri}/cats"
ds_name = "dep"
DataChain.from_storage(uri, session=ctc.session).save(ds_name)
dependencies = ctc.session.catalog.get_dataset_dependencies(ds_name, 1)
assert len(dependencies) == 1
assert dependencies[0].type == DatasetDependencyType.STORAGE
if cloud_type == "file":
assert dependencies[0].name == FileClient.root_path().as_uri()
else:
assert dependencies[0].name == src_uri


@pytest.mark.parametrize("use_cache", [True, False])
def test_map_file(cloud_test_catalog, use_cache):
ctc = cloud_test_catalog
Expand Down
7 changes: 4 additions & 3 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import sqlalchemy as sa

from datachain.catalog.catalog import DATASET_INTERNAL_ERROR_MESSAGE
from datachain.client.local import FileClient
from datachain.data_storage.sqlite import SQLiteWarehouse
from datachain.dataset import LISTING_PREFIX, DatasetDependencyType, DatasetStatus
from datachain.dataset import DatasetDependencyType, DatasetStatus
from datachain.error import DatasetInvalidVersionError, DatasetNotFoundError
from datachain.lib.dc import DataChain
from datachain.lib.listing import parse_listing_uri
Expand Down Expand Up @@ -805,7 +806,7 @@ def test_dataset_stats_registered_ds(cloud_test_catalog, dogs_dataset):


@pytest.mark.parametrize("indirect", [True, False])
def test_dataset_storage_dependencies(cloud_test_catalog, indirect):
def test_dataset_storage_dependencies(cloud_test_catalog, cloud_type, indirect):
ctc = cloud_test_catalog
session = ctc.session
catalog = session.catalog
Expand All @@ -824,7 +825,7 @@ def test_dataset_storage_dependencies(cloud_test_catalog, indirect):
{
"id": ANY,
"type": DatasetDependencyType.STORAGE,
"name": lst_dataset.name.removeprefix(LISTING_PREFIX),
"name": uri if cloud_type != "file" else FileClient.root_path().as_uri(),
"version": "1",
"created_at": lst_dataset.get_version(1).created_at,
"dependencies": [],
Expand Down

0 comments on commit 78ee1ba

Please sign in to comment.