Skip to content

Commit

Permalink
Add instructions for loading models and run in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Aug 5, 2024
1 parent 9d3a288 commit 0375bca
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
fail-fast: false
matrix:
version: ["3.10", "3.11"]
environment: huggingface-access

name: Test with Python ${{ matrix.version }}
steps:
Expand All @@ -25,6 +26,9 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade --no-cache-dir -e '.[dev]'
- name: Log into HugginFace
run: |
huggingface-cli login --token ${{ secrets.HUGGINGFACE_TOKEN }}
- name: Run tests
run: |
pytest -v --cov=aurora --cov-report term-missing
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Install with `pip`:
pip install microsoft-aurora
```

Run an untrained small model on random data:
Run the pretrained small model on random data:

```python
import torch
Expand All @@ -66,6 +66,8 @@ from aurora import AuroraSmall, Batch, Metadata

model = AuroraSmall()

model.load_checkpoint("wbruinsma/aurora", "aurora-0.25-small-pretrained.ckpt")

batch = Batch(
surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")},
static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")},
Expand All @@ -83,6 +85,12 @@ prediction = model.forward(batch)
print(prediction.surf_vars["2t"])
```

Note that this will incur a 500 MB download
and you may need to authenticate with `huggingface-cli login`.

See the [HuggingFace repository `wbruinsma/aurora`](https://huggingface.co/wbruinsma/aurora)
for an overview of which models are available.

## Contributing

See [`CONTRIBUTING.md`](CONTRIBUTING.md).
Expand Down Expand Up @@ -148,6 +156,13 @@ First, install the repository in editable mode and setup `pre-commit`:
make install
```

Then configure the HuggingFace repository where the weights can be found and log into HuggingFace:

```bash
export HUGGINGFACE_REPO=wbruinsma/aurora
huggingface-cli login
```

To run the tests and print coverage, run

```bash
Expand Down
18 changes: 16 additions & 2 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from datetime import timedelta
from functools import partial

from torch import nn
import torch
from huggingface_hub import hf_hub_download

from aurora.batch import Batch
from aurora.model.decoder import Perceiver3DDecoder
Expand All @@ -16,7 +17,7 @@
"""type: Tuple of variable names."""


class Aurora(nn.Module):
class Aurora(torch.nn.Module):
"""The Aurora model.
Defaults to to the 1.3 B parameter configuration.
Expand Down Expand Up @@ -141,6 +142,18 @@ def forward(self, batch: Batch) -> Batch:

return pred

def load_checkpoint(self, repo: str, name: str) -> None:
path = hf_hub_download(repo_id=repo, filename=name)
d = torch.load(path, map_location="cpu")

# Rename keys to ensure compatibility.
for k, v in list(d.items()):
if k.startswith("net."):
del d[k]
d[k[4:]] = v

self.load_state_dict(d, strict=True)


AuroraSmall = partial(
Aurora,
Expand All @@ -150,4 +163,5 @@ def forward(self, batch: Batch) -> Batch:
decoder_num_heads=(16, 8, 4),
embed_dim=256,
num_heads=8,
use_lora=False,
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"torch",
"einops",
"timm==0.6.13",
"huggingface-hub",
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

import os
from datetime import datetime

import torch
Expand All @@ -10,6 +11,8 @@
def test_aurora_small():
model = AuroraSmall()

model.load_checkpoint(os.environ["HUGGINGFACE_REPO"], "aurora-0.25-small-pretrained.ckpt")

batch = Batch(
surf_vars={k: torch.randn(1, 2, 16, 32) for k in ("2t", "10u", "10v", "msl")},
static_vars={k: torch.randn(1, 2, 16, 32) for k in ("lsm", "z", "slt")},
Expand Down

0 comments on commit 0375bca

Please sign in to comment.