From 0fd14fd965d5ea191e66877e4f0a023a000bfd3c Mon Sep 17 00:00:00 2001 From: vivianrwu Date: Thu, 11 Jul 2024 11:54:15 -0700 Subject: [PATCH] Add HuggingFace support for automated inference checkpoint conversion (#712) * Add HuggingFace support for automated inference checkpoint conversion * Add HuggingFace support for inference checkpoint conversion * fix llama checkpoint names * update containers to v0.2.3 / v0.2.2 * update containers to v0.2.3 / v0.2.2 --- .../inference-servers/checkpoints/Dockerfile | 4 +- .../inference-servers/checkpoints/README.md | 33 +++-- .../checkpoints/checkpoint_converter.sh | 115 +++++++++++++----- .../jetstream-pytorch-server/Dockerfile | 4 +- .../pytorch/single-host-inference/README.md | 89 ++++++++++++-- .../single-host-inference/checkpoint-job.yaml | 16 ++- .../single-host-inference/deployment.yaml | 10 +- .../single-host-inference/pd-deployment.yaml | 10 +- 8 files changed, 219 insertions(+), 62 deletions(-) diff --git a/tutorials-and-examples/inference-servers/checkpoints/Dockerfile b/tutorials-and-examples/inference-servers/checkpoints/Dockerfile index 918f7a586..f1c6cd871 100644 --- a/tutorials-and-examples/inference-servers/checkpoints/Dockerfile +++ b/tutorials-and-examples/inference-servers/checkpoints/Dockerfile @@ -20,7 +20,9 @@ RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyri RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list RUN apt -y update && apt install -y google-cloud-cli -RUN pip install kaggle +RUN pip install kaggle && \ +pip install huggingface_hub[cli] && \ +pip install google-jetstream COPY checkpoint_converter.sh /usr/bin/ RUN chmod +x /usr/bin/checkpoint_converter.sh diff --git a/tutorials-and-examples/inference-servers/checkpoints/README.md b/tutorials-and-examples/inference-servers/checkpoints/README.md index d5c79d3ce..8bbc0645a 100644 --- a/tutorials-and-examples/inference-servers/checkpoints/README.md +++ b/tutorials-and-examples/inference-servers/checkpoints/README.md @@ -11,20 +11,37 @@ docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-inference/checkpoint-job.yaml) and pass the following arguments -Jetstream + MaxText +## Jetstream + MaxText ``` -- -i=INFERENCE_SERVER +- -s=INFERENCE_SERVER - -b=BUCKET_NAME - -m=MODEL_PATH - -v=VERSION (Optional) ``` -Jetstream + Pytorch/XLA +## Jetstream + Pytorch/XLA ``` -- -i=INFERENCE_SERVER +- -s=INFERENCE_SERVER - -m=MODEL_PATH -- -q=QUANTIZE (Optional) -- -v=VERSION -- -1=EXTRA_PARAM_1 -- -2=EXTRA_PARAM_2 +- -n=MODEL_NAME +- -q=QUANTIZE_WEIGHTS (Optional) (default=False) +- -t=QUANTIZE_TYPE (Optional) (default=int8_per_channel) +- -v=VERSION (Optional) (default=jetstream-v0.2.3) +- -i=INPUT_DIRECTORY (Optional) +- -o=OUTPUT_DIRECTORY +- -h=HUGGINGFACE (Optional) (default=False) +``` + +## Argument descriptions: +``` +b) BUCKET_NAME: (str) GSBucket, without gs:// +s) INFERENCE_SERVER: (str) Inference server, ex. jetstream-maxtext, jetstream-pytorch +m) MODEL_PATH: (str) Model path, varies depending on inference server and location of base checkpoint +n) MODEL_NAME: (str) Model name, ex. llama-2, llama-3, gemma +h) HUGGINGFACE: (bool) Checkpoint is from HuggingFace. +q) QUANTIZE_WEIGHTS: (str) Whether to quantize weights +t) QUANTIZE_TYPE: (str) Quantization type, QUANTIZE_WEIGHTS must be set to true. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, +v) VERSION: (str) Version of inference server to override, ex. jetstream-v0.2.2, jetstream-v0.2.3 +i) INPUT_DIRECTORY: (str) Input checkpoint directory, likely a GSBucket path +o) OUTPUT_DIRECTORY: (str) Output checkpoint directory, likely a GSBucket path ``` \ No newline at end of file diff --git a/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh b/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh index c2c9a5f69..d52ae35ec 100644 --- a/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh +++ b/tutorials-and-examples/inference-servers/checkpoints/checkpoint_converter.sh @@ -1,16 +1,17 @@ #!/bin/bash export KAGGLE_CONFIG_DIR="/kaggle" +export HUGGINGFACE_TOKEN_DIR="/huggingface" INFERENCE_SERVER="jetstream-maxtext" BUCKET_NAME="" MODEL_PATH="" print_usage() { - printf "Usage: $0 [ -b BUCKET_NAME ] [ -i INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -q QUANTIZE ] [ -v VERSION ] [ -1 EXTRA_PARAM_1 ] [ -2 EXTRA_PARAM_2 ]" + printf "Usage: $0 [ -b BUCKET_NAME ] [ -s INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -n MODEL_NAME ] [ -h HUGGINGFACE ] [ -q QUANTIZE_WEIGHTS ] [ -t QUANTIZE_TYPE ] [ -v VERSION ] [ -i INPUT_DIRECTORY ] [ -o OUTPUT_DIRECTORY ]" } print_inference_server_unknown() { - printf "Enter a valid inference server [ -i INFERENCE_SERVER ]" + printf "Enter a valid inference server [ -s INFERENCE_SERVER ]" printf "Valid options: jetstream-maxtext, jetstream-pytorch" } @@ -43,6 +44,31 @@ download_kaggle_checkpoint() { echo -e "\nCompleted copy of data to gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}" } +download_huggingface_checkpoint() { + MODEL_PATH=$1 + MODEL_NAME=$2 + + INPUT_CKPT_DIR_LOCAL=/base/ + mkdir /base/ + huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN) + huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL} + + if [[ $MODEL_NAME == *"llama"* ]]; then + if [[ $MODEL_NAME == "llama-2" ]]; then + TOKENIZER_PATH=/base/tokenizer.model + if [[ $MODEL_PATH != *"hf"* ]]; then + HUGGINGFACE="False" + fi + else + TOKENIZER_PATH=/base/original/tokenizer.model + fi + elif [[ $MODEL_NAME == *"gemma"* ]]; then + TOKENIZER_PATH=/base/tokenizer.model + else + echo -e "Unclear of tokenizer.model for ${MODEL_NAME}. May have to manually upload." + fi +} + convert_maxtext_checkpoint() { BUCKET_NAME=$1 MODEL_NAME=$2 @@ -60,7 +86,7 @@ convert_maxtext_checkpoint() { cd maxtext git checkout ${MAXTEXT_VERSION} python3 -m pip install -r requirements.txt - echo -e "\Cloned MaxText repository and completed installing requirements" + echo -e "\nCloned MaxText repository and completed installing requirements" python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE} echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}" @@ -73,59 +99,92 @@ convert_maxtext_checkpoint() { convert_pytorch_checkpoint() { MODEL_PATH=$1 - INPUT_CKPT_DIR=$2 - OUTPUT_CKPT_DIR=$3 - QUANTIZE=$4 - PYTORCH_VERSION=$5 - JETSTREAM_VERSION=v0.2.2 + MODEL_NAME=$2 + HUGGINGFACE=$3 + INPUT_CKPT_DIR=$4 + OUTPUT_CKPT_DIR=$5 + QUANTIZE_TYPE=$6 + QUANTIZE_WEIGHTS=$7 + PYTORCH_VERSION=$8 if [ -z $PYTORCH_VERSION ]; then - PYTORCH_VERSION=jetstream-v0.2.2 + PYTORCH_VERSION=jetstream-v0.2.3 fi CKPT_PATH="$(echo ${INPUT_CKPT_DIR} | awk -F'gs://' '{print $2}')" BUCKET_NAME="$(echo ${CKPT_PATH} | awk -F'/' '{print $1}')" TO_REPLACE=gs://${BUCKET_NAME} - INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}} - OUTPUT_CKPT_DIR_LOCAL=/pt-ckpt/ - if [ -z $QUANTIZE ]; then - QUANTIZE="False" - fi + OUTPUT_CKPT_DIR_LOCAL=/pt-ckpt/ - git clone https://github.com/google/JetStream.git git clone https://github.com/google/jetstream-pytorch.git - cd JetStream - git checkout ${JETSTREAM_VERSION} - pip install -e # checkout stable Pytorch commit - cd ../jetstream-pytorch + cd /jetstream-pytorch git checkout ${PYTORCH_VERSION} bash install_everything.sh - export PYTHONPATH=$PYTHONPATH:$(pwd)/deps/xla/experimental/torch_xla2:$(pwd)/JetStream:$(pwd) + echo -e "\nCloned JetStream PyTorch repository and completed installing requirements" echo -e "\nRunning conversion script to convert model weights. This can take a couple minutes..." - python3 -m convert_checkpoints --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize=${QUANTIZE} + + if [ $HUGGINGFACE == "True" ]; then + echo "Checkpoint weights are from HuggingFace" + download_huggingface_checkpoint "$MODEL_PATH" "$MODEL_NAME" + else + HUGGINGFACE="False" + + # Example: + # the input checkpoint directory is gs://jetstream-checkpoints/llama-2-7b/base-checkpoint/ + # the local checkpoint directory will be /models/llama-2-7b/base-checkpoint/ + # INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}} + INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}} + TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL}/tokenizer.model + fi + + if [ -z $QUANTIZE_WEIGHTS ]; then + QUANTIZE_WEIGHTS="False" + fi + + # Possible quantizations: + # 1. quantize_weights = False, we run without specifying quantize_type + # 2. quantize_weights = True, we run without specifying quantize_type to use the default int8_per_channel + # 3. quantize_weights = True, we run and specify quantize_type + # We can use the same command for case #1 and #2, since both have quantize_weights set without needing to specify quantize_type + + echo -e "\n quantize weights: ${QUANTIZE_WEIGHTS}" + if [ $QUANTIZE_WEIGHTS == "True" ]; then + # quantize_type is required, it will be set to the default value if not turned on + if [ -n $QUANTIZE_TYPE ]; then + python3 -m convert_checkpoints --model_name=${MODEL_NAME} --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize_type=${QUANTIZE_TYPE} --quantize_weights=${QUANTIZE_WEIGHTS} --from_hf=${HUGGINGFACE} + fi + else + # quantize_weights should be false, but if not the convert_checkpoints script will catch it + python3 -m convert_checkpoints --model_name=${MODEL_NAME} --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize_weights=${QUANTIZE_WEIGHTS} --from_hf=${HUGGINGFACE} + fi echo -e "\nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_LOCAL}" echo -e "\nUploading converted checkpoint from local path ${OUTPUT_CKPT_DIR_LOCAL} to GSBucket ${OUTPUT_CKPT_DIR}" + gcloud storage cp -r ${OUTPUT_CKPT_DIR_LOCAL}/* ${OUTPUT_CKPT_DIR} + gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR} echo -e "\nCompleted uploading converted checkpoint from local path ${OUTPUT_CKPT_DIR_LOCAL} to GSBucket ${OUTPUT_CKPT_DIR}" } -while getopts 'b:i:m:q:v:1:2:' flag; do +while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do case "${flag}" in b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; - i) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + s) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; m) MODEL_PATH="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; - q) QUANTIZE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + n) MODEL_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + h) HUGGINGFACE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + t) QUANTIZE_TYPE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + q) QUANTIZE_WEIGHTS="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; - 1) EXTRA_PARAM_1="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; - 2) EXTRA_PARAM_2="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + i) INPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; + o) OUTPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;; *) print_usage exit 1 ;; esac @@ -142,8 +201,8 @@ case ${INFERENCE_SERVER} in convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION" ;; jetstream-pytorch) - check_model_path "$MODEL_PATH" - convert_pytorch_checkpoint "$MODEL_PATH" "$EXTRA_PARAM_1" "$EXTRA_PARAM_2" "$QUANTIZE" "$VERSION" + check_model_path "$MODEL_PATH" + convert_pytorch_checkpoint "$MODEL_PATH" "$MODEL_NAME" "$HUGGINGFACE" "$INPUT_DIRECTORY" "$OUTPUT_DIRECTORY" "$QUANTIZE_TYPE" "$QUANTIZE_WEIGHTS" "$VERSION" ;; *) print_inference_server_unknown exit 1 ;; diff --git a/tutorials-and-examples/inference-servers/jetstream/pytorch/jetstream-pytorch-server/Dockerfile b/tutorials-and-examples/inference-servers/jetstream/pytorch/jetstream-pytorch-server/Dockerfile index 81fcdffc9..a4bc13a58 100644 --- a/tutorials-and-examples/inference-servers/jetstream/pytorch/jetstream-pytorch-server/Dockerfile +++ b/tutorials-and-examples/inference-servers/jetstream/pytorch/jetstream-pytorch-server/Dockerfile @@ -4,7 +4,7 @@ FROM ubuntu:22.04 ENV DEBIAN_FRONTEND=noninteractive -ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.2 +ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.3 RUN apt -y update && apt install -y --no-install-recommends \ ca-certificates \ @@ -20,8 +20,6 @@ cd /jetstream-pytorch && \ git checkout ${PYTORCH_JETSTREAM_VERSION} && \ bash install_everything.sh -ENV PYTHONPATH=$PYTHONPATH:$(pwd)/deps/xla/experimental/torch_xla2:$(pwd)/JetStream:$(pwd) - COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/ RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh diff --git a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/README.md b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/README.md index 04b2fe5c5..910b6c944 100644 --- a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/README.md +++ b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/README.md @@ -66,7 +66,7 @@ $ kubectl annotate serviceaccount default \ iam.gke.io/gcp-service-account=jetstream-iam-sa@${PROJECT_ID}.iam.gserviceaccount.com ``` -### Create a Cloud Storage bucket to store the Llama2-7b model checkpoint +### Create a Cloud Storage bucket to store your model checkpoint ``` BUCKET_NAME= @@ -74,12 +74,71 @@ gcloud storage buckets create $BUCKET_NAME ``` ## Checkpoint conversion + +### [Option #1] Download weights from GitHub Follow the instructions here to download the llama-2-7b weights: https://github.com/meta-llama/llama#download -Upload your weights to your GSBucket +``` +ls llama + +llama-2-7b tokenizer.model .. +``` + +Upload your weights and tokenizer to your GSBucket + +``` +gcloud storage cp -r llama-2-7b/* gs://BUCKET_NAME/llama-2-7b/base/ +gcloud storage cp tokenizer.model gs://BUCKET_NAME/llama-2-7b/base/ +``` + +### [Option #2] Download weights from HuggingFace +Accept the terms and conditions from https://huggingface.co/meta-llama/Llama-2-7b-hf. + +For llama-3-8b: https://huggingface.co/meta-llama/Meta-Llama-3-8B. + +For gemma-2b: https://huggingface.co/google/gemma-2b-pytorch. + +Obtain a HuggingFace CLI token by going to your HuggingFace settings and under the `Access Tokens`, generate a `New token`. Edit permissions to your access token to have read access to your respective checkpoint repository. + +Copy your access token and create a Secret to store the HuggingFace token ``` -gcloud storage cp -r /* gs://BUCKET_NAME/llama-2-7b/base/ +kubectl create secret generic huggingface-secret \ + --from-literal=HUGGINGFACE_TOKEN= +``` + +### Apply the checkpoint conversion job + +For the following models, replace the following arguments in `checkpoint-job.yaml` + +#### Llama-2-7b-hf +``` +- -s=jetstream-pytorch +- -m=meta-llama/Llama-2-7b-hf +- -o=gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/ +- -n=llama-2 +- -q=False +- -h=True +``` + +#### Llama-3-8b +``` +- -s=jetstream-pytorch +- -m=meta-llama/Meta-Llama-3-8B +- -o=gs://BUCKET_NAME/pytorch/llama-3-8b/final/bf16/ +- -n=llama-3 +- -q=False +- -h=True +``` + +#### Gemma-2b +``` +- -s=jetstream-pytorch +- -m=google/gemma-2b-pytorch +- -o=gs://BUCKET_NAME/pytorch/gemma-2b/final/bf16/ +- -n=gemma +- -q=False +- -h=True ``` Run the checkpoint conversion job. This will use the [checkpoint conversion script](https://github.com/google/jetstream-pytorch/blob/main/convert_checkpoints.py) from Jetstream-pytorch to create a compatible Pytorch checkpoint @@ -95,19 +154,19 @@ Observe your checkpoint kubectl logs -f jobs/checkpoint-converter # This can take several minutes ... -Completed uploading converted checkpoint from local path /pt-ckpt/ to GSBucket gs://BUCKET_NAME/pytorch/llama2-7b/final/bf16/" +Completed uploading converted checkpoint from local path /pt-ckpt/ to GSBucket gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/" ``` -Now your converted checkpoint will be located in `gs://BUCKET_NAME/pytorch/llama2-7b/final/bf16/` +Now your converted checkpoint will be located in `gs://BUCKET_NAME/pytorch/llama-2-7b/final/bf16/` ## Deploy the Jetstream Pytorch server The following flags are set in the manifest file ``` ---param_size: Size of model +--size: Size of model +--model_name: Name of model (llama-2, llama-3, gemma) --batch_size: Batch size --max_cache_length: Maximum length of kv cache ---platform=tpu: TPU machine type (8 for v5e-8, 4 for v4-8) --tokenizer_path: Path to model tokenizer file --checkpoint_path: Path to checkpoint Optional flags to add @@ -115,6 +174,18 @@ Optional flags to add --quantize_kv_cache (Default False): Quantized kv cache ``` +For llama3-8b, you can use the following arguments: +``` +- --size=8b +- --model_name=llama-3 +- --batch_size=80 +- --max_cache_length=2048 +- --quantize_weights=False +- --quantize_kv_cache=False +- --tokenizer_path=/models/pytorch/llama3-8b/final/bf16/tokenizer.model +- --checkpoint_path=/models/pytorch/llama3-8b/final/bf16/model.safetensors +``` + ``` kubectl apply -f deployment.yaml ``` @@ -122,8 +193,8 @@ kubectl apply -f deployment.yaml ### Verify the deployment ``` kubectl get deployment -NAME READY UP-TO-DATE AVAILABLE AGE -jetstream-pytorch-server 2/2 2 2 ##s +NAME READY UP-TO-DATE AVAILABLE AGE +jetstream-pytorch-server 2/2 2 2 ##s ``` View the HTTP server logs to check that the model has been loaded and compiled. It may take the server a few minutes to complete this operation. diff --git a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/checkpoint-job.yaml b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/checkpoint-job.yaml index 99079648c..f48c1ac79 100644 --- a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/checkpoint-job.yaml +++ b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/checkpoint-job.yaml @@ -12,16 +12,20 @@ spec: restartPolicy: Never containers: - name: inference-checkpoint - image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0 + image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.3 args: - - -i=jetstream-pytorch + - -s=jetstream-pytorch - -m=/models - - -1=gs://BUCKET_NAME/pytorch/llama2-7b/base/ - - -2=gs://BUCKET_NAME/pytorch/llama2-7b/final/bf16/ + - -i=gs://BUCKET_NAME/pytorch/llama2-7b/base/ + - -o=gs://BUCKET_NAME/pytorch/llama2-7b/final/bf16/ + - -q=False volumeMounts: - mountPath: "/kaggle/" name: kaggle-credentials readOnly: true + - mountPath: "/huggingface/" + name: huggingface-credentials + readOnly: true - name: gcs-fuse-checkpoint mountPath: /models readOnly: true @@ -38,6 +42,10 @@ spec: secret: defaultMode: 0400 secretName: kaggle-secret + - name: huggingface-credentials + secret: + defaultMode: 0400 + secretName: huggingface-secret - name: gcs-fuse-checkpoint csi: driver: gcsfuse.csi.storage.gke.io diff --git a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/deployment.yaml b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/deployment.yaml index 51dffb062..126d8cfbd 100644 --- a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/deployment.yaml +++ b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/deployment.yaml @@ -19,12 +19,14 @@ spec: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice containers: - name: jetstream-pytorch-server - image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.0 + image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.3 args: - - --param_size=7b + - --size=7b + - --model_name=llama-2 - --batch_size=80 - --max_cache_length=2048 - - --platform=tpu=8 + - --quantize_weights=False + - --quantize_kv_cache=False - --tokenizer_path=/jetstream-pytorch/jetstream_pt/third_party/llama2/tokenizer.model - --checkpoint_path=/models/pytorch/llama-2-7b/final/bf16/model.safetensors ports: @@ -39,7 +41,7 @@ spec: limits: google.com/tpu: 8 - name: jetstream-http - image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.0 + image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2 ports: - containerPort: 8000 volumes: diff --git a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/pd-deployment.yaml b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/pd-deployment.yaml index e6491b5e7..6297ce41b 100644 --- a/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/pd-deployment.yaml +++ b/tutorials-and-examples/inference-servers/jetstream/pytorch/single-host-inference/pd-deployment.yaml @@ -17,16 +17,16 @@ spec: cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice containers: - name: jetstream-pytorch-server - image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.0 + image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-pytorch-server:v0.2.3 args: - - --param_size=7b + - --size=7b + - --model_name=llama-2 - --batch_size=80 - --max_cache_length=2048 - - --platform=tpu=8 - --quantize_weights=False - --quantize_kv_cache=False - --tokenizer_path=/jetstream-pytorch/jetstream_pt/third_party/llama2/tokenizer.model - - --checkpoint_path=/models/llama2-7b/bf16/model.safetensors + - --checkpoint_path=/models/pytorch/llama-2-7b/final/bf16/model.safetensors ports: - containerPort: 9000 volumeMounts: @@ -38,7 +38,7 @@ spec: limits: google.com/tpu: 8 - name: jetstream-http - image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.0 + image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2 ports: - containerPort: 8000 volumes: