Skip to content

Commit

Permalink
dc: try to fix dataset_stats for DataChain.from_storage() generated d…
Browse files Browse the repository at this point in the history
…ataset (#151)
  • Loading branch information
skshetry authored Jul 23, 2024
1 parent 5312913 commit 16aaa37
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,9 @@ def dataset_stats(
expressions: tuple[_ColumnsClauseArgument[Any], ...] = (
sa.func.count(table.c.sys__id),
)
if "size" in table.columns:
if "file__size" in table.columns:
expressions = (*expressions, sa.func.sum(table.c.file__size))
elif "size" in table.columns:
expressions = (*expressions, sa.func.sum(table.c.size))
query = select(*expressions)
((nrows, *rest),) = self.db.execute(query)
Expand Down
10 changes: 10 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

from datachain.dataset import DatasetStats
from datachain.lib.dc import DataChain
from datachain.lib.file import File

Expand Down Expand Up @@ -205,3 +206,12 @@ def test_show_no_truncate(capsys, catalog):
for i in range(3):
assert client[i] in normalized_output
assert details[i] in normalized_output


def test_from_storage_dataset_stats(tmp_dir, catalog):
for i in range(4):
(tmp_dir / f"file{i}.txt").write_text(f"file{i}")

dc = DataChain.from_storage(tmp_dir.as_uri(), catalog=catalog).save("test-data")
stats = catalog.dataset_stats(dc.name, dc.version)
assert stats == DatasetStats(num_objects=4, size=20)

0 comments on commit 16aaa37

Please sign in to comment.