From ed5d11ad881905c47742f77594007291184c5d92 Mon Sep 17 00:00:00 2001 From: Sax Authors Date: Mon, 26 Aug 2024 15:19:27 -0700 Subject: [PATCH] Allow specifying model overrides when starting a sax server with a static PiperOrigin-RevId: 667735701 Change-Id: Ibeecdfb38a82cfa10d8be2f6a684cafef20aa19d --- saxml/server/BUILD | 1 - saxml/server/model_service_main.py | 44 ++++++++---------------------- 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/saxml/server/BUILD b/saxml/server/BUILD index 86ab57ec..0e96146a 100644 --- a/saxml/server/BUILD +++ b/saxml/server/BUILD @@ -313,7 +313,6 @@ pytype_strict_library( ":model_service_base", ":servable_model_registry", ":spmd_backend", - "//pyglib/flags/contrib:json_flag", "//saxml/protobuf:modelet_py_pb2", "//saxml/protobuf:modelet_py_pb2_grpc", "//saxml/server/jax:jax_spmd_backend", diff --git a/saxml/server/model_service_main.py b/saxml/server/model_service_main.py index 012ee7a3..3a4d2e02 100644 --- a/saxml/server/model_service_main.py +++ b/saxml/server/model_service_main.py @@ -13,7 +13,6 @@ # limitations under the License. """The main module of model services.""" -import json import re from typing import Optional, Sequence @@ -53,7 +52,9 @@ _PLATFORM_TOPOLOGY = flags.DEFINE_string( 'platform_topology', '1', 'Optional topology description.' ) -_TAGS = flags.DEFINE_list('tags', [], 'Optional list of string tags.') +_TAGS = flags.DEFINE_list( + 'tags', [], 'Optional list of string tags.' +) _JAX_PROFILER_PORT = flags.DEFINE_integer( 'jax_profiler_port', None, @@ -65,7 +66,9 @@ _JAX_CACHE_DIR = flags.DEFINE_string( 'jax_cache_dir', None, - 'If set, tries to use cached jits at this directory.', + ( + 'If set, tries to use cached jits at this directory.' + ), ) # Internal tuning knobs. Consult sax-dev@ before tweaking these. @@ -78,13 +81,6 @@ _CHECKPOINTS = flags.DEFINE_list( 'checkpoints', [], 'Optional model checkpoints to load at startup time.' ) -_MODEL_CONFIG_OVERRIDES = flags.DEFINE_list( - 'model_config_overrides', - [], - 'Optional model config overrides for the models loaded at startup time. The' - ' format is comma-separated JSON for each model. For example:' - ' \'{"BATCH_SIZE": 4, "BATCH_WAIT_SECS": 30},{"NUM_SAMPLES": 4}\'', -) _DETERMINISTIC_RNG = flags.DEFINE_bool( 'deterministic_rng', False, @@ -108,13 +104,12 @@ @flags.multi_flags_validator( - ['models', 'model_keys', 'checkpoints', 'model_config_overrides'], + ['models', 'model_keys', 'checkpoints'], message='models, model_keys, and checkpoints must have the same length', ) def _check_model_checkpoint_flags(flags_dict): return len(flags_dict['models']) == len(flags_dict['checkpoints']) and ( len(flags_dict['models']) == len(flags_dict['model_keys']) - and len(flags_dict['models']) == len(flags_dict['model_config_overrides']) ) @@ -123,7 +118,6 @@ def _load_static_model( model: str, model_key: str, checkpoint: str, - model_config_overrides: dict[str, str], channel_creds: Optional[grpc.ChannelCredentials], ) -> None: """Loads statically specified model to a started service.""" @@ -138,10 +132,7 @@ def _load_static_model( grpc.channel_ready_future(channel).result(timeout=10) stub = modelet_pb2_grpc.ModeletStub(channel) req = modelet_pb2.LoadRequest( - model_key=model_key, - model_path=model, - checkpoint_path=checkpoint, - overrides=model_config_overrides, + model_key=model_key, model_path=model, checkpoint_path=checkpoint ) try: stub.Load(req) @@ -241,23 +232,10 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None: logging.info('Starting runner %d.', jax.process_index()) runner.start() if is_primary: - model_config_overrides: list[dict[str, str]] = [ - {key: str(value) for key, value in json.loads(x).items()} - for x in _MODEL_CONFIG_OVERRIDES.value - ] - if not model_config_overrides: - model_config_overrides = [ - dict[str, str]() for _ in range(len(_MODELS.value)) - ] - for model, key, ckpt, overrides in zip( - _MODELS.value, - _MODEL_KEYS.value, - _CHECKPOINTS.value, - model_config_overrides, + for model, key, ckpt in zip( + _MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value ): - _load_static_model( - _PORT.value, model, key, ckpt, overrides, channel_creds - ) + _load_static_model(_PORT.value, model, key, ckpt, channel_creds) runner.on_initial_models_load_completion() runner.wait() finally: