Skip to content

Commit

Permalink
Merge pull request #12 from microsoft/wesselb/fix-highres-model-settings
Browse files Browse the repository at this point in the history
Add `AuroraHighRes`
  • Loading branch information
megstanley authored Aug 23, 2024
2 parents c39cfb1 + da93aff commit 82c67ce
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
12 changes: 9 additions & 3 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aurora.model.lora import LoRAMode
from aurora.model.swin3d import Swin3DTransformerBackbone

__all__ = ["Aurora", "AuroraSmall"]
__all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"]


class Aurora(torch.nn.Module):
Expand Down Expand Up @@ -63,14 +63,14 @@ def __init__(
window_size (tuple[int, int, int], optional): Vertical height, height, and width of the
window of the underlying Swin transformer.
encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer.
encoder_num_heads (tuple[int, ...], optional) Number of attention heads in each encoder
encoder_num_heads (tuple[int, ...], optional): Number of attention heads in each encoder
layer. The dimensionality doubles after every layer. To keep the dimensionality of
every head constant, you want to double the number of heads after every layer. The
dimensionality of attention head of the first layer is determined by `embed_dim`
divided by the value here. For all cases except one, this is equal to `64`.
decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer.
Generally, you want this to be the reversal of `encoder_depths`.
decoder_num_heads (tuple[int, ...], optional) Number of attention heads in each decoder
decoder_num_heads (tuple[int, ...], optional): Number of attention heads in each decoder
layer. Generally, you want this to be the reversal of `encoder_num_heads`.
latent_levels (int, optional): Number of latent pressure levels.
patch_size (int, optional): Patch size.
Expand Down Expand Up @@ -250,3 +250,9 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None:
num_heads=8,
use_lora=False,
)

AuroraHighRes = partial(
Aurora,
encoder_depths=(6, 8, 8),
decoder_depths=(8, 8, 6),
)
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ Models

.. autoclass:: aurora.AuroraSmall
:members:

.. autoclass:: aurora.AuroraHighRes
:members:
8 changes: 4 additions & 4 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ Aurora 0.1° Fine-Tuned is a high-resolution version of Aurora.
### Usage

```python
from aurora import Aurora
from aurora import AuroraHighRes

model = Aurora()
model = AuroraHighRes()
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt")
```

Expand Down Expand Up @@ -170,8 +170,8 @@ Therefore, you should use the static variables provided in
you can turn off LoRA to obtain more realistic predictions at the expensive of slightly higher long-term MSE:

```python
from aurora import Aurora
from aurora import AuroraHighRes

model = Aurora(use_lora=False) # Disable LoRA for more realistic samples.
model = AuroraHighRes(use_lora=False) # Disable LoRA for more realistic samples.
model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt", strict=False)
```

0 comments on commit 82c67ce

Please sign in to comment.