Skip to content

Commit

Permalink
Document how to compute gradients without running OOM (#30)
Browse files Browse the repository at this point in the history
* Rework ID-based param

* Update docs

* Test decoder init

* Add checkpointing and AMP
  • Loading branch information
wesselb authored Sep 11, 2024
1 parent b65b87d commit c91c3ab
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
30 changes: 23 additions & 7 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import contextlib
import dataclasses
from datetime import timedelta
from functools import partial

import torch
from huggingface_hub import hf_hub_download
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)

from aurora.batch import Batch
from aurora.model.decoder import Perceiver3DDecoder
from aurora.model.encoder import Perceiver3DEncoder
from aurora.model.lora import LoRAMode
from aurora.model.swin3d import Swin3DTransformerBackbone
from aurora.model.swin3d import BasicLayer3D, Swin3DTransformerBackbone

__all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"]

Expand Down Expand Up @@ -47,6 +51,7 @@ def __init__(
use_lora: bool = True,
lora_steps: int = 40,
lora_mode: LoRAMode = "single",
autocast: bool = False,
) -> None:
"""Construct an instance of the model.
Expand Down Expand Up @@ -92,11 +97,14 @@ def __init__(
lora_mode (str, optional): LoRA mode. `"single"` uses the same LoRA for all roll-out
steps, and `"all"` uses a different LoRA for every roll-out step. Defaults to
`"single"`.
autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to
`False`.
"""
super().__init__()
self.surf_vars = surf_vars
self.atmos_vars = atmos_vars
self.patch_size = patch_size
self.autocast = autocast

self.encoder = Perceiver3DEncoder(
surf_vars=surf_vars,
Expand Down Expand Up @@ -181,12 +189,13 @@ def forward(self, batch: Batch) -> Batch:
batch,
lead_time=timedelta(hours=6),
)
x = self.backbone(
x,
lead_time=timedelta(hours=6),
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext():
x = self.backbone(
x,
lead_time=timedelta(hours=6),
patch_res=patch_res,
rollout_step=batch.metadata.rollout_step,
)
pred = self.decoder(
x,
batch,
Expand Down Expand Up @@ -297,6 +306,13 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:

self.load_state_dict(d, strict=strict)

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
This is required in order to compute gradients without running out of memory.
"""
apply_activation_checkpointing(self, check_fn=lambda x: isinstance(x, BasicLayer3D))


AuroraSmall = partial(
Aurora,
Expand Down
27 changes: 27 additions & 0 deletions docs/finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,33 @@ model = Aurora(use_lora=False) # Model is not fine-tuned.
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
```

## Computing Gradients

To compute gradients, you will need an A100 with 80 GB of memory.
In addition, you will need to use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)
and gradient checkpointing.
You can do this as follows:

```python
from aurora import Aurora

model = Aurora(
use_lora=False, # Model was not fine-tuned.
autocast=True, # Use AMP.
)
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")

batch = ... # Load some data.

model = model.cuda()
model.train()
model.configure_activation_checkpointing()

pred = model.forward(batch)
loss = ...
loss.backward()
```

## Extending Aurora with New Variables

Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,
Expand Down

0 comments on commit c91c3ab

Please sign in to comment.