Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): fix support for args when using IndexifyFunction and IndexifyRouter classes #1127

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python-sdk/indexify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .functions_sdk.image import Image
from .functions_sdk.indexify_functions import (
IndexifyFunction,
IndexifyRouter,
get_ctx,
indexify_function,
indexify_router,
Expand All @@ -23,6 +24,7 @@
"indexify_function",
"get_ctx",
"IndexifyFunction",
"IndexifyRouter",
"indexify_router",
"DEFAULT_SERVICE_URL",
"IndexifyClient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def _indexify_client(


def _is_router(func_wrapper: IndexifyFunctionWrapper) -> bool:
return (
str(type(func_wrapper.indexify_function))
== "<class 'indexify.functions_sdk.indexify_functions.IndexifyRouter'>"
return str(
type(func_wrapper.indexify_function)
) == "<class 'indexify.functions_sdk.indexify_functions.IndexifyRouter'>" or isinstance(
func_wrapper.indexify_function, IndexifyRouter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we only use the isinstance check here?
It's clear easier to maintain if it gives the same results.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my testing, the way we build the IndexifyFunction using type(...) fails the is instance check.

Will do more testing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. If isinstance() fails and we won't figure out why, then let's not add isinstance() here and add a comment that's it's not working. So others don't try to quickfix it with the isinstance() check.

Copy link
Member Author

@seriousben seriousben Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not be adding code for no reason :)

The isintance is needed when using the IndexifyRouter class. But it looks like when using the indexify_router function, it is not enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it can either be an instance or a class.

)


Expand Down
88 changes: 31 additions & 57 deletions python-sdk/indexify/functions_sdk/indexify_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
import traceback
from inspect import Parameter
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Expand All @@ -14,6 +14,7 @@
)

from pydantic import BaseModel, Field, PrivateAttr
from typing_extensions import get_type_hints

from .data_objects import IndexifyData
from .image import DEFAULT_IMAGE, Image
Expand Down Expand Up @@ -89,10 +90,22 @@ class IndexifyFunction:
def run(self, *args, **kwargs) -> Union[List[Any], Any]:
pass

def partial(self, **kwargs) -> Callable:
from functools import partial

return partial(self.run, **kwargs)
def _call_run(self, *args, **kwargs) -> Union[List[Any], Any]:
# Process dictionary argument mapping it to args or to kwargs.
if self.accumulate and len(args) == 2 and isinstance(args[1], dict):
sig = inspect.signature(self.run)
new_args = [args[0]] # Keep the accumulate argument
dict_arg = args[1]
new_args_from_dict, new_kwargs = _process_dict_arg(dict_arg, sig)
new_args.extend(new_args_from_dict)
return self.run(*new_args, **new_kwargs)
elif len(args) == 1 and isinstance(args[0], dict):
sig = inspect.signature(self.run)
dict_arg = args[0]
new_args, new_kwargs = _process_dict_arg(dict_arg, sig)
return self.run(*new_args, **new_kwargs)

return self.run(*args, **kwargs)

@classmethod
def deserialize_output(cls, output: IndexifyData) -> Any:
Expand All @@ -111,10 +124,16 @@ class IndexifyRouter:
def run(self, *args, **kwargs) -> Optional[List[IndexifyFunction]]:
pass

# Create run method that preserves signature
def _call_run(self, *args, **kwargs):
# Process dictionary argument mapping it to args or to kwargs.
if len(args) == 1 and isinstance(args[0], dict):
sig = inspect.signature(self.run)
dict_arg = args[0]
new_args, new_kwargs = _process_dict_arg(dict_arg, sig)
return self.run(*new_args, **new_kwargs)

from inspect import Parameter, signature

from typing_extensions import get_type_hints
return self.run(*args, **kwargs)


def _process_dict_arg(dict_arg: dict, sig: inspect.Signature) -> Tuple[list, dict]:
Expand Down Expand Up @@ -147,25 +166,6 @@ def indexify_router(
output_encoder: Optional[str] = "cloudpickle",
):
def construct(fn):
# Get function signature using inspect.signature
fn_sig = signature(fn)
fn_hints = get_type_hints(fn)

# Create run method that preserves signature
def run(self, *args, **kwargs):
# Process dictionary argument mapping it to args or to kwargs.
if len(args) == 1 and isinstance(args[0], dict):
sig = inspect.signature(fn)
dict_arg = args[0]
new_args, new_kwargs = _process_dict_arg(dict_arg, sig)
return fn(*new_args, **new_kwargs)

return fn(*args, **kwargs)

# Apply original signature and annotations to run method
run.__signature__ = fn_sig
run.__annotations__ = fn_hints

attrs = {
"name": name if name else fn.__name__,
"description": (
Expand All @@ -177,7 +177,7 @@ def run(self, *args, **kwargs):
"placement_constraints": placement_constraints,
"input_encoder": input_encoder,
"output_encoder": output_encoder,
"run": run,
"run": staticmethod(fn),
}

return type("IndexifyRouter", (IndexifyRouter,), attrs)
Expand All @@ -195,32 +195,6 @@ def indexify_function(
placement_constraints: List[PlacementConstraints] = [],
):
def construct(fn):
# Get function signature using inspect.signature
fn_sig = signature(fn)
fn_hints = get_type_hints(fn)

# Create run method that preserves signature
def run(self, *args, **kwargs):
# Process dictionary argument mapping it to args or to kwargs.
if self.accumulate and len(args) == 2 and isinstance(args[1], dict):
sig = inspect.signature(fn)
new_args = [args[0]] # Keep the accumulate argument
dict_arg = args[1]
new_args_from_dict, new_kwargs = _process_dict_arg(dict_arg, sig)
new_args.extend(new_args_from_dict)
return fn(*new_args, **new_kwargs)
elif len(args) == 1 and isinstance(args[0], dict):
sig = inspect.signature(fn)
dict_arg = args[0]
new_args, new_kwargs = _process_dict_arg(dict_arg, sig)
return fn(*new_args, **new_kwargs)

return fn(*args, **kwargs)

# Apply original signature and annotations to run method
run.__signature__ = fn_sig
run.__annotations__ = fn_hints

attrs = {
"name": name if name else fn.__name__,
"description": (
Expand All @@ -233,7 +207,7 @@ def run(self, *args, **kwargs):
"accumulate": accumulate,
"input_encoder": input_encoder,
"output_encoder": output_encoder,
"run": run,
"run": staticmethod(fn),
}

return type("IndexifyFunction", (IndexifyFunction,), attrs)
Expand Down Expand Up @@ -303,7 +277,7 @@ def run_router(
args += input
else:
args.append(input)
extracted_data = self.indexify_function.run(*args, **kwargs)
extracted_data = self.indexify_function._call_run(*args, **kwargs)
except Exception as e:
return [], traceback.format_exc()
if not isinstance(extracted_data, list) and extracted_data is not None:
Expand All @@ -330,7 +304,7 @@ def run_fn(
args.append(input)

try:
extracted_data = self.indexify_function.run(*args, **kwargs)
extracted_data = self.indexify_function._call_run(*args, **kwargs)
except Exception as e:
return [], traceback.format_exc()
if extracted_data is None:
Expand Down
14 changes: 0 additions & 14 deletions python-sdk/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,6 @@ def extractor_c(url: str) -> str:
result, _ = extractor_wrapper.run_fn({"url": "foo"})
self.assertEqual(result[0], "123")

# FIXME: Partial extractor is not working
# def test_partial_extractor(self):
# @extractor()
# def extractor_c(url: str, some_other_param: str) -> str:
# """
# Random description of extractor_c
# """
# return f"hello {some_other_param}"

# print(type(extractor_c))
# partial_extractor = extractor_c.partial(some_other_param="world")
# result = partial_extractor.extract(BaseData.from_data(url="foo"))
# self.assertEqual(result[0].payload, "hello world")


if __name__ == "__main__":
unittest.main()
65 changes: 65 additions & 0 deletions python-sdk/tests/test_graph_behaviours.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from indexify import (
Graph,
IndexifyFunction,
IndexifyRouter,
Pipeline,
RemoteGraph,
RemotePipeline,
Expand Down Expand Up @@ -239,6 +240,26 @@ def test_simple_function(self, is_remote):
output = graph.output(invocation_id, "simple_function")
self.assertEqual(output, [MyObject(x="ab")])

@parameterized.expand([(False), (True)])
def test_simple_function_cls(self, is_remote):
class MyObject(BaseModel):
x: int

class SimpleFunctionCtxCls(IndexifyFunction):
name = "SimpleFunctionCtxCls"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 1)

graph = Graph(name="test_simple_function_cls", start_node=SimpleFunctionCtxCls)
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1))
output = graph.output(invocation_id, "SimpleFunctionCtxCls")
self.assertEqual(output, [MyObject(x=2)])

@parameterized.expand([(False), (True)])
def test_simple_function_with_json_encoding(self, is_remote):
graph = Graph(
Expand Down Expand Up @@ -637,6 +658,50 @@ def test_router_graph_behavior(self, is_remote):
output_str = graph.output(invocation_id, "make_it_string_from_int")
self.assertEqual(output_str, ["7"])

@parameterized.expand([(False), (True)])
def test_router_graph_behavior_cls(self, is_remote):
class MyObject(BaseModel):
x: int

class SimpleFunctionCtxCls1(IndexifyFunction):
name = "SimpleFunctionCtxCls1"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 1)

class SimpleFunctionCtxCls2(IndexifyFunction):
name = "SimpleFunctionCtxCls2"

def __init__(self):
super().__init__()

def run(self, obj: MyObject) -> MyObject:
return MyObject(x=obj.x + 2)

class SimpleRouterCtxCls(IndexifyRouter):
name = "SimpleRouterCtxCls"

def __init__(self):
super().__init__()

def run(
self, obj: MyObject
) -> Union[SimpleFunctionCtxCls1, SimpleFunctionCtxCls2]:
if obj.x % 2 == 0:
return SimpleFunctionCtxCls1
else:
return SimpleFunctionCtxCls2

graph = Graph(name="test_simple_function_cls", start_node=SimpleRouterCtxCls)
graph.route(SimpleRouterCtxCls, [SimpleFunctionCtxCls1, SimpleFunctionCtxCls2])
graph = remote_or_local_graph(graph, is_remote)
invocation_id = graph.run(block_until_done=True, obj=MyObject(x=1))
output = graph.output(invocation_id, "SimpleFunctionCtxCls2")
self.assertEqual(output, [MyObject(x=3)])

@parameterized.expand([(False), (True)])
def test_invoke_file(self, is_remote):
graph = Graph(
Expand Down
Loading