Skip to content

Commit

Permalink
fix(sdk): add IndexifyRouter test
Browse files Browse the repository at this point in the history
  • Loading branch information
seriousben committed Dec 20, 2024
1 parent 83098de commit 501be11
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def _indexify_client(


def _is_router(func_wrapper: IndexifyFunctionWrapper) -> bool:
"""Determines if the function is a router.
A function is a router if it is an instance of IndexifyRouter or if it is an IndexifyRouter class.
"""
return str(
type(func_wrapper.indexify_function)
) == "<class 'indexify.functions_sdk.indexify_functions.IndexifyRouter'>" or isinstance(
Expand Down
112 changes: 61 additions & 51 deletions python-sdk/tests/test_graph_behaviours.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,59 @@ def create_router_graph():
return graph


class SimpleFunctionCtxClsObject(BaseModel):
x: int


class SimpleFunctionCtxCls(IndexifyFunction):
name = "SimpleFunctionCtxCls"

def __init__(self):
super().__init__()

def run(self, obj: SimpleFunctionCtxClsObject) -> SimpleFunctionCtxClsObject:
return SimpleFunctionCtxClsObject(x=obj.x + 1)


class SimpleRouterCtxClsObject(BaseModel):
x: int


class SimpleFunctionCtxCls1(IndexifyFunction):
name = "SimpleFunctionCtxCls1"

def __init__(self):
super().__init__()

def run(self, obj: SimpleRouterCtxClsObject) -> SimpleRouterCtxClsObject:
return SimpleRouterCtxClsObject(x=obj.x + 1)


class SimpleFunctionCtxCls2(IndexifyFunction):
name = "SimpleFunctionCtxCls2"

def __init__(self):
super().__init__()

def run(self, obj: SimpleRouterCtxClsObject) -> SimpleRouterCtxClsObject:
return SimpleRouterCtxClsObject(x=obj.x + 2)


class SimpleRouterCtxCls(IndexifyRouter):
name = "SimpleRouterCtxCls"

def __init__(self):
super().__init__()

def run(
self, obj: SimpleRouterCtxClsObject
) -> Union[SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]:
if obj.x % 2 == 0:
return SimpleFunctionCtxCls1
else:
return SimpleFunctionCtxCls2


def create_simple_pipeline():
p = Pipeline("simple_pipeline", "A simple pipeline")
p.add_step(generate_seq)
Expand Down Expand Up @@ -242,23 +295,13 @@ def test_simple_function(self, is_remote):

@parameterized.expand([(False), (True)])
def test_simple_function_cls(self, is_remote):
class MyObject(BaseModel):
x: int

class SimpleFunctionCtxCls(IndexifyFunction):
name = "SimpleFunctionCtxCls"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 1)

graph = Graph(name="test_simple_function_cls", start_node=SimpleFunctionCtxCls)
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1))
invocation_id = graph.run(
block_until_done=True, obj=SimpleFunctionCtxClsObject(x=1)
)
output = graph.output(invocation_id, "SimpleFunctionCtxCls")
self.assertEqual(output, [MyObject(x=2)])
self.assertEqual(output, [SimpleFunctionCtxClsObject(x=2)])

@parameterized.expand([(False), (True)])
def test_simple_function_with_json_encoding(self, is_remote):
Expand Down Expand Up @@ -660,47 +703,14 @@ def test_router_graph_behavior(self, is_remote):

@parameterized.expand([(False), (True)])
def test_router_graph_behavior_cls(self, is_remote):
class MyObject(BaseModel):
x: int

class SimpleFunctionCtxCls1(IndexifyFunction):
name = "SimpleFunctionCtxCls1"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 1)

class SimpleFunctionCtxCls2(IndexifyFunction):
name = "SimpleFunctionCtxCls2"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 2)

class SimpleRouterCtxCls(IndexifyRouter):
name = "SimpleRouterCtxCls"

def __init__(self):
super().__init__()

def run(
self, obj: MyObject
) -> Union[SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]:
if obj.x % 2 == 0:
return SimpleFunctionCtxCls1
else:
return SimpleFunctionCtxCls2

graph = Graph(name="test_simple_function_cls", start_node=SimpleRouterCtxCls)
graph.route(SimpleRouterCtxCls, [SimpleFunctionCtxCls1, SimpleFunctionCtxCls2])
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1))
invocation_id = graph.run(
block_until_done=True, obj=SimpleRouterCtxClsObject(x=1)
)
output = graph.output(invocation_id, "SimpleFunctionCtxCls2")
self.assertEqual(output, [MyObject(x=3)])
self.assertEqual(output, [SimpleRouterCtxClsObject(x=3)])

@parameterized.expand([(False), (True)])
def test_invoke_file(self, is_remote):
Expand Down

0 comments on commit 501be11

Please sign in to comment.