diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index 490b43238040..0ba6a140853f 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -11,3 +11,13 @@ def test_basic_dataset_lifecycle(client): == list(dataset.rows) == [{"a": 5, "b": 6}, {"a": 7, "b": 10}] ) + + +def test_dataset_iteration(client): + dataset = weave.Dataset(rows=[{"a": 5, "b": 6}, {"a": 7, "b": 10}]) + rows = list(dataset) + assert rows == [{"a": 5, "b": 6}, {"a": 7, "b": 10}] + + # Test that we can iterate multiple times + rows2 = list(dataset) + assert rows2 == rows diff --git a/weave/flow/dataset.py b/weave/flow/dataset.py index 03a415a4bc7c..0bcd9c60b81d 100644 --- a/weave/flow/dataset.py +++ b/weave/flow/dataset.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from typing import Any from pydantic import field_validator @@ -65,3 +66,6 @@ def convert_to_table(cls, rows: Any) -> weave.Table: "Attempted to construct a Dataset row with an empty dict." ) return rows + + def __iter__(self) -> Iterator[dict]: + return iter(self.rows)