Skip to content

Commit 950ecde

Browse files
feat: async support for process functions (#9)
* feat: async support for process_ * Apply suggestions from code review Co-authored-by: andrey-zelenkov <[email protected]> * chore: fix type checking on implementation side --------- Co-authored-by: andrey-zelenkov <[email protected]>
1 parent b63f745 commit 950ecde

File tree

5 files changed

+229
-135
lines changed

5 files changed

+229
-135
lines changed

src/f5_ai_gateway_sdk/processor.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
LICENSE file in the root directory of this source tree.
66
"""
77

8+
import inspect
89
import json
910
import logging
1011
from abc import ABC
1112
from io import TextIOWrapper, StringIO
1213
from json import JSONDecodeError
1314
from typing import Generic, Any, TypeVar
14-
from collections.abc import Callable, Mapping
15+
from collections.abc import Awaitable, Callable, Mapping
1516
import warnings
1617

1718
from pydantic import JsonValue, ValidationError
@@ -225,6 +226,15 @@ def __init_subclass__(cls, **kwargs):
225226
"The DEPRECATED 'process' method must not be implemented "
226227
"alongside 'process_input' or 'process_response'."
227228
)
229+
if is_process_overridden and inspect.iscoroutinefunction(
230+
inspect.unwrap(cls.process)
231+
):
232+
# we don't want to add async capabilities to the deprecated function
233+
raise TypeError(
234+
f"Cannot create concrete class {cls.__name__}. "
235+
"The DEPRECATED 'process' method does not support async. "
236+
"Implement 'process_input' and/or 'process_response' instead."
237+
)
228238

229239
return
230240

@@ -875,15 +885,18 @@ async def _parse_and_process(self, request: Request) -> Response:
875885
prompt_hash, response_hash = (None, None)
876886
if input_direction:
877887
prompt_hash = prompt.hash()
878-
result: Result | Reject = self.process_input(
888+
result = await self._handle_process_function(
889+
self.process_input,
879890
metadata=metadata,
880891
parameters=parameters,
881892
prompt=prompt,
882893
request=request,
883894
)
895+
884896
else:
885897
response_hash = response.hash()
886-
result: Result | Reject = self.process_response(
898+
result = await self._handle_process_function(
899+
self.process_response,
887900
metadata=metadata,
888901
parameters=parameters,
889902
prompt=prompt,
@@ -1014,13 +1027,22 @@ def _is_method_overridden(self, method_name: str) -> bool:
10141027
# the method object directly from the Processor class, then it has been overridden.
10151028
return instance_class_method_obj is not base_class_method_obj
10161029

1030+
def _process_fallback(self, **kwargs) -> Result | Reject:
1031+
warnings.warn(
1032+
f"{type(self).__name__} uses the deprecated 'process' method. "
1033+
"Implement 'process_input' and/or 'process_response' instead.",
1034+
DeprecationWarning,
1035+
stacklevel=2,
1036+
)
1037+
return self.process(**kwargs)
1038+
10171039
def process_input(
10181040
self,
10191041
prompt: PROMPT,
10201042
metadata: Metadata,
10211043
parameters: PARAMS,
10221044
request: Request,
1023-
) -> Result | Reject:
1045+
) -> Result | Reject | Awaitable[Result | Reject]:
10241046
"""
10251047
This abstract method is for implementors of the processor to define
10261048
with their own custom logic. Errors should be raised as a subclass
@@ -1043,23 +1065,17 @@ def process_input(self, prompt, response, metadata, parameters, request):
10431065
10441066
return Result(processor_result=result)
10451067
"""
1046-
if self._is_method_overridden("process"):
1047-
warnings.warn(
1048-
f"{type(self).__name__} uses the deprecated 'process' method for input. "
1049-
"Implement 'process_input' instead.",
1050-
DeprecationWarning,
1051-
stacklevel=2, # Points the warning to the caller of process_input
1068+
if not self._is_method_overridden("process"):
1069+
raise NotImplementedError(
1070+
f"{type(self).__name__} must implement 'process_input' or the "
1071+
"deprecated 'process' method to handle input."
10521072
)
1053-
return self.process(
1054-
prompt=prompt,
1055-
response=None,
1056-
metadata=metadata,
1057-
parameters=parameters,
1058-
request=request,
1059-
)
1060-
raise NotImplementedError(
1061-
f"{type(self).__name__} must implement 'process_input' or the "
1062-
"deprecated 'process' method to handle input."
1073+
return self._process_fallback(
1074+
prompt=prompt,
1075+
response=None,
1076+
metadata=metadata,
1077+
parameters=parameters,
1078+
request=request,
10631079
)
10641080

10651081
def process_response(
@@ -1069,7 +1085,7 @@ def process_response(
10691085
metadata: Metadata,
10701086
parameters: PARAMS,
10711087
request: Request,
1072-
) -> Result | Reject:
1088+
) -> Result | Reject | Awaitable[Result | Reject]:
10731089
"""
10741090
This abstract method is for implementors of the processor to define
10751091
with their own custom logic. Errors should be raised as a subclass
@@ -1096,23 +1112,17 @@ def process_response(self, prompt, response, metadata, parameters, request):
10961112
return Result(processor_result=result)
10971113
"""
10981114

1099-
if self._is_method_overridden("process"):
1100-
warnings.warn(
1101-
f"{type(self).__name__} uses the deprecated 'process' method for response. "
1102-
"Implement 'process_response' instead.",
1103-
DeprecationWarning,
1104-
stacklevel=2, # Points the warning to the caller of process_input
1115+
if not self._is_method_overridden("process"):
1116+
raise NotImplementedError(
1117+
f"{type(self).__name__} must implement 'process_response' or the "
1118+
"deprecated 'process' method to handle input."
11051119
)
1106-
return self.process(
1107-
prompt=prompt,
1108-
response=response,
1109-
metadata=metadata,
1110-
parameters=parameters,
1111-
request=request,
1112-
)
1113-
raise NotImplementedError(
1114-
f"{type(self).__name__} must implement 'process_response' or the "
1115-
"deprecated 'process' method to handle input."
1120+
return self._process_fallback(
1121+
prompt=prompt,
1122+
response=response,
1123+
metadata=metadata,
1124+
parameters=parameters,
1125+
request=request,
11161126
)
11171127

11181128
def process(
@@ -1159,6 +1169,13 @@ def process(self, prompt, response, metadata, parameters, request):
11591169
"'process_input'/'process_response'."
11601170
)
11611171

1172+
async def _handle_process_function(self, func, **kwargs) -> Result | Reject:
1173+
if inspect.iscoroutinefunction(func):
1174+
result = await func(**kwargs)
1175+
else:
1176+
result = func(**kwargs)
1177+
return result
1178+
11621179

11631180
def _validation_error_as_messages(err: ValidationError) -> list[str]:
11641181
return [_error_details_to_str(e) for e in err.errors()]

0 commit comments

Comments
 (0)