Skip to content

Commit

Permalink
Magnet xformers 0.0.22 compatibility fix (facebookresearch#394)
Browse files Browse the repository at this point in the history
MAGNeT - Fix for xformers 0.0.22 compatibility, thanks to @nateraw catch.

In addition, the following smaller fixes are also contained in this PR:
* MAGNeT notebook - change to stride1 span arrangement by default.
* MAGNeT doc fix of a typo.
* MAGNeT music training grid typo fix.
  • Loading branch information
lonzi authored Jan 17, 2024
1 parent 3fc3c9f commit 7e2e4bf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion grids/magnet/magnet_32khz.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='magnet/magnet_base_32khz')
launcher.bind_(solver='magnet/magnet_32khz')
# replace this by the desired music dataset
launcher.bind_(dset='internal/music_400k_32khz')

Expand Down
20 changes: 14 additions & 6 deletions models/lm_magnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def __init__(self, subcodes_context: int = 5, compression_model_framerate: int =
self.causal = kwargs['causal']
self.subcodes_context = subcodes_context
self.span_len = span_len
self._build_attn_masks(compression_model_framerate, segment_duration,
self._build_attn_masks(compression_model_framerate=compression_model_framerate,
segment_duration=segment_duration,
num_heads=kwargs['num_heads'],
device=kwargs['device'], dtype=kwargs['dtype'])

def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
Expand All @@ -64,12 +66,13 @@ def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype
torch.zeros([], device=device, dtype=dtype),
torch.full([], float('-inf'), device=device, dtype=dtype))

def _stage_attn_mask(self, stage: int, seq_len: int,
def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
"""Creates a restricted attention mask given the stage (codebook index).
Args:
stage (int): The codebook index. Takes values in [0, n_q].
seq_len (int): Token sequence length.
num_heads (int): Num transformer attention heads.
device (torch.device): device of the output tensor.
dtype (torch.dtype): data type of the output tensor.
Returns:
Expand All @@ -82,29 +85,34 @@ def _stage_attn_mask(self, stage: int, seq_len: int,
sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)

if sa_mask is not None:
# Repeat for each attention head
sa_mask = sa_mask.repeat((1, num_heads, 1, 1))

# align8 to enable memory efficient attention
MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
seq_len_aligned = \
int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR

sa_mask_aligned = torch.zeros((seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
sa_mask_aligned[:seq_len, :seq_len] = sa_mask
sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
sa_mask = sa_mask_aligned

return sa_mask

def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int,
def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
device: torch.device, dtype: torch.dtype):
"""Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
Args:
compression_model_framerate (int): The frame rate of the tokenizer.
segment_duration (int): Sample length in seconds.
num_heads (int): Num transformer attention heads.
device (torch.device): device of the output tensor.
dtype (torch.dtype): data type of the output tensor.
"""
seq_len = compression_model_framerate * segment_duration
self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, device, dtype) for stage in range(self.n_q)]
self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
device, dtype) for stage in range(self.n_q)]

@torch.no_grad()
def generate(self,
Expand Down
7 changes: 5 additions & 2 deletions modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,10 +401,13 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
if custom_attn_mask:
# When using a custom attn mask: move to query's device + remove align8 padding
# When using a custom attn mask:
# Move to query's device, repeat for each sample, remove align8 padding
seq_len = query.shape[1]
attn_mask = attn_mask.to(q.dtype)
attn_mask = attn_mask[:seq_len, :seq_len]
attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1))
attn_mask = attn_mask[..., :seq_len, :seq_len]

p = self.dropout if self.training else 0
if _efficient_attention_backend == 'torch':
x = torch.nn.functional.scaled_dot_product_attention(
Expand Down

0 comments on commit 7e2e4bf

Please sign in to comment.