Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Feature/transformer sequence sharding #90

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Keep it human-readable, your future self will thank you!
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)
- Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69)
- Add sequence sharding strategy for TransformerProcessor [#90](https://github.com/ecmwf/anemoi-models/pull/90)

### Changed
- Bugfixes for CI
Expand Down
123 changes: 123 additions & 0 deletions src/anemoi/models/distributed/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# nor does it submit to any jurisdiction.


import logging
from typing import Optional

import torch
Expand All @@ -17,6 +18,8 @@

from anemoi.models.distributed.utils import get_memory_format

LOGGER = logging.getLogger(__name__)


def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor:
"""Apply all_to_all along the head dimension.
Expand Down Expand Up @@ -82,6 +85,72 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N
return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format)


def _halo_exchange(input_: Tensor, halo_size: int, mgroup: ProcessGroup, bwd: bool = False) -> Tensor:
"""Exchange halo regions between neighboring ranks.

Expected format is (batch_size, halo_size + sequence_length + halo_size, channels).

Parameters
----------
input_ : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group
bwd : bool
Flag to indicate if backward pass

Returns
-------
Tensor
Tensor with halo regions from neighboring ranks
"""
end = input_.shape[-2]

left_halo_slice = slice(0, halo_size)
right_halo_slice = slice(end - halo_size, end)
left_send_slice = slice(halo_size, 2 * halo_size)
right_send_slice = slice(end - 2 * halo_size, end - halo_size)

if bwd: # reverse halo exchange direction for gradient accumulation
left_halo_slice, left_send_slice = left_send_slice, left_halo_slice
right_halo_slice, right_send_slice = right_send_slice, right_halo_slice

left_send = input_[:, left_send_slice, :]
right_send = input_[:, right_send_slice, :]

# setup neighbor ranks and tensor lists for all_to_all communication
group_rank = dist.get_rank(mgroup)
group_size = dist.get_world_size(mgroup)
left_rank = group_rank - 1 if group_rank > 0 else None
right_rank = group_rank + 1 if group_rank < group_size - 1 else None

input_list = [torch.empty(0, device=input_.device) for _ in range(group_size)]
if left_rank is not None:
input_list[left_rank] = left_send
if right_rank is not None:
input_list[right_rank] = right_send
output_list = [torch.empty_like(input_i, device=input_.device) for input_i in input_list]

dist.all_to_all(output_list, input_list, group=mgroup)

if bwd: # add gradient contributions to halo regions and zero out send regions
if left_rank is not None:
input_[:, left_send_slice, :] = 0
input_[:, left_halo_slice, :] += output_list[left_rank]
if right_rank is not None:
input_[:, right_send_slice, :] = 0
input_[:, right_halo_slice, :] += output_list[right_rank]
else: # add halo regions to input tensor
if left_rank is not None:
input_[:, left_halo_slice, :] = output_list[left_rank]
if right_rank is not None:
input_[:, right_halo_slice, :] = output_list[right_rank]

return input_


def shard_heads(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor:
"""Sync tensor.

Expand Down Expand Up @@ -130,6 +199,36 @@ def shard_sequence(input_: Tensor, shapes: list, mgroup: ProcessGroup) -> Tensor
return _SplitSequenceParallelSection.apply(input_, shapes, mgroup)


def halo_exchange(x: Tensor, halo_size: int, mgroup: ProcessGroup) -> Tensor:
"""Exchange halo regions between ranks,

Parameters
----------
x : Tensor
Input tensor
halo_size : int
Halo size (left, right)
mgroup : ProcessGroup
Model communication group

Returns
-------
Tensor, int, int
Tensor appended with halo regions from neighboring ranks, left halo size, right halo size
"""
if mgroup is None or dist.get_world_size(mgroup) == 1:
return x, 0, 0

# pad tensor with halo regions
halo_size_left = halo_size if dist.get_rank(mgroup) != 0 else 0
halo_size_right = halo_size if dist.get_rank(mgroup) != dist.get_world_size(mgroup) - 1 else 0
x_pad = torch.nn.functional.pad(x, pad=(0, 0, halo_size_left, halo_size_right), mode="constant", value=0)

out = _HaloExchange.apply(x_pad, halo_size, mgroup)

return out, halo_size_left, halo_size_right


class _SplitHeadsParallelSection(torch.autograd.Function):
"""Sync the input from parallel section."""

Expand Down Expand Up @@ -172,3 +271,27 @@ def backward(ctx, grad_output):
None,
)
return grad_output, None, None


class _HaloExchange(torch.autograd.Function):
"""Exchange halo regions between ranks."""

@staticmethod
def forward(ctx, input_, halo_size_, mgroup_):
ctx.halo_size = halo_size_
ctx.mgroup = mgroup_

if mgroup_:
return _halo_exchange(input_, halo_size_, mgroup_)
return input_

@staticmethod
def backward(ctx, grad_output):
if ctx.mgroup:
return (
_halo_exchange(grad_output, ctx.halo_size, ctx.mgroup, bwd=True),
None,
None,
)

return grad_output, None, None
70 changes: 57 additions & 13 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
else:
_FLASH_ATTENTION_AVAILABLE = True


from anemoi.models.distributed.transformer import halo_exchange
from anemoi.models.distributed.transformer import shard_heads
from anemoi.models.distributed.transformer import shard_sequence

Expand All @@ -42,6 +44,7 @@ def __init__(
is_causal: bool = False,
window_size: Optional[int] = None,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -55,38 +58,75 @@ def __init__(
self.window_size = (window_size, window_size) # flash attention
self.dropout_p = dropout_p
self.is_causal = is_causal
self.shard_strategy = shard_strategy

self.lin_qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
self.attention = attn_func

if not _FLASH_ATTENTION_AVAILABLE:
LOGGER.warning("Flash attention not available, falling back to pytorch scaled_dot_product_attention")

if shard_strategy not in ["shard_heads", "shard_sequence"]:
raise ValueError(f"Invalid shard_strategy: {shard_strategy}")

if shard_strategy == "shard_sequence": # remove this after PR #47 is merged (sliding window support)
assert _FLASH_ATTENTION_AVAILABLE, "Flash attention is required for shard_sequence strategy"

self.projection = nn.Linear(embed_dim, embed_dim, bias=True)

def forward(
self, x: Tensor, shapes: list, batch_size: int, model_comm_group: Optional[ProcessGroup] = None
) -> Tensor:
query, key, value = self.lin_qkv(x).chunk(3, -1)

if model_comm_group:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded accross GPUs"

query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
if self.shard_strategy == "shard_sequence":
assert (
shapes[-1][0] // 2 >= self.window_size[0]
), "Sharded sequence length must be at least twice the window size"

# unpack grid dimension first to allow for halo exchange
x_bgc = einops.rearrange(
x,
"(batch grid) channels -> batch grid channels",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)
# communicate halos (adds halos to x)
x_plus_halos, halo_size_left, halo_size_right = halo_exchange(
x_bgc, halo_size=self.window_size[0], mgroup=model_comm_group
)

# compute q, k, v (on local sequence shards with halos)
query, key, value = self.lin_qkv(x_plus_halos).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"batch grid (heads vars) -> batch heads grid vars",
heads=self.num_heads,
)
for t in (query, key, value)
)
else: # shard_heads
query, key, value = self.lin_qkv(x).chunk(3, -1)

query, key, value = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
batch=batch_size,
heads=self.num_heads,
)
for t in (query, key, value)
)

query = shard_heads(query, shapes=shapes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shapes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shapes, mgroup=model_comm_group)

dropout_p = self.dropout_p if self.training else 0.0

if _FLASH_ATTENTION_AVAILABLE:
Expand All @@ -104,7 +144,11 @@ def forward(
dropout_p=dropout_p,
) # expects (batch heads grid variable) format

out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)
if self.shard_strategy == "shard_sequence":
out = out[:, :, halo_size_left : out.shape[-2] - halo_size_right, :] # remove halos
else: # shard_heads
out = shard_sequence(out, shapes=shapes, mgroup=model_comm_group)

out = einops.rearrange(out, "batch heads grid vars -> (batch grid) (heads vars)")

out = self.projection(out)
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
activation: str,
window_size: int,
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
):
super().__init__()

Expand All @@ -87,6 +88,7 @@ def __init__(
bias=False,
is_causal=False,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.mlp = nn.Sequential(
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
dropout_p: float = 0.0,
shard_strategy: str = "shard_heads",
) -> None:
"""Initialize TransformerProcessor.

Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
activation=activation,
window_size=window_size,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

def forward(
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
dropout_p: float = 0.1,
shard_strategy: str = "shard_heads",
**kwargs,
) -> None:
"""Initialize TransformerProcessor.
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
window_size=window_size,
activation=activation,
dropout_p=dropout_p,
shard_strategy=shard_strategy,
)

self.offload_layers(cpu_offload)
Expand Down
Loading