From 16aaa37998ab7a750cc6500b55e5a1598d46ee07 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:45:32 +0545 Subject: [PATCH] dc: try to fix dataset_stats for DataChain.from_storage() generated dataset (#151) --- src/datachain/data_storage/warehouse.py | 4 +++- tests/func/test_datachain.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/datachain/data_storage/warehouse.py b/src/datachain/data_storage/warehouse.py index 7c396a981..18aa359a6 100644 --- a/src/datachain/data_storage/warehouse.py +++ b/src/datachain/data_storage/warehouse.py @@ -390,7 +390,9 @@ def dataset_stats( expressions: tuple[_ColumnsClauseArgument[Any], ...] = ( sa.func.count(table.c.sys__id), ) - if "size" in table.columns: + if "file__size" in table.columns: + expressions = (*expressions, sa.func.sum(table.c.file__size)) + elif "size" in table.columns: expressions = (*expressions, sa.func.sum(table.c.size)) query = select(*expressions) ((nrows, *rest),) = self.db.execute(query) diff --git a/tests/func/test_datachain.py b/tests/func/test_datachain.py index aabdd57d4..a4356dadd 100644 --- a/tests/func/test_datachain.py +++ b/tests/func/test_datachain.py @@ -4,6 +4,7 @@ import pandas as pd import pytest +from datachain.dataset import DatasetStats from datachain.lib.dc import DataChain from datachain.lib.file import File @@ -205,3 +206,12 @@ def test_show_no_truncate(capsys, catalog): for i in range(3): assert client[i] in normalized_output assert details[i] in normalized_output + + +def test_from_storage_dataset_stats(tmp_dir, catalog): + for i in range(4): + (tmp_dir / f"file{i}.txt").write_text(f"file{i}") + + dc = DataChain.from_storage(tmp_dir.as_uri(), catalog=catalog).save("test-data") + stats = catalog.dataset_stats(dc.name, dc.version) + assert stats == DatasetStats(num_objects=4, size=20)