From 666f3c9f9d5272f2ec4e630023ac2810f1fb2a6c Mon Sep 17 00:00:00 2001 From: Sax Authors Date: Mon, 26 Aug 2024 12:51:33 -0700 Subject: [PATCH] Allow specifying model overrides when starting a sax server with a static PiperOrigin-RevId: 667681444 Change-Id: I2e2bbde7cf542e72f9abd67a52e47f5c59649e5b --- saxml/server/BUILD | 1 + saxml/server/model_service_main.py | 44 ++++++++++++++++++++++-------- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/saxml/server/BUILD b/saxml/server/BUILD index 0e96146a..86ab57ec 100644 --- a/saxml/server/BUILD +++ b/saxml/server/BUILD @@ -313,6 +313,7 @@ pytype_strict_library( ":model_service_base", ":servable_model_registry", ":spmd_backend", + "//pyglib/flags/contrib:json_flag", "//saxml/protobuf:modelet_py_pb2", "//saxml/protobuf:modelet_py_pb2_grpc", "//saxml/server/jax:jax_spmd_backend", diff --git a/saxml/server/model_service_main.py b/saxml/server/model_service_main.py index 3a4d2e02..012ee7a3 100644 --- a/saxml/server/model_service_main.py +++ b/saxml/server/model_service_main.py @@ -13,6 +13,7 @@ # limitations under the License. """The main module of model services.""" +import json import re from typing import Optional, Sequence @@ -52,9 +53,7 @@ _PLATFORM_TOPOLOGY = flags.DEFINE_string( 'platform_topology', '1', 'Optional topology description.' ) -_TAGS = flags.DEFINE_list( - 'tags', [], 'Optional list of string tags.' -) +_TAGS = flags.DEFINE_list('tags', [], 'Optional list of string tags.') _JAX_PROFILER_PORT = flags.DEFINE_integer( 'jax_profiler_port', None, @@ -66,9 +65,7 @@ _JAX_CACHE_DIR = flags.DEFINE_string( 'jax_cache_dir', None, - ( - 'If set, tries to use cached jits at this directory.' - ), + 'If set, tries to use cached jits at this directory.', ) # Internal tuning knobs. Consult sax-dev@ before tweaking these. @@ -81,6 +78,13 @@ _CHECKPOINTS = flags.DEFINE_list( 'checkpoints', [], 'Optional model checkpoints to load at startup time.' ) +_MODEL_CONFIG_OVERRIDES = flags.DEFINE_list( + 'model_config_overrides', + [], + 'Optional model config overrides for the models loaded at startup time. The' + ' format is comma-separated JSON for each model. For example:' + ' \'{"BATCH_SIZE": 4, "BATCH_WAIT_SECS": 30},{"NUM_SAMPLES": 4}\'', +) _DETERMINISTIC_RNG = flags.DEFINE_bool( 'deterministic_rng', False, @@ -104,12 +108,13 @@ @flags.multi_flags_validator( - ['models', 'model_keys', 'checkpoints'], + ['models', 'model_keys', 'checkpoints', 'model_config_overrides'], message='models, model_keys, and checkpoints must have the same length', ) def _check_model_checkpoint_flags(flags_dict): return len(flags_dict['models']) == len(flags_dict['checkpoints']) and ( len(flags_dict['models']) == len(flags_dict['model_keys']) + and len(flags_dict['models']) == len(flags_dict['model_config_overrides']) ) @@ -118,6 +123,7 @@ def _load_static_model( model: str, model_key: str, checkpoint: str, + model_config_overrides: dict[str, str], channel_creds: Optional[grpc.ChannelCredentials], ) -> None: """Loads statically specified model to a started service.""" @@ -132,7 +138,10 @@ def _load_static_model( grpc.channel_ready_future(channel).result(timeout=10) stub = modelet_pb2_grpc.ModeletStub(channel) req = modelet_pb2.LoadRequest( - model_key=model_key, model_path=model, checkpoint_path=checkpoint + model_key=model_key, + model_path=model, + checkpoint_path=checkpoint, + overrides=model_config_overrides, ) try: stub.Load(req) @@ -232,10 +241,23 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None: logging.info('Starting runner %d.', jax.process_index()) runner.start() if is_primary: - for model, key, ckpt in zip( - _MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value + model_config_overrides: list[dict[str, str]] = [ + {key: str(value) for key, value in json.loads(x).items()} + for x in _MODEL_CONFIG_OVERRIDES.value + ] + if not model_config_overrides: + model_config_overrides = [ + dict[str, str]() for _ in range(len(_MODELS.value)) + ] + for model, key, ckpt, overrides in zip( + _MODELS.value, + _MODEL_KEYS.value, + _CHECKPOINTS.value, + model_config_overrides, ): - _load_static_model(_PORT.value, model, key, ckpt, channel_creds) + _load_static_model( + _PORT.value, model, key, ckpt, overrides, channel_creds + ) runner.on_initial_models_load_completion() runner.wait() finally: