-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #50 from vivianrwu/saxml-httpserver
Add userguide for HTTP server and Saxml
- Loading branch information
Showing
3 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
FROM python:3.10 | ||
|
||
ENV SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL True | ||
|
||
RUN set -e | ||
|
||
RUN apt -y update && apt install -y \ | ||
apt-transport-https \ | ||
curl \ | ||
gnupg patch python3-pip | ||
|
||
RUN git clone https://github.com/google/saxml.git && \ | ||
cd saxml && \ | ||
git checkout r1.0.0 | ||
|
||
RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key --keyring /usr/share/keyrings/bazel-archive-keyring.gpg add - | ||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
RUN echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN apt -y update && apt install -y bazel-5.4.0 google-cloud-cli | ||
RUN update-alternatives --install /usr/bin/bazel bazel /usr/bin/bazel-5.4.0 20 | ||
|
||
RUN pip3 install -U pip numpy | ||
RUN pip3 install --upgrade google-api-python-client | ||
|
||
COPY . saxml/saxml/httpserver | ||
WORKDIR saxml | ||
|
||
RUN bazel build saxml/client/python:sax.cc --compile_one_dependency | ||
|
||
ENV PYTHONPATH "${PYTHONPATH}:/saxml/bazel-bin/saxml/client/python" | ||
|
||
EXPOSE 8888 | ||
|
||
ENV PYTHONUNBUFFERED 1 | ||
|
||
WORKDIR /saxml/saxml/httpserver | ||
|
||
CMD ["python3.10", "http_server.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Inferencing using Saxml and an HTTP Server | ||
|
||
## Background | ||
|
||
[Saxml](https://github.com/google/saxml) is an experimental system that serves [Paxml](https://github.com/google/paxml), [JAX](https://github.com/google/jax), and [PyTorch](https://pytorch.org/) models for inference. A Sax cell (aka Sax cluster) consists of an admin server and a group of model servers. The admin server keeps track of model servers, assigns published models to model servers to serve, and helps clients locate model servers serving specific published models. | ||
|
||
In order to interact with the Sax cluster today, users can use the command line tool, [saxutil](https://github.com/google/saxml#use-sax), or interact directly with the [Sax client](https://github.com/google/saxml/tree/main/saxml/client/). | ||
|
||
This tutorial uses an HTTP Server to handle HTTP requests to Sax, supporting features such as model publishing, listing, updating, unpublishing, and generating predictions. The HTTP server uses the [Python Sax client](https://github.com/google/saxml/tree/main/saxml/client/python) in order to communicate with the Sax cluster and handle routing within the Sax system. With an HTTP server, interaction with Sax can also expand to further than at the VM-level. For example, integration with GKE and load balancing will enable requests to Sax from inside and outside the GKE cluster. | ||
|
||
**This tutorial focuses on the deployment of the HTTP server and assumes you have already deployed a Sax Admin Server and Sax Model Server according to the [OSS SAX Docker Guide](https://github.com/google/saxml/tree/main/saxml/tools/docker)** | ||
|
||
### Build Dockerfile.http | ||
|
||
Build the HTTP Server image: | ||
|
||
``` | ||
docker build -f Dockerfile.http -t sax-http . | ||
``` | ||
|
||
### Run the HTTP Server Locally | ||
|
||
If you haven't already, create a GCS Bucket to store Sax Cluster information: | ||
|
||
``` | ||
GSBUCKET=${USER}-sax-data | ||
gcloud storage buckets create gs://${GSBUCKET} | ||
``` | ||
|
||
``` | ||
docker run -e SAX_ROOT=gs://${GSBUCKET}/sax-root -p 8888:8888 -it sax-http | ||
``` | ||
|
||
In another terminal: | ||
|
||
``` | ||
$ curl localhost:8888 | ||
``` | ||
|
||
You will see the output below: | ||
|
||
``` | ||
{ | ||
"message": "HTTP Server for SAX Client" | ||
} | ||
``` | ||
|
||
### Publish a model | ||
|
||
``` | ||
$ curl --request POST \ | ||
-s \ | ||
localhost:8888/publish \ | ||
--data ' | ||
{ | ||
"model": "/sax/test/lm2b", | ||
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest", | ||
"checkpoint": "None", | ||
"replicas": 1 | ||
} | ||
' | ||
``` | ||
|
||
You will see the output below: | ||
|
||
``` | ||
{ | ||
"model": "/sax/test/lm2b", | ||
"path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest", | ||
"checkpoint": "None", | ||
"replicas": 1 | ||
} | ||
``` | ||
|
||
### List a Sax Cell | ||
|
||
``` | ||
$ curl --request GET \ | ||
-s \ | ||
localhost:8888/listcell \ | ||
--data ' | ||
{ | ||
"model": "/sax/test/lm2b" | ||
} | ||
' | ||
``` | ||
You will see the output below: | ||
|
||
``` | ||
{ | ||
"model": "/sax/test/lm2b", | ||
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest", | ||
"checkpoint": "None", | ||
"max_replicas": 1, | ||
"active_replicas": 1 | ||
} | ||
``` | ||
|
||
### Generate a prediction | ||
|
||
``` | ||
$ json_payload=$(cat << EOF | ||
{ | ||
"model": "/sax/test/lm2b", | ||
"query": "Q: Who is Harry Potter's mom? A: " | ||
} | ||
EOF | ||
) | ||
$ curl --request POST \ | ||
--header "Content-type: application/json" \ | ||
-s \ | ||
localhost:8888/generate \ | ||
--data "$json_payload" | ||
``` | ||
|
||
The result should be printed in the terminal | ||
|
||
### Unpublish a model | ||
|
||
``` | ||
$ curl --request POST \ | ||
-s \ | ||
localhost:8888/unpublish \ | ||
--data ' | ||
{ | ||
"model": "/sax/test/lm2b" | ||
} | ||
' | ||
``` | ||
|
||
You will see the output below: | ||
|
||
``` | ||
{ | ||
"model": "/sax/test/lm2b" | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
"""HTTP Server to interact with SAX Cluster, SAX Admin Server, and SAX Model Server.""" | ||
|
||
import http.server | ||
import json | ||
import sax | ||
|
||
class Server(http.server.BaseHTTPRequestHandler): | ||
"""Handler for HTTP Server.""" | ||
|
||
invalid_res = { | ||
'message': "Invalid Request" | ||
} | ||
get_dict = {"/", "/listcell"} | ||
post_dict = {"/publish", "/unpublish", "/generate"} | ||
put_dict = {"/update"} | ||
|
||
def success_res(self, res): | ||
self.send_response(200) | ||
self.end_headers() | ||
self.wfile.write(json.dumps(res, indent=4).encode('utf-8')) | ||
self.wfile.write('\n'.encode('utf-8')) | ||
return | ||
|
||
def error_res(self, e): | ||
self.send_response(400) | ||
self.end_headers() | ||
error = {'Error': str(e)} | ||
self.wfile.write(json.dumps(error, indent=4).encode('utf-8')) | ||
self.wfile.write('\n'.encode('utf-8')) | ||
return | ||
|
||
def do_GET(self): | ||
"""Handles GET requests.""" | ||
|
||
if self.path not in self.get_dict: | ||
self.send_response(400) | ||
self.end_headers() | ||
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8')) | ||
self.wfile.write('\n'.encode('utf-8')) | ||
return | ||
|
||
if self.path == '/': | ||
default_res = {'message': 'HTTP Server for SAX Client'} | ||
self.success_res(default_res) | ||
return | ||
|
||
content_length = int(self.headers['content-length']) | ||
data = self.rfile.read(content_length).decode('utf-8') | ||
params = json.loads(data) | ||
|
||
if self.path == '/listcell': | ||
"""List details about a published model.""" | ||
|
||
if len(params) != 1: | ||
self.error_res("Provide model for list cell") | ||
return | ||
|
||
try: | ||
model = params['model'] | ||
details = sax.ListDetail(model) | ||
details_res = { | ||
'model': model, | ||
'model_path': details.model, | ||
'checkpoint': details.ckpt, | ||
'max_replicas': details.max_replicas, | ||
'active_replicas': details.active_replicas, | ||
} | ||
self.success_res(details_res) | ||
|
||
except Exception as e: | ||
self.error_res(e) | ||
|
||
def do_POST(self): | ||
"""Handles POST requests.""" | ||
|
||
if self.path not in self.post_dict: | ||
self.send_response(400) | ||
self.end_headers() | ||
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8')) | ||
self.wfile.write('\n'.encode('utf-8')) | ||
return | ||
|
||
content_length = int(self.headers['content-length']) | ||
data = self.rfile.read(content_length).decode('utf-8') | ||
params = json.loads(data) | ||
|
||
if self.path == '/publish': | ||
"""Publishes a model.""" | ||
|
||
if len(params) != 4: | ||
self.error_res("Provide model, model path, checkpoint, and replica number for publish") | ||
return | ||
|
||
try: | ||
model = params['model'] | ||
path = params['model_path'] | ||
ckpt = params['checkpoint'] | ||
replicas = int(params['replicas']) | ||
sax.Publish(model, path, ckpt, replicas) | ||
publish_res = { | ||
'model': model, | ||
'path': path, | ||
'checkpoint': ckpt, | ||
'replicas': replicas, | ||
} | ||
self.success_res(publish_res) | ||
|
||
except Exception as e: | ||
self.error_res(e) | ||
|
||
if self.path == '/unpublish': | ||
"""Unpublishes a model.""" | ||
|
||
if len(params) != 1: | ||
self.error_res("Provide model for unpublish") | ||
return | ||
|
||
try: | ||
model = params['model'] | ||
sax.Unpublish(model) | ||
unpublish_res = { | ||
'model': model, | ||
} | ||
self.success_res(unpublish_res) | ||
|
||
except Exception as e: | ||
self.error_res(e) | ||
|
||
if self.path == '/generate': | ||
"""Generates a text input using a published language model.""" | ||
|
||
if len(params) != 2: | ||
self.error_res("Provide model and query for generate") | ||
return | ||
|
||
try: | ||
model = params['model'] | ||
query = params['query'] | ||
sax.ListDetail(model) | ||
model_open = sax.Model(model) | ||
lm = model_open.LM() | ||
res = lm.Generate(query) | ||
generate_res = { | ||
'generate_response': res, | ||
} | ||
self.success_res(generate_res) | ||
|
||
except Exception as e: | ||
self.error_res(e) | ||
|
||
def do_PUT(self): | ||
"""Handles PUT requests.""" | ||
|
||
if self.path not in self.put_dict: | ||
self.send_response(400) | ||
self.end_headers() | ||
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8')) | ||
self.wfile.write('\n'.encode('utf-8')) | ||
return | ||
|
||
content_length = int(self.headers['content-length']) | ||
data = self.rfile.read(content_length).decode('utf-8') | ||
params = json.loads(data) | ||
|
||
if self.path == '/update': | ||
"""Updates a model.""" | ||
|
||
if len(params) != 4: | ||
self.error_res("Provide model, model path, checkpoint, and replica number for update") | ||
return | ||
|
||
try: | ||
model = params['model'] | ||
path = params['model_path'] | ||
ckpt = params['checkpoint'] | ||
replicas = int(params['replicas']) | ||
sax.Update(model, path, ckpt, replicas) | ||
update_res = { | ||
'model': model, | ||
'path': path, | ||
'checkpoint': ckpt, | ||
'replicas': replicas, | ||
} | ||
self.success_res(update_res) | ||
|
||
except Exception as e: | ||
self.error_res(e) | ||
|
||
s = http.server.HTTPServer(('0.0.0.0', 8888), Server) | ||
s.serve_forever() |