3232 _PartialFunctionFlags ,
3333)
3434from modal ._serialization import deserialize , deserialize_params
35- from modal ._utils .async_utils import TaskContext , synchronizer
35+ from modal ._utils .async_utils import TaskContext , aclosing , synchronizer
3636from modal ._utils .function_utils import (
3737 callable_has_non_self_params ,
3838)
3939from modal .app import App , _App
4040from modal .client import Client , _Client
4141from modal .config import logger
42- from modal .exception import ExecutionError , InputCancellation , InvalidError
42+ from modal .exception import ExecutionError , InputCancellation
4343from modal .running_app import RunningApp , running_app_from_layout
4444from modal_proto import api_pb2
4545
@@ -184,17 +184,13 @@ def call_function(
184184 batch_wait_ms : int ,
185185):
186186 async def run_input_async (io_context : IOContext ) -> None :
187- started_at = time .time ()
188187 reset_context = execution_context ._set_current_context_ids (
189188 io_context .input_ids , io_context .function_call_ids , io_context .attempt_tokens
190189 )
190+ started_at = time .time ()
191191 async with container_io_manager .handle_input_exception .aio (io_context , started_at ):
192- res = io_context .call_finalized_function ()
193192 # TODO(erikbern): any exception below shouldn't be considered a user exception
194193 if io_context .finalized_function .is_generator :
195- if not inspect .isasyncgen (res ):
196- raise InvalidError (f"Async generator function returned value of type { type (res )} " )
197-
198194 # Send up to this many outputs at a time.
199195 current_function_call_id = execution_context .current_function_call_id ()
200196 assert current_function_call_id is not None # Set above.
@@ -204,33 +200,24 @@ async def run_input_async(io_context: IOContext) -> None:
204200 async with container_io_manager .generator_output_sender (
205201 current_function_call_id ,
206202 current_attempt_token ,
207- io_context .finalized_function . data_format ,
203+ io_context ._generator_output_format () ,
208204 generator_queue ,
209205 ):
210206 item_count = 0
211- async for value in res :
212- await container_io_manager ._queue_put .aio (generator_queue , value )
213- item_count += 1
207+ async with aclosing (io_context .call_generator_async ()) as gen :
208+ async for value in gen :
209+ await container_io_manager ._queue_put .aio (generator_queue , value )
210+ item_count += 1
214211
215- message = api_pb2 .GeneratorDone (items_total = item_count )
216- await container_io_manager .push_outputs .aio (
217- io_context ,
218- started_at ,
219- message ,
220- api_pb2 .DATA_FORMAT_GENERATOR_DONE ,
212+ await container_io_manager ._send_outputs .aio (
213+ started_at , io_context .output_items_generator_done (started_at , item_count )
221214 )
222215 else :
223- if not inspect .iscoroutine (res ) or inspect .isgenerator (res ) or inspect .isasyncgen (res ):
224- raise InvalidError (
225- f"Async (non-generator) function returned value of type { type (res )} "
226- " You might need to use @app.function(..., is_generator=True)."
227- )
228- value = await res
216+ value = await io_context .call_function_async ()
229217 await container_io_manager .push_outputs .aio (
230218 io_context ,
231219 started_at ,
232220 value ,
233- io_context .finalized_function .data_format ,
234221 )
235222 reset_context ()
236223
@@ -240,13 +227,9 @@ def run_input_sync(io_context: IOContext) -> None:
240227 io_context .input_ids , io_context .function_call_ids , io_context .attempt_tokens
241228 )
242229 with container_io_manager .handle_input_exception (io_context , started_at ):
243- res = io_context .call_finalized_function ()
244-
245230 # TODO(erikbern): any exception below shouldn't be considered a user exception
246231 if io_context .finalized_function .is_generator :
247- if not inspect .isgenerator (res ):
248- raise InvalidError (f"Generator function returned value of type { type (res )} " )
249-
232+ gen = io_context .call_generator_sync ()
250233 # Send up to this many outputs at a time.
251234 current_function_call_id = execution_context .current_function_call_id ()
252235 assert current_function_call_id is not None # Set above.
@@ -256,25 +239,20 @@ def run_input_sync(io_context: IOContext) -> None:
256239 with container_io_manager .generator_output_sender (
257240 current_function_call_id ,
258241 current_attempt_token ,
259- io_context .finalized_function . data_format ,
242+ io_context ._generator_output_format () ,
260243 generator_queue ,
261244 ):
262245 item_count = 0
263- for value in res :
246+ for value in gen :
264247 container_io_manager ._queue_put (generator_queue , value )
265248 item_count += 1
266249
267- message = api_pb2 .GeneratorDone (items_total = item_count )
268- container_io_manager .push_outputs (io_context , started_at , message , api_pb2 .DATA_FORMAT_GENERATOR_DONE )
269- else :
270- if inspect .iscoroutine (res ) or inspect .isgenerator (res ) or inspect .isasyncgen (res ):
271- raise InvalidError (
272- f"Sync (non-generator) function return value of type { type (res )} ."
273- " You might need to use @app.function(..., is_generator=True)."
274- )
275- container_io_manager .push_outputs (
276- io_context , started_at , res , io_context .finalized_function .data_format
250+ container_io_manager ._send_outputs (
251+ started_at , io_context .output_items_generator_done (started_at , item_count )
277252 )
253+ else :
254+ values = io_context .call_function_sync ()
255+ container_io_manager .push_outputs (io_context , started_at , values )
278256 reset_context ()
279257
280258 if container_io_manager .input_concurrency_enabled :
0 commit comments