forked from GoogleCloudPlatform/ai-on-gke
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add checkpoint conversion script, argument descriptions, and grpc int…
…eraction (GoogleCloudPlatform#531) * Add checkpoint conversion script and argument descriptions * edit model server logs * modify checkpoint job and add deployment grpc instructions * rename conversion script and pass optional version argument * remove grpc specific deployment file for jetstream
- Loading branch information
Showing
9 changed files
with
279 additions
and
34 deletions.
There are no files selected for viewing
29 changes: 29 additions & 0 deletions
29
tutorials-and-examples/inference-servers/checkpoints/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Ubuntu:22.04 | ||
# Use Ubuntu 22.04 from Docker Hub. | ||
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04 | ||
FROM ubuntu:22.04 | ||
|
||
ENV DEBIAN_FRONTEND=noninteractive | ||
|
||
RUN apt -y update && apt install -y --no-install-recommends \ | ||
ca-certificates \ | ||
git \ | ||
python3.10 \ | ||
python3-pip \ | ||
curl \ | ||
gnupg | ||
|
||
RUN update-alternatives --install \ | ||
/usr/bin/python3 python3 /usr/bin/python3.10 1 | ||
|
||
RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list | ||
RUN apt -y update && apt install -y google-cloud-cli | ||
|
||
RUN pip install kaggle | ||
|
||
RUN git clone https://github.com/pylint-dev/pylint.git | ||
|
||
COPY checkpoint_converter.sh /usr/bin/ | ||
RUN chmod +x /usr/bin/checkpoint_converter.sh | ||
ENTRYPOINT ["/usr/bin/checkpoint_converter.sh"] |
18 changes: 18 additions & 0 deletions
18
tutorials-and-examples/inference-servers/checkpoints/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Checkpoint conversion | ||
|
||
The `checkpoint_entrypoint.sh` script overviews how to convert your inference checkpoint for various model servers. | ||
|
||
Build the checkpoint conversion Dockerfile | ||
``` | ||
docker build -t inference-checkpoint . | ||
docker tag inference-checkpoint gcr.io/${PROJECT_ID}/inference-checkpoint:latest | ||
docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest | ||
``` | ||
|
||
Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-inference/checkpoint-job.yaml) and pass the following arguments | ||
|
||
``` | ||
- -i=INFERENCE_SERVER | ||
- -b=BUCKET_NAME | ||
- -m=MODEL_PATH | ||
``` |
92 changes: 92 additions & 0 deletions
92
tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#!/bin/bash | ||
|
||
export KAGGLE_CONFIG_DIR="/kaggle" | ||
INFERENCE_SERVER="jetstream-maxtext" | ||
BUCKET_NAME="" | ||
MODEL_PATH="" | ||
|
||
print_usage() { | ||
printf "Usage: $0 [ -b BUCKET_NAME ] [ -i INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -v VERSION ]" | ||
} | ||
|
||
print_inference_server_unknown() { | ||
printf "Enter a valid inference server [ -i INFERENCE_SERVER ]" | ||
printf "Valid options: jetstream-maxtext" | ||
} | ||
|
||
download_kaggle_checkpoint() { | ||
BUCKET_NAME=$1 | ||
MODEL_NAME=$2 | ||
VARIATION_NAME=$3 | ||
MODEL_PATH=$4 | ||
|
||
mkdir -p /data/${MODEL_NAME}_${VARIATION_NAME} | ||
kaggle models instances versions download ${MODEL_PATH} --untar -p /data/${MODEL_NAME}_${VARIATION_NAME} | ||
echo -e "\nCompleted extraction to /data/${MODEL_NAME}_${VARIATION_NAME}" | ||
|
||
gcloud storage rsync --recursive --no-clobber /data/${MODEL_NAME}_${VARIATION_NAME} gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME} | ||
echo -e "\nCompleted copy of data to gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}" | ||
} | ||
|
||
convert_maxtext_checkpoint() { | ||
BUCKET_NAME=$1 | ||
MODEL_NAME=$2 | ||
VARIATION_NAME=$3 | ||
MODEL_SIZE=$4 | ||
MAXTEXT_VERSION=$5 | ||
|
||
if [ -z $MAXTEXT_VERSION ]; then | ||
MAXTEXT_VERSION=jetstream-v0.2.0 | ||
fi | ||
|
||
git clone https://github.com/google/maxtext.git | ||
|
||
# checkout stable MaxText commit | ||
cd maxtext | ||
git checkout ${MAXTEXT_VERSION} | ||
python3 -m pip install -r requirements.txt | ||
echo -e "\Cloned MaxText repository and completed installing requirements" | ||
|
||
python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE} | ||
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}" | ||
|
||
RUN_NAME=0 | ||
|
||
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml force_unroll=true model_name=${MODEL_NAME}-${MODEL_SIZE} async_checkpointing=false run_name=${RUN_NAME} load_parameters_path=gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}/0/items base_output_directory=gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME} | ||
echo -e "\nCompleted unscanning checkpoint to gs://${BUCKET_NAME}/final/unscanned/${MODEL_NAME}_${VARIATION_NAME}/${RUN_NAME}/checkpoints/0/items" | ||
} | ||
|
||
|
||
while getopts 'b:i:m:' flag; do | ||
case "${flag}" in | ||
b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; | ||
i) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; | ||
m) MODEL_PATH="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; | ||
v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; | ||
*) print_usage | ||
exit 1 ;; | ||
esac | ||
done | ||
|
||
if [ -z $BUCKET_NAME ]; then | ||
echo "BUCKET_NAME is empty, please provide a GSBucket" | ||
fi | ||
|
||
if [ -z $MODEL_PATH ]; then | ||
echo "MODEL_PATH is empty, please provide the model path" | ||
fi | ||
|
||
echo "Inference server is ${INFERENCE_SERVER}" | ||
MODEL_NAME=$(echo ${MODEL_PATH} | awk -F'/' '{print $2}') | ||
VARIATION_NAME=$(echo ${MODEL_PATH} | awk -F'/' '{print $4}') | ||
MODEL_SIZE=$(echo ${VARIATION_NAME} | awk -F'-' '{print $1}') | ||
|
||
case ${INFERENCE_SERVER} in | ||
|
||
jetstream-maxtext) | ||
download_kaggle_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_PATH" | ||
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION" | ||
;; | ||
*) print_inference_server_unknown | ||
exit 1 ;; | ||
esac |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
...nd-examples/inference-servers/jetstream/maxtext/single-host-inference/checkpoint-job.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
apiVersion: batch/v1 | ||
kind: Job | ||
metadata: | ||
name: data-loader-7b | ||
spec: | ||
ttlSecondsAfterFinished: 30 | ||
template: | ||
spec: | ||
restartPolicy: Never | ||
containers: | ||
- name: inference-checkpoint | ||
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0 | ||
args: | ||
- -b=BUCKET_NAME | ||
- -m=google/gemma/maxtext/7b-it/2 | ||
volumeMounts: | ||
- mountPath: "/kaggle/" | ||
name: kaggle-credentials | ||
readOnly: true | ||
resources: | ||
requests: | ||
google.com/tpu: 8 | ||
limits: | ||
google.com/tpu: 8 | ||
nodeSelector: | ||
cloud.google.com/gke-tpu-topology: 2x4 | ||
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice | ||
volumes: | ||
- name: kaggle-credentials | ||
secret: | ||
defaultMode: 0400 | ||
secretName: kaggle-secret |
Oops, something went wrong.