diff --git a/python-sdk/indexify/function_executor/handlers/run_function/handler.py b/python-sdk/indexify/function_executor/handlers/run_function/handler.py index 0bab792b9..7b33235aa 100644 --- a/python-sdk/indexify/function_executor/handlers/run_function/handler.py +++ b/python-sdk/indexify/function_executor/handlers/run_function/handler.py @@ -139,6 +139,10 @@ def _indexify_client( def _is_router(func_wrapper: IndexifyFunctionWrapper) -> bool: + """Determines if the function is a router. + + A function is a router if it is an instance of IndexifyRouter or if it is an IndexifyRouter class. + """ return str( type(func_wrapper.indexify_function) ) == "" or isinstance( diff --git a/python-sdk/tests/test_graph_behaviours.py b/python-sdk/tests/test_graph_behaviours.py index 915ef9de6..2a57ff31e 100644 --- a/python-sdk/tests/test_graph_behaviours.py +++ b/python-sdk/tests/test_graph_behaviours.py @@ -214,6 +214,59 @@ def create_router_graph(): return graph +class SimpleFunctionCtxClsObject(BaseModel): + x: int + + +class SimpleFunctionCtxCls(IndexifyFunction): + name = "SimpleFunctionCtxCls" + + def __init__(self): + super().__init__() + + def run(self, obj: SimpleFunctionCtxClsObject) -> SimpleFunctionCtxClsObject: + return SimpleFunctionCtxClsObject(x=obj.x + 1) + + +class SimpleRouterCtxClsObject(BaseModel): + x: int + + +class SimpleFunctionCtxCls1(IndexifyFunction): + name = "SimpleFunctionCtxCls1" + + def __init__(self): + super().__init__() + + def run(self, obj: SimpleRouterCtxClsObject) -> SimpleRouterCtxClsObject: + return SimpleRouterCtxClsObject(x=obj.x + 1) + + +class SimpleFunctionCtxCls2(IndexifyFunction): + name = "SimpleFunctionCtxCls2" + + def __init__(self): + super().__init__() + + def run(self, obj: SimpleRouterCtxClsObject) -> SimpleRouterCtxClsObject: + return SimpleRouterCtxClsObject(x=obj.x + 2) + + +class SimpleRouterCtxCls(IndexifyRouter): + name = "SimpleRouterCtxCls" + + def __init__(self): + super().__init__() + + def run( + self, obj: SimpleRouterCtxClsObject + ) -> Union[SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]: + if obj.x % 2 == 0: + return SimpleFunctionCtxCls1 + else: + return SimpleFunctionCtxCls2 + + def create_simple_pipeline(): p = Pipeline("simple_pipeline", "A simple pipeline") p.add_step(generate_seq) @@ -242,23 +295,13 @@ def test_simple_function(self, is_remote): @parameterized.expand([(False), (True)]) def test_simple_function_cls(self, is_remote): - class MyObject(BaseModel): - x: int - - class SimpleFunctionCtxCls(IndexifyFunction): - name = "SimpleFunctionCtxCls" - - def __init__(self): - super().__init__() - - def run(self, obj: MyObject) -> MyObject: - return MyObject(x=obj.x + 1) - graph = Graph(name="test_simple_function_cls", start_node=SimpleFunctionCtxCls) graph = remote_or_local_graph(graph, is_remote) - invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1)) + invocation_id = graph.run( + block_until_done=True, obj=SimpleFunctionCtxClsObject(x=1) + ) output = graph.output(invocation_id, "SimpleFunctionCtxCls") - self.assertEqual(output, [MyObject(x=2)]) + self.assertEqual(output, [SimpleFunctionCtxClsObject(x=2)]) @parameterized.expand([(False), (True)]) def test_simple_function_with_json_encoding(self, is_remote): @@ -660,47 +703,14 @@ def test_router_graph_behavior(self, is_remote): @parameterized.expand([(False), (True)]) def test_router_graph_behavior_cls(self, is_remote): - class MyObject(BaseModel): - x: int - - class SimpleFunctionCtxCls1(IndexifyFunction): - name = "SimpleFunctionCtxCls1" - - def __init__(self): - super().__init__() - - def run(self, obj: MyObject) -> MyObject: - return MyObject(x=obj.x + 1) - - class SimpleFunctionCtxCls2(IndexifyFunction): - name = "SimpleFunctionCtxCls2" - - def __init__(self): - super().__init__() - - def run(self, obj: MyObject) -> MyObject: - return MyObject(x=obj.x + 2) - - class SimpleRouterCtxCls(IndexifyRouter): - name = "SimpleRouterCtxCls" - - def __init__(self): - super().__init__() - - def run( - self, obj: MyObject - ) -> Union[SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]: - if obj.x % 2 == 0: - return SimpleFunctionCtxCls1 - else: - return SimpleFunctionCtxCls2 - graph = Graph(name="test_simple_function_cls", start_node=SimpleRouterCtxCls) graph.route(SimpleRouterCtxCls, [SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]) graph = remote_or_local_graph(graph, is_remote) - invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1)) + invocation_id = graph.run( + block_until_done=True, obj=SimpleRouterCtxClsObject(x=1) + ) output = graph.output(invocation_id, "SimpleFunctionCtxCls2") - self.assertEqual(output, [MyObject(x=3)]) + self.assertEqual(output, [SimpleRouterCtxClsObject(x=3)]) @parameterized.expand([(False), (True)]) def test_invoke_file(self, is_remote):