Skip to content

Commit

Permalink
Add checkpoint conversion script, argument descriptions, and grpc int…
Browse files Browse the repository at this point in the history
…eraction (GoogleCloudPlatform#531)

* Add checkpoint conversion script and argument descriptions

* edit model server logs

* modify checkpoint job and add deployment grpc instructions

* rename conversion script and pass optional version argument

* remove grpc specific deployment file for jetstream
  • Loading branch information
vivianrwu authored Apr 25, 2024
1 parent feb8291 commit 36a972f
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 34 deletions.
29 changes: 29 additions & 0 deletions tutorials-and-examples/inference-servers/checkpoints/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Ubuntu:22.04
# Use Ubuntu 22.04 from Docker Hub.
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
git \
python3.10 \
python3-pip \
curl \
gnupg

RUN update-alternatives --install \
/usr/bin/python3 python3 /usr/bin/python3.10 1

RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN apt -y update && apt install -y google-cloud-cli

RUN pip install kaggle

RUN git clone https://github.com/pylint-dev/pylint.git

COPY checkpoint_converter.sh /usr/bin/
RUN chmod +x /usr/bin/checkpoint_converter.sh
ENTRYPOINT ["/usr/bin/checkpoint_converter.sh"]
18 changes: 18 additions & 0 deletions tutorials-and-examples/inference-servers/checkpoints/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Checkpoint conversion

The `checkpoint_entrypoint.sh` script overviews how to convert your inference checkpoint for various model servers.

Build the checkpoint conversion Dockerfile
```
docker build -t inference-checkpoint .
docker tag inference-checkpoint gcr.io/${PROJECT_ID}/inference-checkpoint:latest
docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest
```

Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-inference/checkpoint-job.yaml) and pass the following arguments

```
- -i=INFERENCE_SERVER
- -b=BUCKET_NAME
- -m=MODEL_PATH
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#!/bin/bash

export KAGGLE_CONFIG_DIR="/kaggle"
INFERENCE_SERVER="jetstream-maxtext"
BUCKET_NAME=""
MODEL_PATH=""

print_usage() {
printf "Usage: $0 [ -b BUCKET_NAME ] [ -i INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -v VERSION ]"
}

print_inference_server_unknown() {
printf "Enter a valid inference server [ -i INFERENCE_SERVER ]"
printf "Valid options: jetstream-maxtext"
}

download_kaggle_checkpoint() {
BUCKET_NAME=$1
MODEL_NAME=$2
VARIATION_NAME=$3
MODEL_PATH=$4

mkdir -p /data/${MODEL_NAME}_${VARIATION_NAME}
kaggle models instances versions download ${MODEL_PATH} --untar -p /data/${MODEL_NAME}_${VARIATION_NAME}
echo -e "\nCompleted extraction to /data/${MODEL_NAME}_${VARIATION_NAME}"

gcloud storage rsync --recursive --no-clobber /data/${MODEL_NAME}_${VARIATION_NAME} gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}
echo -e "\nCompleted copy of data to gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}"
}

convert_maxtext_checkpoint() {
BUCKET_NAME=$1
MODEL_NAME=$2
VARIATION_NAME=$3
MODEL_SIZE=$4
MAXTEXT_VERSION=$5

if [ -z $MAXTEXT_VERSION ]; then
MAXTEXT_VERSION=jetstream-v0.2.0
fi

git clone https://github.com/google/maxtext.git

# checkout stable MaxText commit
cd maxtext
git checkout ${MAXTEXT_VERSION}
python3 -m pip install -r requirements.txt
echo -e "\Cloned MaxText repository and completed installing requirements"

python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE}
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"

RUN_NAME=0

python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME}-${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}/0/items base_output_directory=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}
echo -e "\nCompleted unscanning checkpoint to gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}/${RUN_NAME}/checkpoints/0/items"
}


while getopts 'b:i:m:' flag; do
case "${flag}" in
b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
i) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
m) MODEL_PATH="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
*) print_usage
exit 1 ;;
esac
done

if [ -z $BUCKET_NAME ]; then
echo "BUCKET_NAME is empty, please provide a GSBucket"
fi

if [ -z $MODEL_PATH ]; then
echo "MODEL_PATH is empty, please provide the model path"
fi

echo "Inference server is ${INFERENCE_SERVER}"
MODEL_NAME=$(echo ${MODEL_PATH} | awk -F'/' '{print $2}')
VARIATION_NAME=$(echo ${MODEL_PATH} | awk -F'/' '{print $4}')
MODEL_SIZE=$(echo ${VARIATION_NAME} | awk -F'-' '{print $1}')

case ${INFERENCE_SERVER} in

jetstream-maxtext)
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_PATH"
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION"
;;
*) print_inference_server_unknown
exit 1 ;;
esac
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV JETSTREAM_VERSION=v0.2.0

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand All @@ -16,6 +17,7 @@ RUN update-alternatives --install \

RUN git clone https://github.com/google/JetStream.git && \
cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
pip install -e .

RUN pip3 install uvicorn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ class GenerateRequest(pydantic.BaseModel):
app = fastapi.FastAPI()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1000)

channel = grpc.insecure_channel("127.0.0.1:9000")
grpc.channel_ready_future(channel).result()
stub = jetstream_pb2_grpc.OrchestratorStub(channel)


@app.get("/")
def root():
"""Root path for MaxText + Jetstream HTTP Server."""
Expand All @@ -64,10 +59,9 @@ async def generate(request: GenerateRequest):
priority=request.priority,
max_tokens=request.max_tokens,
)
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
executor, generate_prompt, stub, request
)

future = executor.submit(generate_prompt, request)
response = await future.result()
response = {"response": response}
response = fastapi.Response(
content=json.dumps(response, indent=4), media_type="application/json"
Expand All @@ -78,13 +72,16 @@ async def generate(request: GenerateRequest):
raise fastapi.HTTPException(status_code=500, detail=str(e))


def generate_prompt(
orchestrator_stub: jetstream_pb2_grpc.OrchestratorStub,
async def generate_prompt(
request: jetstream_pb2.DecodeRequest,
):
"""Generate a prompt."""
response = orchestrator_stub.Decode(request)
output = ""
for token_list in response:
output += str(token_list.response[0])
return output

options = [("grpc.keepalive_timeout_ms", 10000)]
async with grpc.aio.insecure_channel("127.0.0.1:9000", options=options) as channel:
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
response = stub.Decode(request)
output = ""
async for token_list in response:
output += str(token_list.response[0])
return output
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV MAXTEXT_VERSION=jetstream-v0.2.0
ENV JETSTREAM_VERSION=v0.2.0

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand All @@ -18,9 +20,11 @@ RUN git clone https://github.com/google/maxtext.git && \
git clone https://github.com/google/JetStream.git

RUN cd maxtext/ && \
git checkout ${MAXTEXT_VERSION} && \
bash setup.sh

RUN cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
pip install -e .

COPY maxengine_server_entrypoint.sh /usr/bin/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,74 @@ $ kubectl annotate serviceaccount default \
iam.gke.io/gcp-service-account=jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com
```

### Create a Cloud Storage bucket to store the Gemma-7b model checkpoint

```
gcloud storage buckets create $BUCKET_NAME
```

### Get access to the model

Access the [model consent page](https://www.kaggle.com/models/google/gemma) and request access with your Kaggle Account. Accept the Terms and Conditions.

Obtain a Kaggle API token by going to your Kaggle settings and under the `API` section, click `Create New Token`. A `kaggle.json` file will be downloaded.

Create a Secret to store the Kaggle credentials
```
kubectl create secret generic kaggle-secret \
--from-file=kaggle.json
```

## Convert the Gemma-7b checkpoint

You can follow [these instructions](https://github.com/google/maxtext/blob/main/end_to_end/test_gemma.sh#L14) to convert the Gemma-7b checkpoint from orbax to a MaxText compatible checkpoint.
To convert the Gemma-7b checkpoint, we have created a job `checkpoint-job.yaml` that does the following:
1. Download the base orbax checkpoint from kaggle
2. Upload the checkpoint to a Cloud Storage bucket
3. Convert the checkpoint to a MaxText compatible checkpoint
4. Unscan the checkpoint to be used for inference

In the manifest, ensure the value of the BUCKET_NAME environment variable is the name of the Cloud Storage bucket you created above. Do not include the `gs://` prefix.

Apply the manifest:
```
kubectl apply -f checkpoint-job.yaml
```

Observe the logs:
```
kubectl logs -f jobs/data-loader-7b
```

You should see the following output once the job has completed. This will take around 10 minutes:
```
Successfully generated decode checkpoint at: gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
+ echo -e '\nCompleted unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items'
Completed unscanning checkpoint to gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
```

## Deploy Maxengine Server and HTTP Server

In this example, we will deploy a Maxengine server targeting Gemma-7b model. You can use the provided Maxengine server and HTTP server images already in `deployment.yaml` or [build your own](#optionals).

Add desired overrides to your yaml file by editing the `args` in `deployment.yaml`. You can reference the [MaxText base config file](https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml) on what values can be overridden.
Add desired overrides to your yaml file by editing the `args` in `deployment.yaml`. You can reference the [MaxText base config file](https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml) on what values can be overridden.

Configure the model checkpoint by adding `load_parameters_path=<GCS bucket path to your checkpoint>` under `args`, you can optionally deploy `deployment.yaml` without adding the checkpoint path.
In the manifest, ensure the value of the BUCKET_NAME is the name of the Cloud Storage bucket that was used when converting your checkpoint.

Argument descriptions:
```
tokenizer_path: The file path to your model’s tokenizer
load_parameters_path: Your checkpoint path (GSBucket)
per_device_batch_size: Decoding batch size per device (1 TPU chip = 1 device)
max_prefill_predict_length: Maximum length for the prefill when doing autoregression
max_target_length: Maximum sequence length
model_name: Model name
ici_fsdp_parallelism: The number of shards for FSDP parallelism
ici_autoregressive_parallelism: The number of shards for autoregressive parallelism
ici_tensor_parallelism: The number of shards for tensor parallelism
weight_dtype: Weight data type (e.g. bfloat16)
scan_layers: Scan layers boolean flag
```

Deploy the manifest file for the Maxengine server and HTTP server:
```
Expand All @@ -97,20 +154,15 @@ Wait for the containers to finish creating:
kubectl get deployment
NAME READY UP-TO-DATE AVAILABLE AGE
maxengine-server 1/1 1 1 2m45s
maxengine-server 2/2 2 2 ##s
```

Check the Maxengine pod’s logs, and verify the compilation is done. You will see similar logs of the following:
```
kubectl logs deploy/maxengine-server -f -c maxengine-server
2024-03-14 06:03:37,750 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(generate) in 8.170992851257324 sec
2024-03-14 06:03:38,779 - root - INFO - Generate engine 0 step 1 - slots free : 96 / 96, took 11807.21ms
2024-03-14 06:03:38,780 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=96
2024-03-14 06:03:38,831 - root - INFO - Detokenising generate step 0 took 46.34ms
2024-03-14 06:03:39,793 - root - INFO - Generate engine 0 step 2 - slots free : 96 / 96, took 1013.51ms
2024-03-14 06:03:39,793 - root - INFO - Generate thread making a decision with: prefill_backlog=0 generate_free_slots=96
2024-03-14 06:03:39,797 - root - INFO - Generate engine 0 step 3 - slots free : 96 / 96, took 3.35ms
2024-03-29 17:09:08,047 - jax._src.dispatch - DEBUG - Finished XLA compilation of jit(initialize) in 0.26236414909362793 sec
2024-03-29 17:09:08,150 - root - INFO - ---------Generate params 0 loaded.---------
```

Check http server logs, this can take a couple minutes:
Expand All @@ -128,7 +180,7 @@ INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
Run the following command to set up port forwarding to the http server:

```
kubectl port-forward svc/jetstream-http-svc 8000:8000
kubectl port-forward svc/jetstream-svc 8000:8000
```

In a new terminal, send a request to the server:
Expand Down Expand Up @@ -164,4 +216,16 @@ Build the HTTP Server Dockerfile from [here](../http-server) and upload to your
docker build -t jetstream-http .
docker tag jetstream-http gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
docker push gcr.io/${PROJECT_ID}/jetstream/maxtext/jetstream-http:latest
```
```

### Interact with the Maxengine server directly using gRPC

The Jetstream HTTP server is great for initial testing and validating end-to-end requests and responses. If you would like to interact directly with the Maxengine server directly for use cases such as [benchmarking](https://github.com/google/JetStream/tree/main/benchmarks), you can do so by following the Jetstream benchmarking setup and applying the `deployment.yaml` manifest file and interacting with the Jetstream gRPC server at port 9000.

```
kubectl apply -f deployment.yaml
kubectl port-forward svc/jetstream-svc 9000:9000
```

To run benchmarking, pass in the flag `--server 127.0.0.1` when running the benchmarking script.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
apiVersion: batch/v1
kind: Job
metadata:
name: data-loader-7b
spec:
ttlSecondsAfterFinished: 30
template:
spec:
restartPolicy: Never
containers:
- name: inference-checkpoint
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0
args:
- -b=BUCKET_NAME
- -m=google/gemma/maxtext/7b-it/2
volumeMounts:
- mountPath: "/kaggle/"
name: kaggle-credentials
readOnly: true
resources:
requests:
google.com/tpu: 8
limits:
google.com/tpu: 8
nodeSelector:
cloud.google.com/gke-tpu-topology: 2x4
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
volumes:
- name: kaggle-credentials
secret:
defaultMode: 0400
secretName: kaggle-secret
Loading

0 comments on commit 36a972f

Please sign in to comment.