Skip to content

Commit 2b27b66

Browse files
authored
Adds support for cbor2 serialization for function calling [SDK-584] (#3474)
Deployed functions will now support both pickle and cbor inputs, and try to mirror the output format with the input format when allowed. In addition, functions will now advertise which input and output formats they support to allow a lookup to introspect this. This enables future versions of modal clients (in particular libmodal clients in other languages) to use cbor instead of pickle for payloads.
1 parent 2813429 commit 2b27b66

20 files changed

+1075
-272
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
{
22
"python.testing.unittestEnabled": false,
33
"python.testing.pytestEnabled": true,
4+
"python.testing.pytestArgs": [
5+
"test"
6+
],
47
}

modal/_container_entrypoint.py

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@
3232
_PartialFunctionFlags,
3333
)
3434
from modal._serialization import deserialize, deserialize_params
35-
from modal._utils.async_utils import TaskContext, synchronizer
35+
from modal._utils.async_utils import TaskContext, aclosing, synchronizer
3636
from modal._utils.function_utils import (
3737
callable_has_non_self_params,
3838
)
3939
from modal.app import App, _App
4040
from modal.client import Client, _Client
4141
from modal.config import logger
42-
from modal.exception import ExecutionError, InputCancellation, InvalidError
42+
from modal.exception import ExecutionError, InputCancellation
4343
from modal.running_app import RunningApp, running_app_from_layout
4444
from modal_proto import api_pb2
4545

@@ -184,17 +184,13 @@ def call_function(
184184
batch_wait_ms: int,
185185
):
186186
async def run_input_async(io_context: IOContext) -> None:
187-
started_at = time.time()
188187
reset_context = execution_context._set_current_context_ids(
189188
io_context.input_ids, io_context.function_call_ids, io_context.attempt_tokens
190189
)
190+
started_at = time.time()
191191
async with container_io_manager.handle_input_exception.aio(io_context, started_at):
192-
res = io_context.call_finalized_function()
193192
# TODO(erikbern): any exception below shouldn't be considered a user exception
194193
if io_context.finalized_function.is_generator:
195-
if not inspect.isasyncgen(res):
196-
raise InvalidError(f"Async generator function returned value of type {type(res)}")
197-
198194
# Send up to this many outputs at a time.
199195
current_function_call_id = execution_context.current_function_call_id()
200196
assert current_function_call_id is not None # Set above.
@@ -204,33 +200,24 @@ async def run_input_async(io_context: IOContext) -> None:
204200
async with container_io_manager.generator_output_sender(
205201
current_function_call_id,
206202
current_attempt_token,
207-
io_context.finalized_function.data_format,
203+
io_context._generator_output_format(),
208204
generator_queue,
209205
):
210206
item_count = 0
211-
async for value in res:
212-
await container_io_manager._queue_put.aio(generator_queue, value)
213-
item_count += 1
207+
async with aclosing(io_context.call_generator_async()) as gen:
208+
async for value in gen:
209+
await container_io_manager._queue_put.aio(generator_queue, value)
210+
item_count += 1
214211

215-
message = api_pb2.GeneratorDone(items_total=item_count)
216-
await container_io_manager.push_outputs.aio(
217-
io_context,
218-
started_at,
219-
message,
220-
api_pb2.DATA_FORMAT_GENERATOR_DONE,
212+
await container_io_manager._send_outputs.aio(
213+
started_at, io_context.output_items_generator_done(started_at, item_count)
221214
)
222215
else:
223-
if not inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
224-
raise InvalidError(
225-
f"Async (non-generator) function returned value of type {type(res)}"
226-
" You might need to use @app.function(..., is_generator=True)."
227-
)
228-
value = await res
216+
value = await io_context.call_function_async()
229217
await container_io_manager.push_outputs.aio(
230218
io_context,
231219
started_at,
232220
value,
233-
io_context.finalized_function.data_format,
234221
)
235222
reset_context()
236223

@@ -240,13 +227,9 @@ def run_input_sync(io_context: IOContext) -> None:
240227
io_context.input_ids, io_context.function_call_ids, io_context.attempt_tokens
241228
)
242229
with container_io_manager.handle_input_exception(io_context, started_at):
243-
res = io_context.call_finalized_function()
244-
245230
# TODO(erikbern): any exception below shouldn't be considered a user exception
246231
if io_context.finalized_function.is_generator:
247-
if not inspect.isgenerator(res):
248-
raise InvalidError(f"Generator function returned value of type {type(res)}")
249-
232+
gen = io_context.call_generator_sync()
250233
# Send up to this many outputs at a time.
251234
current_function_call_id = execution_context.current_function_call_id()
252235
assert current_function_call_id is not None # Set above.
@@ -256,25 +239,20 @@ def run_input_sync(io_context: IOContext) -> None:
256239
with container_io_manager.generator_output_sender(
257240
current_function_call_id,
258241
current_attempt_token,
259-
io_context.finalized_function.data_format,
242+
io_context._generator_output_format(),
260243
generator_queue,
261244
):
262245
item_count = 0
263-
for value in res:
246+
for value in gen:
264247
container_io_manager._queue_put(generator_queue, value)
265248
item_count += 1
266249

267-
message = api_pb2.GeneratorDone(items_total=item_count)
268-
container_io_manager.push_outputs(io_context, started_at, message, api_pb2.DATA_FORMAT_GENERATOR_DONE)
269-
else:
270-
if inspect.iscoroutine(res) or inspect.isgenerator(res) or inspect.isasyncgen(res):
271-
raise InvalidError(
272-
f"Sync (non-generator) function return value of type {type(res)}."
273-
" You might need to use @app.function(..., is_generator=True)."
274-
)
275-
container_io_manager.push_outputs(
276-
io_context, started_at, res, io_context.finalized_function.data_format
250+
container_io_manager._send_outputs(
251+
started_at, io_context.output_items_generator_done(started_at, item_count)
277252
)
253+
else:
254+
values = io_context.call_function_sync()
255+
container_io_manager.push_outputs(io_context, started_at, values)
278256
reset_context()
279257

280258
if container_io_manager.input_concurrency_enabled:

modal/_functions.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ async def create(
150150
args,
151151
kwargs,
152152
stub,
153-
max_object_size_bytes=function._max_object_size_bytes,
154-
method_name=function._use_method_name,
153+
function=function,
155154
function_call_invocation_type=function_call_invocation_type,
156155
)
157156

@@ -439,8 +438,7 @@ async def create(
439438
args,
440439
kwargs,
441440
control_plane_stub,
442-
max_object_size_bytes=function._max_object_size_bytes,
443-
method_name=function._use_method_name,
441+
function=function,
444442
)
445443

446444
request = api_pb2.AttemptStartRequest(
@@ -698,6 +696,7 @@ def from_local(
698696
experimental_options: Optional[dict[str, str]] = None,
699697
_experimental_proxy_ip: Optional[str] = None,
700698
_experimental_custom_scaling_factor: Optional[float] = None,
699+
restrict_output: bool = False,
701700
) -> "_Function":
702701
"""mdmd:hidden
703702
@@ -834,17 +833,23 @@ def from_local(
834833
is_web_endpoint=is_web_endpoint,
835834
ignore_first_argument=True,
836835
)
836+
if is_web_endpoint:
837+
method_input_formats = [api_pb2.DATA_FORMAT_ASGI]
838+
method_output_formats = [api_pb2.DATA_FORMAT_ASGI]
839+
else:
840+
method_input_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
841+
if restrict_output:
842+
method_output_formats = [api_pb2.DATA_FORMAT_CBOR]
843+
else:
844+
method_output_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
845+
837846
method_definition = api_pb2.MethodDefinition(
838847
webhook_config=partial_function.params.webhook_config,
839848
function_type=function_type,
840849
function_name=function_name,
841850
function_schema=method_schema,
842-
supported_input_formats=[api_pb2.DATA_FORMAT_ASGI]
843-
if is_web_endpoint
844-
else [api_pb2.DATA_FORMAT_PICKLE],
845-
supported_output_formats=[api_pb2.DATA_FORMAT_ASGI]
846-
if is_web_endpoint
847-
else [api_pb2.DATA_FORMAT_PICKLE],
851+
supported_input_formats=method_input_formats,
852+
supported_output_formats=method_output_formats,
848853
)
849854
method_definitions[method_name] = method_definition
850855

@@ -869,16 +874,18 @@ def _deps(only_explicit_mounts=False) -> list[_Object]:
869874
return deps
870875

871876
if info.is_service_class():
872-
# classes don't have data formats themselves - methods do
877+
# classes don't have data formats themselves - input/output formats are set per method above
873878
supported_input_formats = []
874879
supported_output_formats = []
875880
elif webhook_config is not None:
876881
supported_input_formats = [api_pb2.DATA_FORMAT_ASGI]
877882
supported_output_formats = [api_pb2.DATA_FORMAT_ASGI]
878883
else:
879-
# TODO: add CBOR support
880-
supported_input_formats = [api_pb2.DATA_FORMAT_PICKLE]
881-
supported_output_formats = [api_pb2.DATA_FORMAT_PICKLE]
884+
supported_input_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
885+
if restrict_output:
886+
supported_output_formats = [api_pb2.DATA_FORMAT_CBOR]
887+
else:
888+
supported_output_formats = [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]
882889

883890
async def _preload(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
884891
assert resolver.client and resolver.client.stub

0 commit comments

Comments
 (0)