From 1d625ddcfa6698b08b0845db87474a359c5eed23 Mon Sep 17 00:00:00 2001 From: Mofi Rahman Date: Tue, 11 Jun 2024 15:35:02 -0400 Subject: [PATCH] add finetuning gemma on GKE with L4 GPUs example (#697) * add finetuning gemma example Signed-off-by: Mofi Rahman * resolve comments Signed-off-by: Mofi Rahman * remove extra tag from title Signed-off-by: Mofi Rahman --------- Signed-off-by: Mofi Rahman --- .../finetuning-gemma-2b-on-l4/Dockerfile | 30 +++ .../finetuning-gemma-2b-on-l4/README.md | 174 ++++++++++++ .../finetuning-gemma-2b-on-l4/cloudbuild.yaml | 5 + .../finetuning-gemma-2b-on-l4/finetune.py | 247 ++++++++++++++++++ .../finetuning-gemma-2b-on-l4/finetune.yaml | 57 ++++ 5 files changed, 513 insertions(+) create mode 100644 tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/Dockerfile create mode 100644 tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/README.md create mode 100644 tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/cloudbuild.yaml create mode 100644 tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.py create mode 100644 tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.yaml diff --git a/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/Dockerfile b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/Dockerfile new file mode 100644 index 000000000..bc77319cc --- /dev/null +++ b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/Dockerfile @@ -0,0 +1,30 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM nvidia/cuda:12.2.0-runtime-ubuntu22.04 + +RUN apt-get update && \ + apt-get -y --no-install-recommends install python3-dev gcc python3-pip git && \ + rm -rf /var/lib/apt/lists/* + +RUN pip3 install --no-cache-dir \ + accelerate==0.30.1 bitsandbytes==0.43.1 \ + datasets==2.19.1 transformers==4.41.0 \ + peft==0.11.1 trl==0.8.6 torch==2.3.0 + +COPY finetune.py /finetune.py + +ENV PYTHONUNBUFFERED 1 + +CMD python3 /finetune.py --device cuda diff --git a/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/README.md b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/README.md new file mode 100644 index 000000000..9d9c31a1d --- /dev/null +++ b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/README.md @@ -0,0 +1,174 @@ +# Tutorial: Finetuning Gemma 2b on GKE using L4 GPUs + +We’ll walk through fine-tuning a Gemma 2b model using GKE using 8 x L4 GPUs. L4 GPUs are suitable for many use cases beyond serving models. We will demonstrate how the L4 GPU is a great option for fine tuning LLMs, at a fraction of the cost of using a higher end GPU. + +Let’s get started and fine-tune Gemma 2B on the [b-mc2/sql-create-context](https://huggingface.co/datasets/b-mc2/sql-create-context) dataset using GKE. +Parameter Efficient Fine Tuning (PEFT) and LoRA is used so fine-tuning is posible +on GPUs with less GPU memory. + +As part of this tutorial, you will get to do the following: + +1. Prepare your environment with a GKE cluster in + Autopilot mode. +2. Create a finetune container. +3. Use GPU to finetune the Gemma 2B model and upload the model to huggingface. + +## Prerequisites + +* A terminal with `kubectl` and `gcloud` installed. Cloud Shell works great! +* Create a [Hugging Face](https://huggingface.co/) account, if you don't already have one. +* Ensure your project has sufficient quota for GPUs. To learn more, see [About GPUs](/kubernetes-engine/docs/concepts/gpus#gpu-quota) and [Allocation quotas](/compute/resource-usage#gpu_quota). +* To get access to the Gemma models for deployment to GKE, you must first sign the license consent agreement then generate a Hugging Face access token. Make sure the token has `Write` permission. + +## Creating the GKE cluster with L4 nodepools + +Let’s start by setting a few environment variables that will be used throughout this post. You should modify these variables to meet your environment and needs. + +Download the code and files used throughout the tutorial: + +```bash +git clone https://github.com/GoogleCloudPlatform/ai-on-gke +cd ai-on-gke/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4 +``` + +Run the following commands to set the env variables and make sure to replace ``: + +```bash +gcloud config set project +export PROJECT_ID=$(gcloud config get project) +export REGION=us-central1 +export HF_TOKEN= +export CLUSTER_NAME=finetune-gemma +``` + +> Note: You might have to rerun the export commands if for some reason you reset your shell and the variables are no longer set. This can happen for example when your Cloud Shell disconnects. + +Create the GKE cluster by running: + +```bash +gcloud container clusters create-auto ${CLUSTER_NAME} \ + --project=${PROJECT_ID} \ + --region=${REGION} \ + --release-channel=rapid \ + --cluster-version=1.29 +``` + +### Create a Kubernetes secret for Hugging Face credentials + +In your shell session, do the following: + + 1. Configure `kubectl` to communicate with your cluster: + + ```sh + gcloud container clusters get-credentials ${CLUSTER_NAME} --location=${REGION} + ``` + + 2. Create a Kubernetes Secret that contains the Hugging Face token: + + ```sh + kubectl create secret generic hf-secret \ + --from-literal=hf_api_token=${HF_TOKEN} \ + --dry-run=client -o yaml | kubectl apply -f - + ``` + +### Containerize the Code with Docker and Cloud Build + +1. Create an Artifact Registry Docker Repository + + ```sh + gcloud artifacts repositories create gemma \ + --project=${PROJECT_ID} \ + --repository-format=docker \ + --location=us \ + --description="Gemma Repo" + ``` + +2. Execute the build and create inference container image. + + ```sh + gcloud builds submit . + ``` + +## Run Finetune Job on GKE + +1. Open the `finetune.yaml` manifest. +2. Edit the `image` name with the container image built with Cloud Build and `NEW_MODEL` environment variable value. This `NEW_MODEL` will be the name of the model you would save as a public model in your Hugging Face account. +3. Run the following command to create the finetune job: + + ```sh + kubectl apply -f finetune.yaml + ``` + +4. Monitor the job by running: + + ```sh + watch kubectl get pods + ``` + +5. You can check the logs of the job by running: + + ```sh + kubectl logs -f -l app=gemma-finetune + ``` + +6. Once the job is completed, you can check the model in Hugging Face. + +## Serve the Finetuned Model on GKE + +To deploy the finetuned model on GKE you can follow the instructions from Deploy a pre-trained Gemma model on [Hugging Face TGI](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-gpu-tgi#deploy-pretrained) or [vLLM](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-gpu-vllm#deploy-vllm). Select the Gemma 2B instruction and change the `MODEL_ID` to `/gemma-2b-sql-finetuned`. + +### Set up port forwarding + +Once the model is deploye, run the following command to set up port forwarding to the model: + +```sh +kubectl port-forward service/llm-service 8000:8000 +``` + +The output is similar to the following: + +```sh +Forwarding from 127.0.0.1:8000 -> 8000 +``` + +### Interact with the model using curl + +Once the model is deployed In a new terminal session, use curl to chat with your model: + +> The following example command is for TGI. + +```sh +USER_PROMPT="Question: What is the total number of attendees with age over 30 at kubecon eu? Context: CREATE TABLE attendees (name VARCHAR, age INTEGER, kubecon VARCHAR)" + +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d @- < 30 AND kubecon = 'eu'\n"} +``` + +## Clean Up + +To avoid incurring charges to your Google Cloud account for the resources used in this tutorial, either delete the project that contains the resources, or keep the project and delete the individual resources. + +### Delete the deployed resources + +To avoid incurring charges to your Google Cloud account for the resources that you created in this guide, run the following command: + +```sh +gcloud container clusters delete ${CLUSTER_NAME} \ + --region=${REGION} +``` diff --git a/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/cloudbuild.yaml b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/cloudbuild.yaml new file mode 100644 index 000000000..cf85100ff --- /dev/null +++ b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/cloudbuild.yaml @@ -0,0 +1,5 @@ +steps: +- name: 'gcr.io/cloud-builders/docker' + args: [ 'build', '-t', 'us-docker.pkg.dev/$PROJECT_ID/gemma/finetune-gemma-gpu:1.0.0', '.' ] +images: +- 'us-docker.pkg.dev/$PROJECT_ID/gemma/finetune-gemma-gpu:1.0.0' diff --git a/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.py b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.py new file mode 100644 index 000000000..b7d464b1b --- /dev/null +++ b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.py @@ -0,0 +1,247 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import torch +from datasets import load_dataset, Dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + TrainingArguments, +) +from peft import LoraConfig, PeftModel + +from trl import SFTTrainer + +# The model that you want to train from the Hugging Face hub +model_name = os.getenv("MODEL_NAME", "google/gemma-2b") + +# The instruction dataset to use +dataset_name = "b-mc2/sql-create-context" + +# Fine-tuned model name +new_model = os.getenv("NEW_MODEL", "gemma-2b-sql") + +################################################################################ +# QLoRA parameters +################################################################################ + +# LoRA attention dimension +lora_r = int(os.getenv("LORA_R", "4")) + +# Alpha parameter for LoRA scaling +lora_alpha = int(os.getenv("LORA_ALPHA", "8")) + +# Dropout probability for LoRA layers +lora_dropout = 0.1 + +################################################################################ +# bitsandbytes parameters +################################################################################ + +# Activate 4-bit precision base model loading +use_4bit = True + +# Compute dtype for 4-bit base models +bnb_4bit_compute_dtype = "float16" + +# Quantization type (fp4 or nf4) +bnb_4bit_quant_type = "nf4" + +# Activate nested quantization for 4-bit base models (double quantization) +use_nested_quant = False + +################################################################################ +# TrainingArguments parameters +################################################################################ + +# Output directory where the model predictions and checkpoints will be stored +output_dir = "./results" + +# Number of training epochs +num_train_epochs = 1 + +# Enable fp16/bf16 training (set bf16 to True with an A100) +fp16 = True +bf16 = False + +# Batch size per GPU for training +per_device_train_batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "1")) + +# Batch size per GPU for evaluation +per_device_eval_batch_size = int(os.getenv("EVAL_BATCH_SIZE", "2")) + +# Number of update steps to accumulate the gradients for +gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "1")) + +# Enable gradient checkpointing +gradient_checkpointing = True + +# Maximum gradient normal (gradient clipping) +max_grad_norm = 0.3 + +# Initial learning rate (AdamW optimizer) +learning_rate = 2e-4 + +# Weight decay to apply to all layers except bias/LayerNorm weights +weight_decay = 0.001 + +# Optimizer to use +optim = "paged_adamw_32bit" + +# Learning rate schedule +lr_scheduler_type = "cosine" + +# Number of training steps (overrides num_train_epochs) +max_steps = -1 + +# Ratio of steps for a linear warmup (from 0 to learning rate) +warmup_ratio = 0.03 + +# Group sequences into batches with same length +# Saves memory and speeds up training considerably +group_by_length = True + +# Save checkpoint every X updates steps +save_steps = 0 + +# Log every X updates steps +logging_steps = int(os.getenv("LOGGING_STEPS", "50")) + +################################################################################ +# SFT parameters +################################################################################ + +# Maximum sequence length to use +max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", "512")) + +# Pack multiple short examples in the same input sequence to increase efficiency +packing = False + +# Load the entire model on the GPU 0 +device_map = {'':torch.cuda.current_device()} + +# Set limit to a positive number +limit = int(os.getenv("DATASET_LIMIT", "5000")) + +dataset = load_dataset(dataset_name, split="train") +if limit != -1: + dataset = dataset.shuffle(seed=42).select(range(limit)) + + +def transform(data): + question = data['question'] + context = data['context'] + answer = data['answer'] + template = "Question: {question}\nContext: {context}\nAnswer: {answer}" + return {'text': template.format(question=question, context=context, answer=answer)} + + +transformed = dataset.map(transform) + +# Load tokenizer and model with QLoRA configuration +compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + +bnb_config = BitsAndBytesConfig( + load_in_4bit=use_4bit, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_compute_dtype=compute_dtype, + bnb_4bit_use_double_quant=use_nested_quant, +) + +# Check GPU compatibility with bfloat16 +if compute_dtype == torch.float16 and use_4bit: + major, _ = torch.cuda.get_device_capability() + if major >= 8: + print("=" * 80) + print("Your GPU supports bfloat16") + print("=" * 80) + +# Load base model +# model = AutoModelForCausalLM.from_pretrained("google/gemma-7b") +model = AutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=bnb_config, + device_map=device_map, + torch_dtype=torch.float16, +) +model.config.use_cache = False +model.config.pretraining_tp = 1 + +# Load LLaMA tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +# Load LoRA configuration +peft_config = LoraConfig( + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + r=lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=["q_proj", "v_proj"] +) + +# Set training parameters +training_arguments = TrainingArguments( + output_dir=output_dir, + num_train_epochs=num_train_epochs, + per_device_train_batch_size=per_device_train_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, + optim=optim, + save_steps=save_steps, + logging_steps=logging_steps, + learning_rate=learning_rate, + weight_decay=weight_decay, + fp16=fp16, + bf16=bf16, + max_grad_norm=max_grad_norm, + max_steps=max_steps, + warmup_ratio=warmup_ratio, + group_by_length=group_by_length, + lr_scheduler_type=lr_scheduler_type, +) + +trainer = SFTTrainer( + model=model, + train_dataset=transformed, + peft_config=peft_config, + dataset_text_field="text", + max_seq_length=max_seq_length, + tokenizer=tokenizer, + args=training_arguments, + packing=packing, +) + +trainer.train() + +trainer.model.save_pretrained(new_model) + +# Reload model in FP16 and merge it with LoRA weights +base_model = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=True, + return_dict=True, + torch_dtype=torch.float16, + device_map=device_map, +) +model = PeftModel.from_pretrained(base_model, new_model) +model = model.merge_and_unload() + + + +model.push_to_hub(new_model, check_pr=True) + +tokenizer.push_to_hub(new_model, check_pr=True) diff --git a/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.yaml b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.yaml new file mode 100644 index 000000000..908fe4b09 --- /dev/null +++ b/tutorials-and-examples/genAI-LLM/finetuning-gemma-2b-on-l4/finetune.yaml @@ -0,0 +1,57 @@ +apiVersion: batch/v1 +kind: Job +metadata: + name: finetune-job + namespace: default + labels: + app: gemma-finetune +spec: + backoffLimit: 2 + template: + metadata: + annotations: + kubectl.kubernetes.io/default-container: finetuner + spec: + terminationGracePeriodSeconds: 600 + containers: + - name: finetuner + image: + resources: + limits: + nvidia.com/gpu: "8" + env: + - name: MODEL_NAME + value: "google/gemma-2b" + - name: NEW_MODEL + value: "" + - name: LORA_R + value: "8" + - name: LORA_ALPHA + value: "16" + - name: TRAIN_BATCH_SIZE + value: "1" + - name: EVAL_BATCH_SIZE + value: "2" + - name: GRADIENT_ACCUMULATION_STEPS + value: "2" + - name: DATASET_LIMIT + value: "1000" + - name: MAX_SEQ_LENGTH + value: "512" + - name: LOGGING_STEPS + value: "5" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-secret + key: hf_api_token + volumeMounts: + - mountPath: /dev/shm + name: dshm + volumes: + - name: dshm + emptyDir: + medium: Memory + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-l4 + restartPolicy: OnFailure