Skip to content

Commit

Permalink
test(sdk): add coverage for returning a list mapping to args/kwargs (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
seriousben authored Dec 8, 2024
1 parent 7e188a6 commit 5309853
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion python-sdk/tests/test_graph_behaviours.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from pathlib import Path
from typing import List, Union
from typing import List, Tuple, Union

from parameterized import parameterized
from pydantic import BaseModel
Expand Down Expand Up @@ -471,6 +471,54 @@ def my_func_2(input: dict) -> int:
self.assertEqual(len(output), 1)
self.assertEqual(output[0], 6)

@parameterized.expand([(False), (True)])
def test_return_dict_args_as_kwargs_in_list(self, is_remote):
@indexify_function()
def my_func(text: str) -> List[dict]:
return [dict(index=index, char=char) for index, char in enumerate(text)]

@indexify_function()
def my_func_2(index: int, char: str) -> str:
return f"{char}={index}"

graph = Graph(
name="test_return_dict_args_as_kwargs_in_list",
description="test",
start_node=my_func,
)

graph.add_edge(my_func, my_func_2)
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, text="hi")
output = graph.output(invocation_id, my_func_2.name)
self.assertEqual(len(output), 2)
self.assertIn("h=0", output)
self.assertIn("i=1", output)

@parameterized.expand([(False), (True)])
def test_return_dict_args_as_dict_in_list(self, is_remote):
@indexify_function()
def my_func(text: str) -> List[dict]:
return [dict(index=index, char=char) for index, char in enumerate(text)]

@indexify_function()
def my_func_2(data: dict) -> str:
return f"{data['char']}={data['index']}"

graph = Graph(
name="test_return_dict_args_as_dict_in_list",
description="test",
start_node=my_func,
)

graph.add_edge(my_func, my_func_2)
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, text="hi")
output = graph.output(invocation_id, my_func_2.name)
self.assertEqual(len(output), 2)
self.assertIn("h=0", output)
self.assertIn("i=1", output)

@parameterized.expand([(False), (True)])
def test_return_multiple_dict_as_args(self, is_remote):
@indexify_function(input_encoder="json", output_encoder="json")
Expand Down

0 comments on commit 5309853

Please sign in to comment.