Skip to content

Commit

Permalink
Add source files and user guide for single host inferencing with JetS…
Browse files Browse the repository at this point in the history
…tream and MaxText on GKE (#358)

* Add source files and user guide for single host inferencing with JetStream and MaxText on GKE

* fix Dockerfile ubuntu version and entrypoint; address yaml and README comments

* restructure folder

* Delete jetstream directory
  • Loading branch information
vivianrwu committed Mar 16, 2024
1 parent 8eb320d commit 10e8abf
Show file tree
Hide file tree
Showing 6 changed files with 377 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Ubuntu:22.04
# Use Ubuntu 22.04 from Docker Hub.
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
git \
python3.10 \
python3-pip

RUN update-alternatives --install \
/usr/bin/python3 python3 /usr/bin/python3.10 1

RUN git clone https://github.com/google/JetStream.git && \
cd /JetStream && \
pip install -e .

RUN pip3 install uvicorn
RUN pip3 install fastapi
RUN pip3 install pydantic
ENV PYTHONDONTWRITEBYTECODE=1

COPY http_server.py /maxengine/httpserver/
WORKDIR /maxengine/httpserver

CMD ["uvicorn", "http_server:app", "--host=0.0.0.0", "--port=8000"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HTTP Server to interact with MaxText + JetStream Server."""

import asyncio
import concurrent.futures
import json
import logging
from typing import Optional

import fastapi
import grpc
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
import pydantic


class GenerateRequest(pydantic.BaseModel):
server: Optional[str] = "127.0.0.1"
port: Optional[str] = "9000"
session_cache: Optional[str] = ""
prompt: Optional[str] = "This is an example prompt"
priority: Optional[int] = 0
max_tokens: Optional[int] = 100


app = fastapi.FastAPI()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1000)

channel = grpc.insecure_channel("127.0.0.1:9000")
grpc.channel_ready_future(channel).result()
stub = jetstream_pb2_grpc.OrchestratorStub(channel)


@app.get("/")
def root():
"""Root path for MaxText + Jetstream HTTP Server."""
response = {"message": "HTTP Server for MaxText + JetStream"}
response = fastapi.Response(
content=json.dumps(response, indent=4), media_type="application/json"
)
return response


@app.post("/generate", status_code=200)
async def generate(request: GenerateRequest):
"""Generate a prompt."""
try:
request = jetstream_pb2.DecodeRequest(
session_cache=request.session_cache,
additional_text=request.prompt,
priority=request.priority,
max_tokens=request.max_tokens,
)
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
executor, generate_prompt, stub, request
)
response = {"response": response}
response = fastapi.Response(
content=json.dumps(response, indent=4), media_type="application/json"
)
return response
except Exception as e:
logging.exception("Exception in generate")
raise fastapi.HTTPException(status_code=500, detail=str(e))


def generate_prompt(
orchestrator_stub: jetstream_pb2_grpc.OrchestratorStub,
request: jetstream_pb2.DecodeRequest,
):
"""Generate a prompt."""
response = orchestrator_stub.Decode(request)
output = ""
for token_list in response:
output += str(token_list.response[0])
return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Ubuntu:22.04
# Use Ubuntu 22.04 from Docker Hub.
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
git \
python3.10 \
python3-pip

RUN update-alternatives --install \
/usr/bin/python3 python3 /usr/bin/python3.10 1

RUN git clone https://github.com/google/maxtext.git && \
git clone https://github.com/google/JetStream.git

RUN cd maxtext/ && \
bash setup.sh

RUN cd /JetStream && \
pip install -e .

COPY maxengine_server_entrypoint.sh /usr/bin/

RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh

ENTRYPOINT ["/usr/bin/maxengine_server_entrypoint.sh"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/bash
cd /maxtext
python3 MaxText/maxengine_server.py \
MaxText/configs/base.yml $@
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Serve a LLM using a single-host TPU on GKE with JetStream and MaxText

## Background
This tutorial shows you how to serve a large language model (LLM) using Tensor Processing Units (TPUs) on Google Kubernetes Engine (GKE) with [JetStream](https://github.com/google/JetStream) and [MaxText](https://github.com/google/maxtext).

## Setup

### Set default environment variables
```
gcloud config set project [PROJECT_ID]
export PROJECT_ID=$(gcloud config get project)
export REGION=[COMPUTE_REGION]
export ZONE=[ZONE]
```

### Create GKE cluster and node pool
```
# Create zonal cluster with 2 CPU nodes
gcloud container clusters create jetstream-maxtext \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--workload-pool=${PROJECT_ID}.svc.id.goog \
--release-channel=rapid \
--num-nodes=2
# Create one v5e TPU pool with topology 2x4 (1 TPU node with 8 chips)
gcloud container node-pools create tpu \
--cluster=jetstream-maxtext \
--zone=${ZONE} \
--num-nodes=2 \
--machine-type=ct5lp-hightpu-8t \
--project=${PROJECT_ID}
```
You have created the following resources:

- Standard cluster with 2 CPU nodes.
- One v5e TPU node pool with 2 nodes, each with 8 chips.

### Configure Applications to use Workload Identity
Prerequisite: make sure you have the following roles

```
roles/container.admin
roles/iam.serviceAccountAdmin
```

Follow [these steps](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity#authenticating_to) to configure the IAM and Kubernetes service account:

```
# Get credentials for your cluster
$ gcloud container clusters get-credentials jetstream-maxtext \
--zone=${ZONE}
# Create an IAM service account.
$ gcloud iam service-accounts create jetstream-iam-sa
# Ensure the IAM service account has necessary roles. Here we add roles/storage.objectUser for gcs bucket access.
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
--member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
--role roles/storage.objectUser
$ gcloud projects add-iam-policy-binding ${PROJECT_ID} \
--member "serviceAccount:jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com" \
--role roles/storage.insightsCollectorService
# Allow the Kubernetes default service account to impersonate the IAM service account
$ gcloud iam service-accounts add-iam-policy-binding jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com \
--role roles/iam.workloadIdentityUser \
--member "serviceAccount:${PROJECT_ID}.svc.id.goog[default/default]"
# Annotate the Kubernetes service account with the email address of the IAM service account.
$ kubectl annotate serviceaccount default \
iam.gke.io/gcp-service-account=jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
```

## Convert the Gemma-7b checkpoint

You can follow [these instructions](https://github.com/google/maxtext/blob/main/end_to_end/test_gemma.sh#L14) to convert the Gemma-7b checkpoint from orbax to a MaxText compatible checkpoint.

## Deploy Maxengine Server and HTTP Server

In this example, we will deploy a Maxengine server targeting Gemma-7b model. You can use the provided Maxengine server and HTTP server images already in `deployment.yaml` or [build your own](#optionals).

Add desired overrides to your yaml file by editing the `args` in `deployment.yaml`. You can reference the [MaxText base config file](https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml) on what values can be overridden.

Configure the model checkpoint by adding `load_parameters_path=<GCS bucket path to your checkpoint>` under `args`, you can optionally deploy `deployment.yaml` without adding the checkpoint path.

Deploy the manifest file for the Maxengine server and HTTP server:
```
kubectl apply -f deployment.yaml
```

## Verify the deployment

Wait for the containers to finish creating:
```
kubectl get deployment
NAME READY UP-TO-DATE AVAILABLE AGE
maxengine-server 1/1 1 1 2m45s
```

Check the Maxengine pod’s logs, and verify the compilation is done. You will see similar logs of the following:
```
kubectl logs deploy/maxengine-server -f -c maxengine-server
2024-03-14 06:03:37,750 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(generate) in 8.170992851257324 sec
2024-03-14 06:03:38,779 - root - INFO - Generate engine 0 step 1 - slots free : 96 / 96, took 11807.21ms
2024-03-14 06:03:38,780 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=96
2024-03-14 06:03:38,831 - root - INFO - Detokenising generate step 0 took 46.34ms
2024-03-14 06:03:39,793 - root - INFO - Generate engine 0 step 2 - slots free : 96 / 96, took 1013.51ms
2024-03-14 06:03:39,793 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=96
2024-03-14 06:03:39,797 - root - INFO - Generate engine 0 step 3 - slots free : 96 / 96, took 3.35ms
```

Check http server logs, this can take a couple minutes:
```
kubectl logs deploy/maxengine-server -f -c jetstream-http
INFO: Started server process [1]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
```

### Send sample requests

Run the following command to set up port forwarding to the http server:

```
kubectl port-forward svc/jetstream-http-svc 8000:8000
```

In a new terminal, send a request to the server:

```
curl --request POST --header "Content-type: application/json" -s localhost:8000/generate --data '{
"prompt": "What are the top 5 programming languages",
"max_tokens": 200
}'
```

The output should be similar to the following:
```
{
"response": " in 2021?\n\nThe answer to this question is not as simple as it may seem. There are many factors that go into determining the most popular programming languages, and they can change from year to year.\n\nIn this blog post, we will discuss the top 5 programming languages in 2021 and why they are so popular.\n\n<h2><strong>1. Python</strong></h2>\n\nPython is a high-level programming language that is used for web development, data analysis, and machine learning. It is one of the most popular languages in the world and is used by many companies such as Google, Facebook, and Instagram.\n\nPython is easy to learn and has a large community of developers who are always willing to help out.\n\n<h2><strong>2. Java</strong></h2>\n\nJava is a general-purpose programming language that is used for web development, mobile development, and game development. It is one of the most popular languages in the"
}
```

## Optionals
### Build and upload Maxengine Server image

Build the Maxengine Server from [here](../maxengine-server) and upload to your project
```
docker build -t maxengine-server .
docker tag maxengine-server gcr.io/${PROJECT_ID}/jetstream/maxtext/maxengine-server:latest
docker push gcr.io/${PROJECT_ID}/jetstream/maxtext/maxengine-server:latest
```

### Build and upload HTTP Server image

Build the HTTP Server Dockerfile from [here](../http-server) and upload to your project
```
docker build -t jetstream-http .
docker tag jetstream-http gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
docker push gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: maxengine-server
spec:
replicas: 2
selector:
matchLabels:
app: maxengine-server
template:
metadata:
labels:
app: maxengine-server
spec:
nodeSelector:
cloud.google.com/gke-tpu-topology: 2x4
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
containers:
- name: maxengine-server
image: us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:dev
securityContext:
privileged: true
args:
- model_name=gemma-7b
- tokenizer_path=assets/tokenizer.gemma
- per_device_batch_size=12
- max_prefill_predict_length=1024
- max_target_length=2048
- steps=10
- async_checkpointing=false
- ici_fsdp_parallelism=1
- ici_autoregressive_parallelism=-1
- scan_layers=false
- weight_dtype=bfloat16
ports:
- containerPort: 9000
resources:
requests:
google.com/tpu: 8
limits:
google.com/tpu: 8
- name: jetstream-http
image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:dev
ports:
- containerPort: 8000
---
apiVersion: v1
kind: Service
metadata:
name: jetstream-http-svc
spec:
selector:
app: maxengine-server
ports:
- protocol: TCP
port: 8000
targetPort: 8000

0 comments on commit 10e8abf

Please sign in to comment.