Skip to content

Commit

Permalink
Merge pull request #50 from vivianrwu/saxml-httpserver
Browse files Browse the repository at this point in the history
Add userguide for HTTP server and Saxml
  • Loading branch information
richardsliu committed Sep 29, 2023
2 parents 4bed81c + 41e7208 commit 9c3cb55
Show file tree
Hide file tree
Showing 3 changed files with 366 additions and 0 deletions.
39 changes: 39 additions & 0 deletions saxml-on-gke/httpserver/Dockerfile.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
FROM python:3.10

ENV SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL True

RUN set -e

RUN apt -y update && apt install -y \
apt-transport-https \
curl \
gnupg patch python3-pip

RUN git clone https://github.com/google/saxml.git && \
cd saxml && \
git checkout r1.0.0

RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key --keyring /usr/share/keyrings/bazel-archive-keyring.gpg add -
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
RUN echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN apt -y update && apt install -y bazel-5.4.0 google-cloud-cli
RUN update-alternatives --install /usr/bin/bazel bazel /usr/bin/bazel-5.4.0 20

RUN pip3 install -U pip numpy
RUN pip3 install --upgrade google-api-python-client

COPY . saxml/saxml/httpserver
WORKDIR saxml

RUN bazel build saxml/client/python:sax.cc --compile_one_dependency

ENV PYTHONPATH "${PYTHONPATH}:/saxml/bazel-bin/saxml/client/python"

EXPOSE 8888

ENV PYTHONUNBUFFERED 1

WORKDIR /saxml/saxml/httpserver

CMD ["python3.10", "http_server.py"]
137 changes: 137 additions & 0 deletions saxml-on-gke/httpserver/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Inferencing using Saxml and an HTTP Server

## Background

[Saxml](https://github.com/google/saxml) is an experimental system that serves [Paxml](https://github.com/google/paxml), [JAX](https://github.com/google/jax), and [PyTorch](https://pytorch.org/) models for inference. A Sax cell (aka Sax cluster) consists of an admin server and a group of model servers. The admin server keeps track of model servers, assigns published models to model servers to serve, and helps clients locate model servers serving specific published models.

In order to interact with the Sax cluster today, users can use the command line tool, [saxutil](https://github.com/google/saxml#use-sax), or interact directly with the [Sax client](https://github.com/google/saxml/tree/main/saxml/client/).

This tutorial uses an HTTP Server to handle HTTP requests to Sax, supporting features such as model publishing, listing, updating, unpublishing, and generating predictions. The HTTP server uses the [Python Sax client](https://github.com/google/saxml/tree/main/saxml/client/python) in order to communicate with the Sax cluster and handle routing within the Sax system. With an HTTP server, interaction with Sax can also expand to further than at the VM-level. For example, integration with GKE and load balancing will enable requests to Sax from inside and outside the GKE cluster.

**This tutorial focuses on the deployment of the HTTP server and assumes you have already deployed a Sax Admin Server and Sax Model Server according to the [OSS SAX Docker Guide](https://github.com/google/saxml/tree/main/saxml/tools/docker)**

### Build Dockerfile.http

Build the HTTP Server image:

```
docker build -f Dockerfile.http -t sax-http .
```

### Run the HTTP Server Locally

If you haven't already, create a GCS Bucket to store Sax Cluster information:

```
GSBUCKET=${USER}-sax-data
gcloud storage buckets create gs://${GSBUCKET}
```

```
docker run -e SAX_ROOT=gs://${GSBUCKET}/sax-root -p 8888:8888 -it sax-http
```

In another terminal:

```
$ curl localhost:8888
```

You will see the output below:

```
{
"message": "HTTP Server for SAX Client"
}
```

### Publish a model

```
$ curl --request POST \
-s \
localhost:8888/publish \
--data '
{
"model": "/sax/test/lm2b",
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"replicas": 1
}
'
```

You will see the output below:

```
{
"model": "/sax/test/lm2b",
"path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"replicas": 1
}
```

### List a Sax Cell

```
$ curl --request GET \
-s \
localhost:8888/listcell \
--data '
{
"model": "/sax/test/lm2b"
}
'
```
You will see the output below:

```
{
"model": "/sax/test/lm2b",
"model_path": "saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest",
"checkpoint": "None",
"max_replicas": 1,
"active_replicas": 1
}
```

### Generate a prediction

```
$ json_payload=$(cat << EOF
{
"model": "/sax/test/lm2b",
"query": "Q: Who is Harry Potter's mom? A: "
}
EOF
)
$ curl --request POST \
--header "Content-type: application/json" \
-s \
localhost:8888/generate \
--data "$json_payload"
```

The result should be printed in the terminal

### Unpublish a model

```
$ curl --request POST \
-s \
localhost:8888/unpublish \
--data '
{
"model": "/sax/test/lm2b"
}
'
```

You will see the output below:

```
{
"model": "/sax/test/lm2b"
}
```
190 changes: 190 additions & 0 deletions saxml-on-gke/httpserver/http_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""HTTP Server to interact with SAX Cluster, SAX Admin Server, and SAX Model Server."""

import http.server
import json
import sax

class Server(http.server.BaseHTTPRequestHandler):
"""Handler for HTTP Server."""

invalid_res = {
'message': "Invalid Request"
}
get_dict = {"/", "/listcell"}
post_dict = {"/publish", "/unpublish", "/generate"}
put_dict = {"/update"}

def success_res(self, res):
self.send_response(200)
self.end_headers()
self.wfile.write(json.dumps(res, indent=4).encode('utf-8'))
self.wfile.write('\n'.encode('utf-8'))
return

def error_res(self, e):
self.send_response(400)
self.end_headers()
error = {'Error': str(e)}
self.wfile.write(json.dumps(error, indent=4).encode('utf-8'))
self.wfile.write('\n'.encode('utf-8'))
return

def do_GET(self):
"""Handles GET requests."""

if self.path not in self.get_dict:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8'))
self.wfile.write('\n'.encode('utf-8'))
return

if self.path == '/':
default_res = {'message': 'HTTP Server for SAX Client'}
self.success_res(default_res)
return

content_length = int(self.headers['content-length'])
data = self.rfile.read(content_length).decode('utf-8')
params = json.loads(data)

if self.path == '/listcell':
"""List details about a published model."""

if len(params) != 1:
self.error_res("Provide model for list cell")
return

try:
model = params['model']
details = sax.ListDetail(model)
details_res = {
'model': model,
'model_path': details.model,
'checkpoint': details.ckpt,
'max_replicas': details.max_replicas,
'active_replicas': details.active_replicas,
}
self.success_res(details_res)

except Exception as e:
self.error_res(e)

def do_POST(self):
"""Handles POST requests."""

if self.path not in self.post_dict:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8'))
self.wfile.write('\n'.encode('utf-8'))
return

content_length = int(self.headers['content-length'])
data = self.rfile.read(content_length).decode('utf-8')
params = json.loads(data)

if self.path == '/publish':
"""Publishes a model."""

if len(params) != 4:
self.error_res("Provide model, model path, checkpoint, and replica number for publish")
return

try:
model = params['model']
path = params['model_path']
ckpt = params['checkpoint']
replicas = int(params['replicas'])
sax.Publish(model, path, ckpt, replicas)
publish_res = {
'model': model,
'path': path,
'checkpoint': ckpt,
'replicas': replicas,
}
self.success_res(publish_res)

except Exception as e:
self.error_res(e)

if self.path == '/unpublish':
"""Unpublishes a model."""

if len(params) != 1:
self.error_res("Provide model for unpublish")
return

try:
model = params['model']
sax.Unpublish(model)
unpublish_res = {
'model': model,
}
self.success_res(unpublish_res)

except Exception as e:
self.error_res(e)

if self.path == '/generate':
"""Generates a text input using a published language model."""

if len(params) != 2:
self.error_res("Provide model and query for generate")
return

try:
model = params['model']
query = params['query']
sax.ListDetail(model)
model_open = sax.Model(model)
lm = model_open.LM()
res = lm.Generate(query)
generate_res = {
'generate_response': res,
}
self.success_res(generate_res)

except Exception as e:
self.error_res(e)

def do_PUT(self):
"""Handles PUT requests."""

if self.path not in self.put_dict:
self.send_response(400)
self.end_headers()
self.wfile.write(json.dumps(self.invalid_res).encode('utf-8'))
self.wfile.write('\n'.encode('utf-8'))
return

content_length = int(self.headers['content-length'])
data = self.rfile.read(content_length).decode('utf-8')
params = json.loads(data)

if self.path == '/update':
"""Updates a model."""

if len(params) != 4:
self.error_res("Provide model, model path, checkpoint, and replica number for update")
return

try:
model = params['model']
path = params['model_path']
ckpt = params['checkpoint']
replicas = int(params['replicas'])
sax.Update(model, path, ckpt, replicas)
update_res = {
'model': model,
'path': path,
'checkpoint': ckpt,
'replicas': replicas,
}
self.success_res(update_res)

except Exception as e:
self.error_res(e)

s = http.server.HTTPServer(('0.0.0.0', 8888), Server)
s.serve_forever()

0 comments on commit 9c3cb55

Please sign in to comment.