From 7e2e4bf4c99fbdb0e77132e6e7578e8f37d6ec9d Mon Sep 17 00:00:00 2001 From: Alon Ziv <30550331+lonzi@users.noreply.github.com> Date: Wed, 17 Jan 2024 14:08:48 +0200 Subject: [PATCH] Magnet xformers 0.0.22 compatibility fix (#394) MAGNeT - Fix for xformers 0.0.22 compatibility, thanks to @nateraw catch. In addition, the following smaller fixes are also contained in this PR: * MAGNeT notebook - change to stride1 span arrangement by default. * MAGNeT doc fix of a typo. * MAGNeT music training grid typo fix. --- grids/magnet/magnet_32khz.py | 2 +- models/lm_magnet.py | 20 ++++++++++++++------ modules/transformer.py | 7 +++++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/grids/magnet/magnet_32khz.py b/grids/magnet/magnet_32khz.py index c3575b30..036de25d 100644 --- a/grids/magnet/magnet_32khz.py +++ b/grids/magnet/magnet_32khz.py @@ -12,7 +12,7 @@ def explorer(launcher): partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) launcher.slurm_(gpus=32, partition=partitions) - launcher.bind_(solver='magnet/magnet_base_32khz') + launcher.bind_(solver='magnet/magnet_32khz') # replace this by the desired music dataset launcher.bind_(dset='internal/music_400k_32khz') diff --git a/models/lm_magnet.py b/models/lm_magnet.py index 4c2ab9ee..201a3b76 100644 --- a/models/lm_magnet.py +++ b/models/lm_magnet.py @@ -40,7 +40,9 @@ def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = self.causal = kwargs['causal'] self.subcodes_context = subcodes_context self.span_len = span_len - self._build_attn_masks(compression_model_framerate, segment_duration, + self._build_attn_masks(compression_model_framerate=compression_model_framerate, + segment_duration=segment_duration, + num_heads=kwargs['num_heads'], device=kwargs['device'], dtype=kwargs['dtype']) def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: @@ -64,12 +66,13 @@ def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype torch.zeros([], device=device, dtype=dtype), torch.full([], float('-inf'), device=device, dtype=dtype)) - def _stage_attn_mask(self, stage: int, seq_len: int, + def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int, device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]: """Creates a restricted attention mask given the stage (codebook index). Args: stage (int): The codebook index. Takes values in [0, n_q]. seq_len (int): Token sequence length. + num_heads (int): Num transformer attention heads. device (torch.device): device of the output tensor. dtype (torch.dtype): data type of the output tensor. Returns: @@ -82,29 +85,34 @@ def _stage_attn_mask(self, stage: int, seq_len: int, sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype) if sa_mask is not None: + # Repeat for each attention head + sa_mask = sa_mask.repeat((1, num_heads, 1, 1)) + # align8 to enable memory efficient attention MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8 seq_len_aligned = \ int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR - sa_mask_aligned = torch.zeros((seq_len_aligned, seq_len_aligned), device=device, dtype=dtype) - sa_mask_aligned[:seq_len, :seq_len] = sa_mask + sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype) + sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask sa_mask = sa_mask_aligned return sa_mask - def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, + def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int, device: torch.device, dtype: torch.dtype): """Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range, either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list. Args: compression_model_framerate (int): The frame rate of the tokenizer. segment_duration (int): Sample length in seconds. + num_heads (int): Num transformer attention heads. device (torch.device): device of the output tensor. dtype (torch.dtype): data type of the output tensor. """ seq_len = compression_model_framerate * segment_duration - self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, device, dtype) for stage in range(self.n_q)] + self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads, + device, dtype) for stage in range(self.n_q)] @torch.no_grad() def generate(self, diff --git a/modules/transformer.py b/modules/transformer.py index 818e98c0..4d44b39e 100644 --- a/modules/transformer.py +++ b/modules/transformer.py @@ -401,10 +401,13 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, q, k, v = [x.float() for x in [q, k, v]] if self.memory_efficient: if custom_attn_mask: - # When using a custom attn mask: move to query's device + remove align8 padding + # When using a custom attn mask: + # Move to query's device, repeat for each sample, remove align8 padding seq_len = query.shape[1] attn_mask = attn_mask.to(q.dtype) - attn_mask = attn_mask[:seq_len, :seq_len] + attn_mask = attn_mask.repeat((q.shape[0], 1, 1, 1)) + attn_mask = attn_mask[..., :seq_len, :seq_len] + p = self.dropout if self.training else 0 if _efficient_attention_backend == 'torch': x = torch.nn.functional.scaled_dot_product_attention(