5
5
LICENSE file in the root directory of this source tree.
6
6
"""
7
7
8
+ import inspect
8
9
import json
9
10
import logging
10
11
from abc import ABC
11
12
from io import TextIOWrapper , StringIO
12
13
from json import JSONDecodeError
13
14
from typing import Generic , Any , TypeVar
14
- from collections .abc import Callable , Mapping
15
+ from collections .abc import Awaitable , Callable , Mapping
15
16
import warnings
16
17
17
18
from pydantic import JsonValue , ValidationError
@@ -225,6 +226,15 @@ def __init_subclass__(cls, **kwargs):
225
226
"The DEPRECATED 'process' method must not be implemented "
226
227
"alongside 'process_input' or 'process_response'."
227
228
)
229
+ if is_process_overridden and inspect .iscoroutinefunction (
230
+ inspect .unwrap (cls .process )
231
+ ):
232
+ # we don't want to add async capabilities to the deprecated function
233
+ raise TypeError (
234
+ f"Cannot create concrete class { cls .__name__ } . "
235
+ "The DEPRECATED 'process' method does not support async. "
236
+ "Implement 'process_input' and/or 'process_response' instead."
237
+ )
228
238
229
239
return
230
240
@@ -875,15 +885,18 @@ async def _parse_and_process(self, request: Request) -> Response:
875
885
prompt_hash , response_hash = (None , None )
876
886
if input_direction :
877
887
prompt_hash = prompt .hash ()
878
- result : Result | Reject = self .process_input (
888
+ result = await self ._handle_process_function (
889
+ self .process_input ,
879
890
metadata = metadata ,
880
891
parameters = parameters ,
881
892
prompt = prompt ,
882
893
request = request ,
883
894
)
895
+
884
896
else :
885
897
response_hash = response .hash ()
886
- result : Result | Reject = self .process_response (
898
+ result = await self ._handle_process_function (
899
+ self .process_response ,
887
900
metadata = metadata ,
888
901
parameters = parameters ,
889
902
prompt = prompt ,
@@ -1014,13 +1027,22 @@ def _is_method_overridden(self, method_name: str) -> bool:
1014
1027
# the method object directly from the Processor class, then it has been overridden.
1015
1028
return instance_class_method_obj is not base_class_method_obj
1016
1029
1030
+ def _process_fallback (self , ** kwargs ) -> Result | Reject :
1031
+ warnings .warn (
1032
+ f"{ type (self ).__name__ } uses the deprecated 'process' method. "
1033
+ "Implement 'process_input' and/or 'process_response' instead." ,
1034
+ DeprecationWarning ,
1035
+ stacklevel = 2 ,
1036
+ )
1037
+ return self .process (** kwargs )
1038
+
1017
1039
def process_input (
1018
1040
self ,
1019
1041
prompt : PROMPT ,
1020
1042
metadata : Metadata ,
1021
1043
parameters : PARAMS ,
1022
1044
request : Request ,
1023
- ) -> Result | Reject :
1045
+ ) -> Result | Reject | Awaitable [ Result | Reject ] :
1024
1046
"""
1025
1047
This abstract method is for implementors of the processor to define
1026
1048
with their own custom logic. Errors should be raised as a subclass
@@ -1043,23 +1065,17 @@ def process_input(self, prompt, response, metadata, parameters, request):
1043
1065
1044
1066
return Result(processor_result=result)
1045
1067
"""
1046
- if self ._is_method_overridden ("process" ):
1047
- warnings .warn (
1048
- f"{ type (self ).__name__ } uses the deprecated 'process' method for input. "
1049
- "Implement 'process_input' instead." ,
1050
- DeprecationWarning ,
1051
- stacklevel = 2 , # Points the warning to the caller of process_input
1068
+ if not self ._is_method_overridden ("process" ):
1069
+ raise NotImplementedError (
1070
+ f"{ type (self ).__name__ } must implement 'process_input' or the "
1071
+ "deprecated 'process' method to handle input."
1052
1072
)
1053
- return self .process (
1054
- prompt = prompt ,
1055
- response = None ,
1056
- metadata = metadata ,
1057
- parameters = parameters ,
1058
- request = request ,
1059
- )
1060
- raise NotImplementedError (
1061
- f"{ type (self ).__name__ } must implement 'process_input' or the "
1062
- "deprecated 'process' method to handle input."
1073
+ return self ._process_fallback (
1074
+ prompt = prompt ,
1075
+ response = None ,
1076
+ metadata = metadata ,
1077
+ parameters = parameters ,
1078
+ request = request ,
1063
1079
)
1064
1080
1065
1081
def process_response (
@@ -1069,7 +1085,7 @@ def process_response(
1069
1085
metadata : Metadata ,
1070
1086
parameters : PARAMS ,
1071
1087
request : Request ,
1072
- ) -> Result | Reject :
1088
+ ) -> Result | Reject | Awaitable [ Result | Reject ] :
1073
1089
"""
1074
1090
This abstract method is for implementors of the processor to define
1075
1091
with their own custom logic. Errors should be raised as a subclass
@@ -1096,23 +1112,17 @@ def process_response(self, prompt, response, metadata, parameters, request):
1096
1112
return Result(processor_result=result)
1097
1113
"""
1098
1114
1099
- if self ._is_method_overridden ("process" ):
1100
- warnings .warn (
1101
- f"{ type (self ).__name__ } uses the deprecated 'process' method for response. "
1102
- "Implement 'process_response' instead." ,
1103
- DeprecationWarning ,
1104
- stacklevel = 2 , # Points the warning to the caller of process_input
1115
+ if not self ._is_method_overridden ("process" ):
1116
+ raise NotImplementedError (
1117
+ f"{ type (self ).__name__ } must implement 'process_response' or the "
1118
+ "deprecated 'process' method to handle input."
1105
1119
)
1106
- return self .process (
1107
- prompt = prompt ,
1108
- response = response ,
1109
- metadata = metadata ,
1110
- parameters = parameters ,
1111
- request = request ,
1112
- )
1113
- raise NotImplementedError (
1114
- f"{ type (self ).__name__ } must implement 'process_response' or the "
1115
- "deprecated 'process' method to handle input."
1120
+ return self ._process_fallback (
1121
+ prompt = prompt ,
1122
+ response = response ,
1123
+ metadata = metadata ,
1124
+ parameters = parameters ,
1125
+ request = request ,
1116
1126
)
1117
1127
1118
1128
def process (
@@ -1159,6 +1169,13 @@ def process(self, prompt, response, metadata, parameters, request):
1159
1169
"'process_input'/'process_response'."
1160
1170
)
1161
1171
1172
+ async def _handle_process_function (self , func , ** kwargs ) -> Result | Reject :
1173
+ if inspect .iscoroutinefunction (func ):
1174
+ result = await func (** kwargs )
1175
+ else :
1176
+ result = func (** kwargs )
1177
+ return result
1178
+
1162
1179
1163
1180
def _validation_error_as_messages (err : ValidationError ) -> list [str ]:
1164
1181
return [_error_details_to_str (e ) for e in err .errors ()]
0 commit comments