Skip to content

Commit

Permalink
docs: update docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Oct 3, 2024
1 parent c04e641 commit 6523b47
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 13 deletions.
25 changes: 18 additions & 7 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float | None = 0.0,
alibi_slopes: Tensor | None = None,
use_alibi_slopes: bool | None = None,
):
"""Initialize MultiHeadSelfAttention.
Expand All @@ -56,11 +56,10 @@ def __init__(
dropout probability, by default 0.0
softcap : float, optional
Anything > 0 activates softcapping attention, by default 0.0
alibi_slopes : Tensor, optional
(nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j,
by default None
use_alibi_slopes : bool, optional
Adds bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
to the attention score of query i and key j, where alibi_slope
is calculated using get_alibi_slopes, by default None
"""
super().__init__()

Expand All @@ -78,7 +77,7 @@ def __init__(
self.dropout_p = dropout_p
self.is_causal = is_causal
self.softcap = softcap
self.use_alibi_slopes = True # use_alibi_slopes
self.use_alibi_slopes = use_alibi_slopes

if self.use_alibi_slopes is not None:
self.alibi_slopes = get_alibi_slopes(num_heads)
Expand Down Expand Up @@ -161,6 +160,18 @@ def forward(


def get_alibi_slopes(num_heads: int) -> Tensor:
"""Calculates linearly decreasing slopes for alibi attention.
Parameters
----------
num_heads : int
number of attention heads
Returns
-------
Tensor
aLiBi slopes
"""
n = 2 ** math.floor(math.log2(num_heads))
slope_0 = 2.0 ** (-8.0 / n)
alibi_slopes = torch.pow(slope_0, torch.arange(1, 1 + n))
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float | None = 0.0,
alibi_slopes: Tensor | None = None,
use_alibi_slopes: bool | None = None,
):
super().__init__()

Expand All @@ -86,7 +86,7 @@ def __init__(
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
alibi_slopes=alibi_slopes,
use_alibi_slopes=use_alibi_slopes,
)

self.mlp = nn.Sequential(
Expand Down
8 changes: 6 additions & 2 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
dropout_p: float = 0.0,
use_flash_attention: bool = False,
softcap: float | None = 0.0,
alibi_slopes: Tensor | None = None,
use_alibi_slopes: bool | None = None,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -94,6 +94,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default 0.0
use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(num_channels=num_channels, num_layers=num_layers)

Expand All @@ -107,7 +111,7 @@ def __init__(
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
alibi_slopes=alibi_slopes,
use_alibi_slopes=use_alibi_slopes,
)

def forward(
Expand Down
8 changes: 6 additions & 2 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
dropout_p: float = 0.1,
use_flash_attention: bool = False,
softcap: float | None = 0.0,
alibi_slopes: Tensor | None = None,
use_alibi_slopes: Tensor | None = None,
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand All @@ -119,6 +119,10 @@ def __init__(
Activation function, by default "GELU"
dropout_p: float, optional
Dropout probability used for multi-head self attention, default 0.0
softcap : float, optional
Anything > 0 activates softcapping flash attention, by default 0.0
use_alibi_slopes : bool, optional
Use aLiBI option, only used for flash attention, by default None
"""
super().__init__(
num_channels=num_channels,
Expand All @@ -142,7 +146,7 @@ def __init__(
dropout_p=dropout_p,
use_flash_attention=use_flash_attention,
softcap=softcap,
alibi_slopes=alibi_slopes,
use_alibi_slopes=use_alibi_slopes,
)

self.offload_layers(cpu_offload)
Expand Down

0 comments on commit 6523b47

Please sign in to comment.