-
Notifications
You must be signed in to change notification settings - Fork 177
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #43 from danielvegamyhre/indexed-job
Distributed ML training on GPUs using Indexed Jobs on GKE
- Loading branch information
Showing
3 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
FROM pytorch/pytorch:latest | ||
RUN pip install tqdm | ||
COPY mnist.py mnist.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Running distributed ML training workloads on GKE using Indexed Jobs | ||
|
||
In this guide you will run a distributed ML training workload on GKE using an [Indexed Job](https://kubernetes.io/blog/2021/04/19/introducing-indexed-jobs/). | ||
|
||
Specifically, you will train a handwritten digit image classifier on the classic MNIST dataset | ||
using PyTorch. The training computation will be distributed across 4 GPU nodes in a GKE cluster. | ||
|
||
## Prerequisites | ||
- [Google Cloud](https://cloud.google.com/) account set up. | ||
- [gcloud](https://pypi.org/project/gcloud/) command line tool installed and configured to use your GCP project. | ||
- [kubectl](https://kubernetes.io/docs/tasks/tools/) command line utility is installed. | ||
- [docker](https://docs.docker.com/engine/install/) is installed. | ||
|
||
### 1. Create a standard GKE cluster | ||
Run the command: | ||
|
||
```bash | ||
gcloud container clusters create demo --zone us-central1-c | ||
``` | ||
|
||
You should see output indicating the cluster is being created (this can take ~10 minutes or so). | ||
|
||
### 2. Create a GPU node pool. | ||
You can choose any supported GPU type you wish, using a supported machine type. See the [docs](https://cloud.google.com/kubernetes-engine/docs/how-to/gpus) for more details. In this example, we are using NVIDIA Tesla T4s with the N1 machine family. | ||
|
||
```bash | ||
gcloud container node-pools create gpu-pool \ | ||
--accelerator type=nvidia-tesla-t4,count=1,gpu-driver-version=LATEST \ | ||
--machine-type n1-standard-4 \ | ||
--zone us-central1-c --cluster demo \ | ||
--node-locations us-central1-c \ | ||
--num-nodes 4 | ||
``` | ||
|
||
Creating this GPU node pool will take a few minutes. | ||
|
||
### 3. Build and push the Docker image to GCR | ||
Make a local copy of the [mnist.py](mnist.py) file which defines a traditional convolutional neural network, as the training logic which trains the model on the classic [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset. | ||
|
||
Next, make a local copy of the [Dockerfile](Dockerfile) and run the following commands to build the container image and push it to your GCR repository: | ||
|
||
```bash | ||
export PROJECT_ID=<your GCP project ID> | ||
docker build -t pytorch-mnist-gpu -f Dockerfile . | ||
docker tag pytorch-mnist-gpu gcr.io/$PROJECT_ID/pytorch-mnist-gpu:latest | ||
docker push gcr.io/$PROJECT_ID/pytorch-mnist-gpu:latest | ||
``` | ||
|
||
|
||
### 4. Define an Indexed Job and Headless Service | ||
|
||
In the yaml below, we configure an Indexed Job to run 4 pods, 1 for each GPU node, and use [torchrun](https://pytorch.org/docs/stable/elastic/run.html) to kick off a distributed training job for the CNN model on the MNIST dataset. This training job will utilize 1 T4 GPU chip on each node in the node pool. | ||
|
||
We also define a [headless service](https://kubernetes.io/docs/concepts/services-networking/service/#headless-services) which selects the | ||
pods owned by this Indexed Job. This will trigger the creation of the DNS records needed for the pods to communicate with eachother | ||
over the network via hostnames. | ||
|
||
Copy the yaml below into a local file `mnist.yaml` and be sure to replace `<PROJECT_ID>` with your GCP project ID in the container image. | ||
|
||
```yaml | ||
apiVersion: v1 | ||
kind: Service | ||
metadata: | ||
name: headless-svc | ||
spec: | ||
clusterIP: None | ||
selector: | ||
job-name: pytorchworker | ||
--- | ||
apiVersion: batch/v1 | ||
kind: Job | ||
metadata: | ||
name: pytorchworker | ||
spec: | ||
backoffLimit: 0 | ||
completions: 4 | ||
parallelism: 4 | ||
completionMode: Indexed | ||
template: | ||
spec: | ||
subdomain: headless-svc | ||
restartPolicy: Never | ||
nodeSelector: | ||
cloud.google.com/gke-accelerator: nvidia-tesla-t4 | ||
tolerations: | ||
- operator: "Exists" | ||
key: nvidia.com/gpu | ||
containers: | ||
- name: pytorch | ||
image: gcr.io/<PROJECT_ID>/pytorch-mnist-gpu:latest | ||
imagePullPolicy: Always | ||
ports: | ||
- containerPort: 3389 | ||
env: | ||
- name: MASTER_ADDR | ||
value: pytorchworker-0.headless-svc | ||
- name: MASTER_PORT | ||
value: "3389" | ||
- name: PYTHONBUFFERED | ||
value: "0" | ||
- name: LOGLEVEL | ||
value: "INFO" | ||
- name: RANK | ||
valueFrom: | ||
fieldRef: | ||
fieldPath: metadata.annotations['batch.kubernetes.io/job-completion-index'] | ||
command: | ||
- bash | ||
- -xc | ||
- | | ||
printenv | ||
torchrun --rdzv_id=123 --nnodes=4 --nproc_per_node=1 --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --node_rank=$RANK mnist.py --epochs=1 --log-interval=1 | ||
``` | ||
### 5. Run the training job | ||
Run the following command to create the Kubernetes resources we defined above and run the training job: | ||
```bash | ||
kubectl apply -f mnist.yaml | ||
``` | ||
|
||
You should see 4 pods created (note the container image is large and may take a few minutes to pull before the container starts running): | ||
|
||
``` | ||
$ kubectl get pods | ||
NAME READY STATUS RESTARTS AGE | ||
pytorchworker-0-bbsmk 0/1 ContainerCreating 0 15s | ||
pytorchworker-1-92tbl 0/1 ContainerCreating 0 15s | ||
pytorchworker-2-nbrgf 0/1 ContainerCreating 0 15s | ||
pytorchworker-3-rsrdf 0/1 ContainerCreating 0 15s | ||
``` | ||
|
||
### 4. Observe training logs | ||
|
||
Once the pods transition from the `ContainerCreating` status to the `Running` status, you can observe the training logs by examining the pod logs. | ||
|
||
```bash | ||
$ kubectl logs -f pytorchworker-1 | ||
|
||
+ torchrun --rdzv_id=123 --nnodes=4 --nproc_per_node=1 --master_addr=pytorchworker-0.headless-svc --master_port=3389 --node_rank=1 mnist.py --epochs=1 --log-interval=1 | ||
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz | ||
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz | ||
100%|██████████| 9912422/9912422 [00:00<00:00, 90162259.46it/s] | ||
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw | ||
|
||
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz | ||
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz | ||
100%|██████████| 28881/28881 [00:00<00:00, 33279036.76it/s] | ||
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw | ||
|
||
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz | ||
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz | ||
100%|██████████| 1648877/1648877 [00:00<00:00, 23474415.33it/s] | ||
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw | ||
|
||
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz | ||
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz | ||
100%|██████████| 4542/4542 [00:00<00:00, 19165521.90it/s] | ||
Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw | ||
|
||
Train Epoch: 1 [0/60000 (0%)] Loss: 2.297087 | ||
Train Epoch: 1 [64/60000 (0%)] Loss: 2.550339 | ||
Train Epoch: 1 [128/60000 (1%)] Loss: 2.361300 | ||
... | ||
|
||
Train Epoch: 1 [14912/60000 (99%)] Loss: 0.051500 | ||
Train Epoch: 1 [5616/60000 (100%)] Loss: 0.209231 | ||
235it [00:36, 6.51it/s] | ||
|
||
Test set: Average loss: 0.0825, Accuracy: 9720/10000 (97%) | ||
|
||
INFO:torch.distributed.elastic.agent.server.api:[default] worker group successfully finished. Waiting 300 seconds for other agents to finish. | ||
INFO:torch.distributed.elastic.agent.server.api:Local worker group finished (SUCCEEDED). Waiting 300 seconds for other agents to finish | ||
INFO:torch.distributed.elastic.agent.server.api:Done waiting for other agents. Elapsed: 0.0015289783477783203 seconds | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import argparse | ||
from tqdm import tqdm | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torchvision import datasets, transforms | ||
from torch.optim.lr_scheduler import StepLR | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
|
||
class CNN(nn.Module): | ||
''' | ||
Convolutional neural network. | ||
''' | ||
def __init__(self): | ||
super(CNN, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 32, 3, 1) | ||
self.conv2 = nn.Conv2d(32, 64, 3, 1) | ||
self.dropout1 = nn.Dropout(0.25) | ||
self.dropout2 = nn.Dropout(0.5) | ||
self.fc1 = nn.Linear(9216, 128) | ||
self.fc2 = nn.Linear(128, 10) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = F.relu(x) | ||
x = self.conv2(x) | ||
x = F.relu(x) | ||
x = F.max_pool2d(x, 2) | ||
x = self.dropout1(x) | ||
x = torch.flatten(x, 1) | ||
x = self.fc1(x) | ||
x = F.relu(x) | ||
x = self.dropout2(x) | ||
x = self.fc2(x) | ||
output = F.log_softmax(x, dim=1) | ||
return output | ||
|
||
|
||
def train(args, model, device, train_loader, optimizer, epoch): | ||
model.train() | ||
for batch_idx, (data, target) in tqdm(enumerate(train_loader)): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = F.nll_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if batch_idx % args.log_interval == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(train_loader.dataset), | ||
100. * batch_idx / len(train_loader), loss.item())) | ||
if args.dry_run: | ||
break | ||
|
||
|
||
def test(model, device, test_loader): | ||
model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in test_loader: | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss | ||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(test_loader.dataset) | ||
|
||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | ||
test_loss, correct, len(test_loader.dataset), | ||
100. * correct / len(test_loader.dataset))) | ||
|
||
|
||
def main(): | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | ||
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | ||
help='input batch size for training (default: 64)') | ||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | ||
help='input batch size for testing (default: 1000)') | ||
parser.add_argument('--epochs', type=int, default=14, metavar='N', | ||
help='number of epochs to train (default: 14)') | ||
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | ||
help='learning rate (default: 1.0)') | ||
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | ||
help='Learning rate step gamma (default: 0.7)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disables CUDA training') | ||
parser.add_argument('--no-mps', action='store_true', default=False, | ||
help='disables macOS GPU training') | ||
parser.add_argument('--dry-run', action='store_true', default=False, | ||
help='quickly check a single pass') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | ||
help='how many batches to wait before logging training status') | ||
parser.add_argument('--save-model', action='store_true', default=False, | ||
help='For Saving the current Model') | ||
args = parser.parse_args() | ||
use_cuda = not args.no_cuda and torch.cuda.is_available() | ||
use_mps = not args.no_mps and torch.backends.mps.is_available() | ||
|
||
torch.manual_seed(args.seed) | ||
|
||
# Initialize distributed training coordation. | ||
torch.distributed.init_process_group(backend="gloo") | ||
|
||
device = "cpu" if not torch.cuda.is_available() else "cuda" | ||
|
||
train_kwargs = {'batch_size': args.batch_size} | ||
test_kwargs = {'batch_size': args.test_batch_size} | ||
if use_cuda: | ||
cuda_kwargs = {'num_workers': 1, | ||
'pin_memory': True} | ||
train_kwargs.update(cuda_kwargs) | ||
test_kwargs.update(cuda_kwargs) | ||
|
||
# Set up distributed training to use DDP. | ||
model = torch.nn.parallel.DistributedDataParallel(CNN().to(device)) | ||
|
||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)) | ||
]) | ||
|
||
train_set = datasets.MNIST('../data', train=True, download=True, | ||
transform=transform) | ||
test_set = datasets.MNIST('../data', train=False, | ||
transform=transform) | ||
|
||
# Set up distributed data sampling so each worker in our DDP set up will | ||
# process a specific subset/partition of the training data. | ||
train_sampler = DistributedSampler(dataset=train_set) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_set,**train_kwargs, sampler=train_sampler) | ||
test_loader = torch.utils.data.DataLoader(test_set, **test_kwargs) | ||
|
||
|
||
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | ||
|
||
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | ||
for epoch in range(1, args.epochs + 1): | ||
train(args, model, device, train_loader, optimizer, epoch) | ||
test(model, device, test_loader) | ||
scheduler.step() | ||
|
||
if args.save_model: | ||
torch.save(model.state_dict(), "mnist_cnn.pt") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |