diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 17e17fab9..2535935df 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -2731,6 +2731,49 @@ def test_group_by_multiple_partition_by(test_session): ) +def test_group_by_no_partition_by(test_session): + from datachain import func + + ds = ( + DataChain.from_values( + col1=["a", "a", "b", "b", "b", "c"], + col2=[1, 2, 1, 2, 1, 2], + col3=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + col4=["1", "2", "3", "4", "5", "6"], + session=test_session, + ) + .order_by("col4") + .group_by( + cnt=func.count(), + cnt_col=func.count("col2"), + sum=func.sum("col3"), + concat=func.concat("col4"), + value=func.any_value("col3"), + collect=func.collect("col3"), + ) + .save("my-ds") + ) + + assert ds.signals_schema.serialize() == { + "cnt": "int", + "cnt_col": "int", + "sum": "float", + "concat": "str", + "value": "float", + "collect": "list[float]", + } + assert ds.to_records() == [ + { + "cnt": 6, + "cnt_col": 6, + "sum": 21.0, + "concat": "123456", + "value": 1.0, + "collect": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + }, + ] + + def test_group_by_error(test_session): from datachain import func