Skip to content

Commit

Permalink
Query and forward parameter overrides.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 663419069
Change-Id: I7bd088367801ef479b5c99d9e421eeb3a9bef594
  • Loading branch information
Sax Authors authored and copybara-github committed Aug 15, 2024
1 parent fb07395 commit e294e8b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 24 deletions.
4 changes: 0 additions & 4 deletions saxml/server/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ pytype_strict_library(
":model_service_base",
":servable_model_registry",
":spmd_backend",
"//saxml/client/python:sax",
"//saxml/protobuf:admin_py_pb2",
"//saxml/protobuf:admin_py_pb2_grpc",
"//saxml/protobuf:modelet_py_pb2",
"//saxml/protobuf:modelet_py_pb2_grpc",
"//saxml/server/jax:jax_spmd_backend",
Expand All @@ -324,7 +321,6 @@ pytype_strict_library(
"//third_party/py/grpcio",
"//third_party/py/jax",
"//third_party/py/tensorflow:tensorflow_no_contrib",
"@pybind11_abseil//pybind11_abseil:status",
],
)

Expand Down
22 changes: 2 additions & 20 deletions saxml/server/model_service_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,13 @@
import grpc
import jax
from jax.experimental.compilation_cache import compilation_cache
from saxml.client.python import sax
from saxml.protobuf import modelet_pb2
from saxml.protobuf import modelet_pb2_grpc
from saxml.server import model_service_base
from saxml.server import servable_model_registry
from saxml.server import spmd_backend
import tensorflow as tf

from google3.third_party.pybind11_abseil import status as absl_status


_SAX_CELL = flags.DEFINE_string(
'sax_cell',
None,
Expand Down Expand Up @@ -123,23 +119,11 @@ def _load_static_model(
model_key: str,
checkpoint: str,
channel_creds: Optional[grpc.ChannelCredentials],
sax_cell: Optional[str],
) -> None:
"""Loads statically specified model to a started service."""
logging.info(
'Loading key %s, model %s, checkpoint %s.', model_key, model, checkpoint
)
# Get overrides that might have been provided via 'saxutil publish' and apply
# them.
overrides = {}
if sax_cell:
try:
overrides = sax.ListDetail(model_key).overrides
logging.info('Got overrides: %s', overrides)
except absl_status.StatusNotOk as e:
logging.warning(
"Could not get model details, not applying overrides: '%s'", e
)
if channel_creds is None:
channel = grpc.insecure_channel(f'localhost:{port}')
else:
Expand All @@ -148,8 +132,7 @@ def _load_static_model(
grpc.channel_ready_future(channel).result(timeout=10)
stub = modelet_pb2_grpc.ModeletStub(channel)
req = modelet_pb2.LoadRequest(
model_key=model_key, model_path=model, checkpoint_path=checkpoint,
overrides=overrides,
model_key=model_key, model_path=model, checkpoint_path=checkpoint
)
try:
stub.Load(req)
Expand Down Expand Up @@ -252,8 +235,7 @@ def run(channel_creds: Optional[grpc.ChannelCredentials]) -> None:
for model, key, ckpt in zip(
_MODELS.value, _MODEL_KEYS.value, _CHECKPOINTS.value
):
_load_static_model(_PORT.value, model, key, ckpt, channel_creds,
_SAX_CELL.value)
_load_static_model(_PORT.value, model, key, ckpt, channel_creds)
runner.on_initial_models_load_completion()
runner.wait()
finally:
Expand Down

0 comments on commit e294e8b

Please sign in to comment.