Skip to content

eScience: Investigating Spot VM stability for ML model training

Notifications You must be signed in to change notification settings

oorjitchowdhary/ml-training-preemptible-vms

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UW eScience Institute

ML training on preemptible VMs

This repository represents an ongoing cloud computing effort at the UW eScience Institute that aims to optimize cloud usage for research computing. The overarching goal is to find the most prevalent cloud use cases for research and approximate a walkthrough alongside a breakdown of the cost analysis and performance metrics for each use case.

Overview

At the moment, the repository showcases training a basic convolutional neural network (CNN) model on the CIFAR-10 dataset as well as fine-tuning a pre-trained ResNet50 model on the ImageNet dataset. The training is done on Google Cloud Platform (GCP) using its Spot VM instances, which are significantly cheaper than regular VM instances but are preemptible and can be terminated by Google at any time. We aim to design a workflow that can take advantage of the cost savings of Spot VMs while minimizing the impact of potential interruptions.

How?

Through periodic checkpointing of the model in training to an external cloud storage bucket, we can iteratively save the model's progress and resume training from the last checkpoint in case of a VM interruption.

Technical Details

CIFAR-10 task

Dataset: CIFAR-10, a popular dataset for image classification tasks. The dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class.

Model: A classic CNN model with 2 convolutional layers, a max-pooling layer, and 3 fully connected layers. The model is trained using the Adam optimizer and the categorical cross-entropy loss function.

Cloud Storage: A Google Cloud Storage bucket stores the .pth files of the model.

ImageNet task

Dataset: ImageNet subset, a subset of the ImageNet dataset. The subset dataset consists of 1.28 million images spanning 1000 classes.

Model: A pre-trained ResNet50 model is fine-tuned on the ImageNet subset dataset. The model is trained using the Adam optimizer and the categorical cross-entropy loss function.

Cloud Storage: A Google Cloud Storage bucket stores the .pth files of the model.

Workflow

Training

CIFAR-10:

  1. Load and normalize CIFAR-10 using torchvision.
  2. Define the 3 layer CNN model.
  3. Define the loss function and optimizer.
  4. Loop over the dataset n times, feeding inputs to the model and optimizing the weights.

Relevant tutorial: Training a Classifier in PyTorch

ImageNet:

  1. Load the pre-trained ResNet50 model from torchvision.models.
  2. Replace the final fully connected layer with a new one that has the same number of outputs as the number of classes in the dataset.
  3. Define the loss function and optimizer.
  4. Loop over the dataset n times, feeding inputs to the model and optimizing the weights.

Relevant tutorial: ImageNet training in PyTorch

Checkpointing

  1. Periodically save (at the end of each epoch) the model's state dictionary to a .pth file in /checkpoints.
  2. Upload the .pth file to a cloud storage bucket.

Cloud Storage bucket:
A Google Cloud Storage bucket is used to store the .pth files of the model. Setting up such a bucket involves creating or using the default service account, downloading the JSON key file, and setting the service_account_json variable in checkpointing.py to the path of the JSON key file. Make sure to also replace the bucket_name variable with the name of your bucket.

Note: You can use any cloud storage service to store the model checkpoints as long as you modify the checkpointing.py script accordingly.

Preemption Handling

Detecting a preemption event involves different strategies depending on the cloud provider. However, a popular approach is to poll the metadata server for a preemption event.

In this repository, we demonstrate concurrently polling the metadata server for a preemption event every 5 seconds while training a model. This approach uses Python's threading module to run the polling function in a separate thread. If a preemption event is detected, the model's state is saved and uploaded to the assigned cloud storage bucket, making the script exit gracefully.

Simulating a preemption event:
On GCP, you can simulate a preemption via a host maintenance event. Read more here.

gcloud compute instances simulate-maintenance-event VM_NAME --zone ZONE

Relevant resources for other cloud providers:

A note on SkyPilot

There is a skypilot.yaml file initialized in the repository that contains the configuration for utilizing SkyPilot, a framework for running jobs across clouds.

TODO: Notes about task.yaml and auto-failover.

Getting Started

  1. If you haven't already, create a Google Cloud Storage bucket and replace the bucket_name variable in checkpointing.py with the name of your bucket.
  2. Create a service account; download the JSON key file and replace the service_account_json variable in checkpointing.py with the path to the JSON key file.
  3. To see the workflow in action, follow the installation instructions below:
# Clone the repository
git clone https://github.com/oorjitchowdhary/ml-training-preemptible-vms.git
cd ml-training-preemptible-vms

# Create a virtual environment
python3 -m venv venv
source venv/bin/activate

# Install the required packages
pip install -r requirements.txt

# Run the script
python index.py

About

eScience: Investigating Spot VM stability for ML model training

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages