From d1ca22362378c90d244d4eeda03be6072d431a4b Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Fri, 23 Aug 2024 16:07:57 +0200 Subject: [PATCH 1/2] Fix config for 0.1 deg model --- aurora/model/aurora.py | 8 +++++++- docs/api.rst | 3 +++ docs/models.md | 8 ++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index c1b6a8d..1f55043 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -13,7 +13,7 @@ from aurora.model.lora import LoRAMode from aurora.model.swin3d import Swin3DTransformerBackbone -__all__ = ["Aurora", "AuroraSmall"] +__all__ = ["Aurora", "AuroraSmall", "AuroraHighRes"] class Aurora(torch.nn.Module): @@ -250,3 +250,9 @@ def load_checkpoint(self, repo: str, name: str, strict: bool = True) -> None: num_heads=8, use_lora=False, ) + +AuroraHighRes = partial( + Aurora, + encoder_depths=(6, 8, 8), + decoder_depths=(8, 8, 6), +) diff --git a/docs/api.rst b/docs/api.rst index 7ab0e60..4653e26 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,3 +22,6 @@ Models .. autoclass:: aurora.AuroraSmall :members: + +.. autoclass:: aurora.AuroraHighRes + :members: diff --git a/docs/models.md b/docs/models.md index 06e3992..167f5b9 100644 --- a/docs/models.md +++ b/docs/models.md @@ -127,9 +127,9 @@ Aurora 0.1° Fine-Tuned is a high-resolution version of Aurora. ### Usage ```python -from aurora import Aurora +from aurora import AuroraHighRes -model = Aurora() +model = AuroraHighRes() model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt") ``` @@ -170,8 +170,8 @@ Therefore, you should use the static variables provided in you can turn off LoRA to obtain more realistic predictions at the expensive of slightly higher long-term MSE: ```python -from aurora import Aurora +from aurora import AuroraHighRes -model = Aurora(use_lora=False) # Disable LoRA for more realistic samples. +model = AuroraHighRes(use_lora=False) # Disable LoRA for more realistic samples. model.load_checkpoint("microsoft/aurora", "aurora-0.1-finetuned.ckpt", strict=False) ``` From da93affc1d3ddf04e1c2a86238b721a12290a711 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Fri, 23 Aug 2024 16:08:41 +0200 Subject: [PATCH 2/2] Add missing colons --- aurora/model/aurora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 1f55043..82f37dd 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -63,14 +63,14 @@ def __init__( window_size (tuple[int, int, int], optional): Vertical height, height, and width of the window of the underlying Swin transformer. encoder_depths (tuple[int, ...], optional): Number of blocks in each encoder layer. - encoder_num_heads (tuple[int, ...], optional) Number of attention heads in each encoder + encoder_num_heads (tuple[int, ...], optional): Number of attention heads in each encoder layer. The dimensionality doubles after every layer. To keep the dimensionality of every head constant, you want to double the number of heads after every layer. The dimensionality of attention head of the first layer is determined by `embed_dim` divided by the value here. For all cases except one, this is equal to `64`. decoder_depths (tuple[int, ...], optional): Number of blocks in each decoder layer. Generally, you want this to be the reversal of `encoder_depths`. - decoder_num_heads (tuple[int, ...], optional) Number of attention heads in each decoder + decoder_num_heads (tuple[int, ...], optional): Number of attention heads in each decoder layer. Generally, you want this to be the reversal of `encoder_num_heads`. latent_levels (int, optional): Number of latent pressure levels. patch_size (int, optional): Patch size.