Skip to content

Commit

Permalink
11 add configurability to dropout in multiheadselfattention module (#12)
Browse files Browse the repository at this point in the history
* feat: add configurability to dropout in MultiHeadSelfAttention

Co-authored-by: Rilwan (Akanni) Adewoyin <[email protected]>

* test: adjust to dropout_p

* doc: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <[email protected]>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <[email protected]>

* Update CHANGELOG.md to KeepChangelog format

* [pre-commit.ci] pre-commit autoupdate (#25)

updates:
- [github.com/psf/black-pre-commit-mirror: 24.4.2 → 24.8.0](psf/black-pre-commit-mirror@24.4.2...24.8.0)
- [github.com/astral-sh/ruff-pre-commit: v0.4.6 → v0.6.2](astral-sh/ruff-pre-commit@v0.4.6...v0.6.2)
- [github.com/tox-dev/pyproject-fmt: 2.1.3 → 2.2.1](tox-dev/pyproject-fmt@2.1.3...2.2.1)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

* Feature/integrate reusable workflows (#16)

* ci: add public pr label

* ci: add readthedocs update check

* ci: add downstream ci

* ci: add ci-config

* chore(deps): remove unused dependency

* docs: update changelog

* ci: switch to main

* chore: changelog 0.2.1

* Update error messages from invalid sub_graph in model instantiation (#20)

* ci: inherit pypi publish flow (#17)

* ci: inherit pypi publish flow

Co-authored-by: Helen Theissen <[email protected]>

* docs: add to changelog

* fix: typo in reusable workflow

* fix: another typo

* chore: bump actions/setup-python to v5

* ci: run downstream-ci for changes in src and tests

* docs: update changelog

---------

Co-authored-by: Helen Theissen <[email protected]>

* Update CHANGELOG.md to KeepChangelog format

* Ci/changelog-release-updater (#26)

* ci: add changelof release updater

* docs: update changelog

---------

Co-authored-by: Rilwan (Akanni) Adewoyin <[email protected]>
Co-authored-by: Gert Mertes <[email protected]>
Co-authored-by: Mario Santa Cruz <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Sep 18, 2024
1 parent 43846f0 commit 0219266
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ Keep it human-readable, your future self will thank you!

### Added

- CI workflow to update the changelog on release
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- CI workflow to update the changelog on release
- Remapper: Preprocessor for remapping one variable to multiple ones. Includes changes to the data indices since the remapper changes the number of variables. With optional config keywords.
- Codeowners file
- Pygrep precommit hooks
- Docsig precommit hooks
- Changelog merge strategy


### Changed

- Update CI to inherit from common infrastructue reusable workflows
Expand Down
15 changes: 11 additions & 4 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ def __init__(
bias: bool = False,
is_causal: bool = False,
window_size: Optional[int] = None,
dropout: float = 0.0,
dropout_p: float = 0.0,
):
super().__init__()

assert (
embed_dim % num_heads == 0
), f"Embedding dimension ({embed_dim}) must be divisible by number of heads ({num_heads})"

self.dropout = dropout
self.num_heads = num_heads
self.embed_dim = embed_dim
self.head_dim = embed_dim // num_heads # q k v
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
Expand Down Expand Up @@ -86,15 +86,22 @@ def forward(
query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
query, key, value = (
einops.rearrange(t, "batch heads grid vars -> batch grid heads vars") for t in (query, key, value)
)
out = self.attention(query, key, value, causal=False, window_size=self.window_size)
out = self.attention(query, key, value, causal=False, window_size=self.window_size, dropout_p=dropout_p)
out = einops.rearrange(out, "batch grid heads vars -> batch heads grid vars")
else:
out = self.attention(query, key, value, is_causal=False) # expects (batch heads grid variable) format
out = self.attention(
query,
key,
value,
is_causal=False,
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")
Expand Down
12 changes: 10 additions & 2 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def forward(
class TransformerProcessorBlock(BaseBlock):
"""Transformer block with MultiHeadSelfAttention and MLPs."""

def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size: int):
def __init__(
self,
num_channels: int,
hidden_dim: int,
num_heads: int,
activation: str,
window_size: int,
dropout_p: float = 0.0,
):
super().__init__()

try:
Expand All @@ -72,7 +80,7 @@ def __init__(self, num_channels, hidden_dim, num_heads, activation, window_size:
window_size=window_size,
bias=False,
is_causal=False,
dropout=0.0,
dropout_p=dropout_p,
)

self.mlp = nn.Sequential(
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -88,6 +89,8 @@ def __init__(
ratio of mlp hidden dimension to embedding dimension, default 4
activation : str, optional
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -98,6 +101,7 @@ def __init__(
num_heads=num_heads,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
)

def forward(
Expand Down
4 changes: 4 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
cpu_offload: bool = False,
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -113,6 +114,8 @@ def __init__(
ratio of mlp hidden dimension to embedding dimension, default 4
activation : str, optional
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -133,6 +136,7 @@ def __init__(
num_layers=self.chunk_size,
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
)

self.offload_layers(cpu_offload)
Expand Down
13 changes: 10 additions & 3 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@ class TestTransformerProcessorBlock:
num_heads=st.integers(min_value=1, max_value=10),
activation=st.sampled_from(["ReLU", "GELU", "Tanh"]),
window_size=st.integers(min_value=1, max_value=512),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size):
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size)
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
)
assert isinstance(block, TransformerProcessorBlock)

assert isinstance(block.layer_norm1, nn.LayerNorm)
Expand All @@ -49,6 +52,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
window_size=st.integers(min_value=1, max_value=512),
shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3),
batch_size=st.integers(min_value=1, max_value=40),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_forward_output(
Expand All @@ -60,9 +64,12 @@ def test_forward_output(
window_size,
shapes,
batch_size,
dropout_p,
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(num_channels, hidden_dim, num_heads, activation, window_size)
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
)

x = torch.randn((batch_size, num_channels))

Expand Down
4 changes: 4 additions & 0 deletions tests/layers/chunk/test_chunk_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def init(self):
mlp_hidden_ratio: int = 4
activation: str = "GELU"
window_size: int = 13
dropout_p: float = 0.1

# num_heads must be evenly divisible by num_channels for MHSA
return (
Expand All @@ -29,6 +30,7 @@ def init(self):
mlp_hidden_ratio,
activation,
window_size,
dropout_p,
)

@pytest.fixture
Expand All @@ -40,6 +42,7 @@ def processor_chunk(self, init):
mlp_hidden_ratio,
activation,
window_size,
dropout_p,
) = init
return TransformerProcessorChunk(
num_channels=num_channels,
Expand All @@ -48,6 +51,7 @@ def processor_chunk(self, init):
mlp_hidden_ratio=mlp_hidden_ratio,
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
)

def test_all_blocks(self, processor_chunk):
Expand Down
6 changes: 6 additions & 0 deletions tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def transformer_processor_init():
cpu_offload = False
num_heads = 16
mlp_hidden_ratio = 4
dropout_p = 0.1
return (
num_layers,
window_size,
Expand All @@ -30,6 +31,7 @@ def transformer_processor_init():
cpu_offload,
num_heads,
mlp_hidden_ratio,
dropout_p,
)


Expand All @@ -44,6 +46,7 @@ def transformer_processor(transformer_processor_init):
cpu_offload,
num_heads,
mlp_hidden_ratio,
dropout_p,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -54,6 +57,7 @@ def transformer_processor(transformer_processor_init):
cpu_offload=cpu_offload,
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
)


Expand All @@ -67,6 +71,7 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_cpu_offload,
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
assert transformer_processor.num_chunks == num_chunks
Expand All @@ -84,6 +89,7 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_cpu_offload,
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
) = transformer_processor_init
gridsize = 100
batch_size = 1
Expand Down
16 changes: 10 additions & 6 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,32 @@
@given(
num_heads=st.integers(min_value=1, max_value=50),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p):
embed_dim = (
num_heads * embed_dim_multiplier
) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

assert isinstance(mhsa, nn.Module)
assert mhsa.num_heads == num_heads
assert mhsa.embed_dim == embed_dim
assert mhsa.head_dim == embed_dim // num_heads
assert dropout_p == mhsa.dropout_p


@pytest.mark.gpu
@given(
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
@settings(deadline=None)
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

x = torch.randn(batch_size * 2, embed_dim)
shapes = [list(x.shape)]
Expand All @@ -54,10 +57,11 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier):
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)

x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
shapes = [list(x.shape)]
Expand Down

0 comments on commit 0219266

Please sign in to comment.