Skip to content

Commit

Permalink
adding global and local window mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 2, 2024
1 parent 67aa900 commit 65a0425
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
31 changes: 31 additions & 0 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,33 @@ def sequence_id_mask_fn(
return sequence_id_mask_fn


def _get_local_global_mask_mod_fn(
sequence_id_transform: dict[str, torch.Tensor],
sliding_window_size: int,
global_window_size: int,
) -> _mask_mod_signature:
sequence_id = sequence_id_transform['sequence_id']
pos_in_seq = sequence_id_transform['pos_in_seq']

def local_global_mask_fn(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
del h
# Check if the query and key belong to the same sequence and the query token is not a padding token.

sequence_id_mask = (sequence_id[b, q_idx] == sequence_id[b, kv_idx]
) & (sequence_id[b, q_idx] != -1)
global_window_mask = (pos_in_seq[b, kv_idx] <= global_window_size)
sliding_window_mask = (q_idx - kv_idx <= sliding_window_size)

return sequence_id_mask & (global_window_mask | sliding_window_mask)

return local_global_mask_fn


def _generate_score_mod(score_mod_list: list[dict[str, Any]],):
score_mod = flex_attention_score_mods.get('noop')()
for mod_dict in score_mod_list:
Expand Down Expand Up @@ -1410,3 +1437,7 @@ def build_alibi_bias(
'sequence_id',
func=_get_sequence_id_mask_mod_fn,
)
flex_attention_mask_mods.register(
'local_global_mask',
func=_get_local_global_mask_mod_fn,
)
37 changes: 34 additions & 3 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def attn_mask_in_len_transformer(
sequence_id: Union[torch.Tensor, None],
S: int,
attention_mask: Union[torch.Tensor, None],
return_pos_in_seq: bool = False,
):
"""Generates the attention mask used for sequence masking in FA v2.
Expand All @@ -210,6 +211,7 @@ def attn_mask_in_len_transformer(
sequence_id (Union[None, torch.Tensor]): Tensor containing the sequence id for each token. Shape (batch_size, seq_len).
S (int): Sequence length
attention_mask (Union[torch.Tensor, None]): Attention mask tensor of shape (batch_size, seq_len)
return_pos_in_seq (bool): If True, returns the position in sequence tensor instead of attn mask in length. Default is False.
Returns:
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
Expand Down Expand Up @@ -259,13 +261,16 @@ def attn_mask_in_len_transformer(
# We replace those -1 with 0 to prevent `torch.nn.functional.one_hot(sequence_id)` in the next line from failing.
# We apply the attention mask again after the one_hot operation.
sequence_id = sequence_id.masked_fill(~attention_mask, 0)
attention_mask_in_length = torch.nn.functional.one_hot(sequence_id)
one_hot_seq_id = torch.nn.functional.one_hot(sequence_id)
if attention_mask is not None:
attention_mask_in_length = attention_mask_in_length.masked_fill(
one_hot_seq_id = one_hot_seq_id.masked_fill(
~attention_mask.unsqueeze(-1),
0,
)
attention_mask_in_length = attention_mask_in_length.sum(dim=1)
if return_pos_in_seq:
return one_hot_seq_id.cumsum(dim=1).sum(dim=-1)

attention_mask_in_length = one_hot_seq_id.sum(dim=1)
attention_mask_in_length = torch.nn.functional.pad(
attention_mask_in_length,
(0, S - attention_mask_in_length.shape[-1]),
Expand All @@ -276,6 +281,28 @@ def attn_mask_in_len_transformer(
return attention_mask_in_length


def pos_in_seq_transformer(
sequence_id: Union[torch.Tensor, None],
S: int,
attention_mask: Union[torch.Tensor, None],
):
return {
'sequence_id':
seq_id_noop_transformer(
sequence_id,
S,
attention_mask,
),
'pos_in_seq':
attn_mask_in_len_transformer(
sequence_id,
S,
attention_mask,
return_pos_in_seq=True,
),
}


def seq_id_noop_transformer(
sequence_id: Union[torch.Tensor, None],
S: int,
Expand Down Expand Up @@ -1605,3 +1632,7 @@ def get_attention_flops(self, msl: int) -> int:
'attention_mask_in_length',
func=attn_mask_in_len_transformer,
)
sequence_id_transformer_registry.register(
'pos_in_seq',
func=pos_in_seq_transformer,
)

0 comments on commit 65a0425

Please sign in to comment.