Skip to content

Commit

Permalink
Allow specifying model overrides when starting a sax server with a st…
Browse files Browse the repository at this point in the history
…atic

PiperOrigin-RevId: 667681444
Change-Id: I2e2bbde7cf542e72f9abd67a52e47f5c59649e5b
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 26, 2024
1 parent 04cf3df commit 666f3c9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
1 change: 1 addition & 0 deletions saxml/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ pytype_strict_library(
":model_service_base",
":servable_model_registry",
":spmd_backend",
"//pyglib/flags/contrib:json_flag",
"//saxml/protobuf:modelet_py_pb2",
"//saxml/protobuf:modelet_py_pb2_grpc",
"//saxml/server/jax:jax_spmd_backend",
Expand Down
44 changes: 33 additions & 11 deletions saxml/server/model_service_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""The main module of model services."""

import json
import re
from typing import Optional, Sequence

Expand Down Expand Up @@ -52,9 +53,7 @@
_PLATFORM_TOPOLOGY = flags.DEFINE_string(
'platform_topology', '1', 'Optional topology description.'
)
_TAGS = flags.DEFINE_list(
'tags', [], 'Optional list of string tags.'
)
_TAGS = flags.DEFINE_list('tags', [], 'Optional list of string tags.')
_JAX_PROFILER_PORT = flags.DEFINE_integer(
'jax_profiler_port',
None,
Expand All @@ -66,9 +65,7 @@
_JAX_CACHE_DIR = flags.DEFINE_string(
'jax_cache_dir',
None,
(
'If set, tries to use cached jits at this directory.'
),
'If set, tries to use cached jits at this directory.',
)

# Internal tuning knobs. Consult sax-dev@ before tweaking these.
Expand All @@ -81,6 +78,13 @@
_CHECKPOINTS = flags.DEFINE_list(
'checkpoints', [], 'Optional model checkpoints to load at startup time.'
)
_MODEL_CONFIG_OVERRIDES = flags.DEFINE_list(
'model_config_overrides',
[],
'Optional model config overrides for the models loaded at startup time. The'
' format is comma-separated JSON for each model. For example:'
' \'{"BATCH_SIZE": 4, "BATCH_WAIT_SECS": 30},{"NUM_SAMPLES": 4}\'',
)
_DETERMINISTIC_RNG = flags.DEFINE_bool(
'deterministic_rng',
False,
Expand All @@ -104,12 +108,13 @@


@flags.multi_flags_validator(
['models', 'model_keys', 'checkpoints'],
['models', 'model_keys', 'checkpoints', 'model_config_overrides'],
message='models, model_keys, and checkpoints must have the same length',
)
def _check_model_checkpoint_flags(flags_dict):
return len(flags_dict['models']) == len(flags_dict['checkpoints']) and (
len(flags_dict['models']) == len(flags_dict['model_keys'])
and len(flags_dict['models']) == len(flags_dict['model_config_overrides'])
)


Expand All @@ -118,6 +123,7 @@ def _load_static_model(
model: str,
model_key: str,
checkpoint: str,
model_config_overrides: dict[str, str],
channel_creds: Optional[grpc.ChannelCredentials],
) -> None:
"""Loads statically specified model to a started service."""
Expand All @@ -132,7 +138,10 @@ def _load_static_model(
grpc.channel_ready_future(channel).result(timeout=10)
stub = modelet_pb2_grpc.ModeletStub(channel)
req = modelet_pb2.LoadRequest(
model_key=model_key, model_path=model, checkpoint_path=checkpoint
model_key=model_key,
model_path=model,
checkpoint_path=checkpoint,
overrides=model_config_overrides,
)
try:
stub.Load(req)
Expand Down Expand Up @@ -232,10 +241,23 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None:
logging.info('Starting runner %d.', jax.process_index())
runner.start()
if is_primary:
for model, key, ckpt in zip(
_MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value
model_config_overrides: list[dict[str, str]] = [
{key: str(value) for key, value in json.loads(x).items()}
for x in _MODEL_CONFIG_OVERRIDES.value
]
if not model_config_overrides:
model_config_overrides = [
dict[str, str]() for _ in range(len(_MODELS.value))
]
for model, key, ckpt, overrides in zip(
_MODELS.value,
_MODEL_KEYS.value,
_CHECKPOINTS.value,
model_config_overrides,
):
_load_static_model(_PORT.value, model, key, ckpt, channel_creds)
_load_static_model(
_PORT.value, model, key, ckpt, overrides, channel_creds
)
runner.on_initial_models_load_completion()
runner.wait()
finally:
Expand Down

0 comments on commit 666f3c9

Please sign in to comment.