From 8205d85448b2750b845a3810ccd792770f981678 Mon Sep 17 00:00:00 2001 From: Honglu Fan <64070721+honglu2875@users.noreply.github.com> Date: Wed, 13 Dec 2023 21:17:46 +0100 Subject: [PATCH] Completely refactor and test YaRN finetuning (#78) * add yarn finetuning * fixing type hints * yarn refactor * fix bugs * bug and format * bug fix * in-place RoPE; major fix for YaRN * format * clean up; remove redundant because we stick to NeoX style. * dropping too long error because we do YaRN * update config * fix yarn * fix yarn; change some defaults * revamp kv cache * fixing sample function; removing max length constraint (dynamic yarn can go a little further) * format * minor bug * fixing batch size * fix cache device * fix attention mask * bug fix * bug fix * revamp sampling code; refactor kv cache * format * fix mask * fix bugs * format * fix typo on mscale default and dynamic scaling * format --- aria/model/cache.py | 67 +++++++ aria/model/dynamic_yarn.py | 165 ---------------- aria/model/model.py | 287 +++++++++++++-------------- aria/model/utils.py | 18 ++ aria/model/yarn_rotary_embedding.py | 223 +++++++++++++++++++++ aria/run.py | 13 +- aria/sample.py | 293 +++++++++++++++------------- config/models/large_yarn.json | 4 +- requirements.txt | 1 + tests/__init__.py | 0 tests/reference_implementations.py | 46 +++++ tests/test_models.py | 64 +++++- 12 files changed, 725 insertions(+), 456 deletions(-) create mode 100644 aria/model/cache.py delete mode 100644 aria/model/dynamic_yarn.py create mode 100644 aria/model/utils.py create mode 100644 aria/model/yarn_rotary_embedding.py create mode 100644 tests/__init__.py create mode 100644 tests/reference_implementations.py diff --git a/aria/model/cache.py b/aria/model/cache.py new file mode 100644 index 0000000..e60b399 --- /dev/null +++ b/aria/model/cache.py @@ -0,0 +1,67 @@ +from typing import Optional + +import torch + + +class KVCache(torch.nn.Module): + def __init__( + self, max_batch_size, n_head, d_head, dtype=torch.float16, max_size=8192 + ): + super().__init__() + self.shape = (max_batch_size, max_size, n_head, d_head) + self.register_buffer( + "k_cache", torch.empty(self.shape, dtype=dtype), persistent=False + ) + self.register_buffer( + "v_cache", torch.empty(self.shape, dtype=dtype), persistent=False + ) + self.next_pos = 0 + + def update( + self, + k, + v, + pos: Optional[torch.Tensor] = None, + start_pos: int = 0, + max_pos: Optional[int] = None, + ): + """ + Update the kv cache and return the new k, v sequences of vectors + + Args: + k: key to update. Shape: (batch_size, num_positions, n_head, d_head) + v: value to update. Shape: (batch_size, num_positions, n_head, d_head) + pos: positions to update. Shape: (num_positions,). + Example: None to append to the end of the cache. + [0, 1, 2, 3, 4] to update the first 5 positions. + [5] to only update the 6th position. + start_pos: the starting position of the cache. Default to 0 + max_pos: the maximum position to update. Default to None. + Only used when pos is *NOT* None. Can be inferred from pos.max(), + but such an operation causes a sync with massive overhead + due to dynamic shape. + """ + if pos is None: + self.k_cache[ + : k.size(0), self.next_pos : self.next_pos + k.size(1) + ] = k + self.v_cache[ + : v.size(0), self.next_pos : self.next_pos + v.size(1) + ] = v + self.next_pos += k.size(1) + else: + assert pos.size(0) == k.size(1) + assert max_pos is not None, ( + "Need to pass in `pos.max()` explicitly. " + "Doing `pos.max()` creates massive overhead." + ) + self.k_cache[: k.size(0), pos] = k + self.v_cache[: v.size(0), pos] = v + # Update next_pos using the max entry. + # Note: `self.next_pos = pos.max() + 1` could have worked, but it + # causes the shape to be dynamic and creates a massive overhead. + self.next_pos = max_pos + 1 + return ( + self.k_cache[: k.size(0), start_pos : self.next_pos], + self.v_cache[: v.size(0), start_pos : self.next_pos], + ) diff --git a/aria/model/dynamic_yarn.py b/aria/model/dynamic_yarn.py deleted file mode 100644 index 49f66c5..0000000 --- a/aria/model/dynamic_yarn.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Dynamic Yarn""" - -# https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaDynamicYaRNScaledRotaryEmbedding.py - -import torch -import math - - -# Inverse dim formula to find dim based on number of rotations -def find_correction_dim( - num_rotations, dim, base=10000, max_position_embeddings=2048 -): - return ( - dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) - ) / (2 * math.log(base)) - - -# Find dim range bounds based on rotations -def find_correction_range( - low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 -): - low = math.floor( - find_correction_dim(low_rot, dim, base, max_position_embeddings) - ) - high = math.ceil( - find_correction_dim(high_rot, dim, base, max_position_embeddings) - ) - return max(low, 0), min(high, dim - 1) # Clamp values just in case - - -def linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - -def get_mscale(scale=1, coeff=0.1): - if scale <= 1: - return 1.0 - return coeff * math.log(scale) + 1.0 - - -class DynamicYaRNScaledRotaryEmbedding(torch.nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - original_max_position_embeddings=2048, - extrapolation_factor=1, - attn_factor=1, - mscale_coeff=0.1, - beta_fast=32, - beta_slow=1, - finetuned=False, - device=None, - ): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.original_max_position_embeddings = original_max_position_embeddings - self.extrapolation_factor = extrapolation_factor - self.attn_factor = attn_factor - self.mscale_coeff = mscale_coeff - self.beta_fast = beta_fast - self.beta_slow = beta_slow - - if finetuned: - self.yarn( - self.max_position_embeddings - / self.original_max_position_embeddings, - device, - ) - else: - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2).float().to(device) / dim) - ) - self.register_buffer("inv_freq", inv_freq) - self.mscale = 1 - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype, - ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to - # obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() - - self.register_buffer( - "cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False - ) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in - # `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - - self.yarn(seq_len / self.original_max_position_embeddings, x.device) - - t = torch.arange( - self.max_seq_len_cached, - device=x.device, - dtype=self.inv_freq.dtype, - ) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order - # to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - - self.register_buffer( - "cos_cached", - (emb.cos() * self.mscale).to(x.dtype), - persistent=False, - ) - self.register_buffer( - "sin_cached", - (emb.sin() * self.mscale).to(x.dtype), - persistent=False, - ) - return ( - self.cos_cached[:seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:seq_len, ...].to(dtype=x.dtype), - ) - - def yarn(self, scale, device): - pos_freqs = self.base ** ( - torch.arange(0, self.dim, 2).float().to(device) / self.dim - ) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (scale * pos_freqs) - - low, high = find_correction_range( - self.beta_fast, - self.beta_slow, - self.dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = ( - 1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device) - ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_mask) - + inv_freq_extrapolation * inv_freq_mask - ) - - self.register_buffer("inv_freq", inv_freq) - self.mscale = float( - get_mscale(scale, self.mscale_coeff) * self.attn_factor - ) # Get n-d magnitude scaling corrected for interpolation diff --git a/aria/model/model.py b/aria/model/model.py index dae9686..c18f9d3 100644 --- a/aria/model/model.py +++ b/aria/model/model.py @@ -1,11 +1,14 @@ """Includes (PyTorch) transformer model and config classes.""" from dataclasses import dataclass +from typing import Optional, Union + import torch import torch.utils.checkpoint from torch import nn as nn from torch.nn import functional as F -from aria.model.dynamic_yarn import DynamicYaRNScaledRotaryEmbedding +from aria.model.yarn_rotary_embedding import YaRNScaledRotaryEmbedding +from aria.model.cache import KVCache @dataclass @@ -25,9 +28,14 @@ class YaRNConfig: beta_fast: int = 16 beta_slow: int = 1 - scale: int = 1.0 - mscale_coeff: int = 0.07 + # `max_len * scale` would be the actual max context length for the run + scale: float = 1.0 + mscale_coeff: float = 0.1 base: float = 10000.0 + # Whether the underlying weights are already finetuned with YaRN + finetuned: bool = False + # Whether to use dynamic YaRN beyond the context length * scale + dynamic: bool = True @dataclass @@ -37,9 +45,10 @@ class ModelConfig: n_layers: int ff_mult: int drop_p: float - max_seq_len: int + max_seq_len: int # The original context length *WITHOUT* considering YaRN grad_checkpoint: bool - yarn_config: dict | YaRNConfig | None = None + yarn_config: Optional[Union[dict, YaRNConfig]] = None + vocab_size: Optional[int] = None def __post_init__(self): if self.yarn_config is not None and isinstance(self.yarn_config, dict): @@ -49,80 +58,6 @@ def set_vocab_size(self, vocab_size: int): self.vocab_size = vocab_size -# Taken from GPT-NeoX see: -# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py -class RotaryEmbedding(torch.nn.Module): - def __init__( - self, - dim, - max_position_embeddings=2048, - base=10000, - ): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) - ) - self.register_buffer("inv_freq", inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype - ) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to - # obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos().to(dtype), persistent=False - ) - self.register_buffer( - "sin_cached", emb.sin().to(dtype), persistent=False - ) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=x.dtype - ) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -def rotate_half(x): - x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - - return torch.cat( - (-x2, x1), dim=x1.ndim - 1 - ) # dim=-1 triggers a bug in earlier torch versions - - -@torch.jit.script -def apply_rotary_pos_emb(q, k, cos, sin, past_len: int = 0): - """Returns tuple (xq, xk). Expects shape (s_len, b_sz, n_head, d_head).""" - cos = cos[past_len : past_len + q.size(0), None, None] - sin = sin[past_len : past_len + q.size(0), None, None] - return (q * cos) + (rotate_half(q) * sin), (k * cos) + ( - rotate_half(k) * sin - ) - - class FusedEncoderBlock(nn.Module): """Transformer block using F.scaled_dot_product_attention(). @@ -147,20 +82,18 @@ def __init__(self, model_config: ModelConfig): self.max_seq_len = model_config.max_seq_len # Positional embeddings - if model_config.yarn_config is not None: - # TODO: Need more testing on this; - cfg = model_config.yarn_config - self.rotary_emb = DynamicYaRNScaledRotaryEmbedding( - self.d_head, - max_position_embeddings=round(self.max_seq_len * cfg.scale), - original_max_position_embeddings=self.max_seq_len, - beta_fast=cfg.beta_fast, - beta_slow=cfg.beta_slow, - base=cfg.base, - mscale_coeff=cfg.mscale_coeff, - ) - else: - self.rotary_emb = RotaryEmbedding(self.d_head) + cfg = model_config.yarn_config or YaRNConfig() + self.rotary_emb = YaRNScaledRotaryEmbedding( + self.d_head, + original_context_length=self.max_seq_len, + scaling_factor=cfg.scale, + beta_fast=cfg.beta_fast, + beta_slow=cfg.beta_slow, + base=cfg.base, + mscale_coeff=cfg.mscale_coeff, + finetuned=cfg.finetuned, + dynamic=cfg.dynamic, + ) # Attention self.mixed_qkv = nn.Linear( @@ -190,43 +123,83 @@ def __init__(self, model_config: ModelConfig): self.norm1 = nn.LayerNorm(model_config.d_model) self.norm2 = nn.LayerNorm(model_config.d_model) - def forward(self, x: torch.Tensor, use_cache=False, past_kv=None): - att, kv = self._att_block( - self.norm1(x), use_cache=use_cache, past_kv=past_kv + def forward(self, x: torch.Tensor, attn_mask=None, past_kv=None): + att = self._att_block( + self.norm1(x), attn_mask=attn_mask, past_kv=past_kv ) x = x + att x = x + self._ff_block(self.norm2(x)) - return x, kv + return x + + @staticmethod + def _create_mask( + q_len: int, + k_len: int, + attn_mask: Optional[torch.Tensor] = None, + device=None, + ): + # Could have cached some of these masks (not the entire (seq_len, seq_len)!!). + # But profiler seems to show that their impact is negligible. + + # attn_mask: (b_sz, k_len) + mask = torch.ones(q_len, k_len, dtype=torch.bool, device=device) + mask = torch.tril(mask, diagonal=k_len - q_len) + if attn_mask is not None: + # (1, q_len, k_len) & (b_sz, 1, k_len) + mask = mask[None, ...] & attn_mask[:, None, :] + return mask[:, None] + else: + return mask - def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None): + def _att_block( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + input_positions: Optional[torch.Tensor] = None, + max_pos: Optional[int] = None, + past_kv: Optional[KVCache] = None, + ): + """ + Args: + x: (b_sz, s_len, d_model) + attn_mask: (b_sz, s_len). The attention mask. `False` masks the column(keys) + in the attention matrix. + input_positions: (s_len,). The absolute position of each token. + If None, we assume that the input positions are contiguous. + max_pos: The maximum position of the input. Only used when input_positions + is not None. Can be inferred as input_positions.max(), but such an + operation makes the cache update slower due to dynamic shape. + past_kv: A KVCache object. + """ batch_size, seq_len, _ = x.shape + past_len = 0 if past_kv is None else past_kv.next_pos + mixed_qkv = self.mixed_qkv(x) xq, xk, xv = mixed_qkv.chunk(3, -1) # Reshape for rotary embeddings - xq = xq.view(batch_size, seq_len, self.n_heads, self.d_head) - xk = xk.view(batch_size, seq_len, self.n_heads, self.d_head) + # Need contiguous for q, k since in-place RoPE cannot be applied on a view + xq = xq.reshape( + batch_size, seq_len, self.n_heads, self.d_head + ).contiguous() + xk = xk.reshape( + batch_size, seq_len, self.n_heads, self.d_head + ).contiguous() xv = xv.view(batch_size, seq_len, self.n_heads, self.d_head) - past_len = 0 if past_kv is None else past_kv[0].size(1) - # apply_rotary_post_emb expects: (s_len, b_sz, n_head, d_head) - cos, sin = self.rotary_emb(x=xv, seq_len=seq_len + past_len) - xq, xk = xq.transpose(0, 1), xk.transpose(0, 1) - xq, xk = apply_rotary_pos_emb( - q=xq, k=xk, cos=cos, sin=sin, past_len=past_len + # apply_rotary_post_emb expects: (b_sz, s_len, n_head, d_head) + xq, xk = self.rotary_emb( + xq, xk, input_positions=input_positions, past_len=past_len ) - xq, xk = xq.transpose(0, 1), xk.transpose(0, 1) # xq, xk: (b_sz, s_len, n_head, d_head) if past_kv is not None: - assert len(past_kv) == 2 - xk = torch.concat([past_kv[0], xk], axis=1) - xv = torch.concat([past_kv[1], xv], axis=1) - kv = (xk, xv) + xk, xv = past_kv.update( + xk, xv, pos=input_positions, max_pos=max_pos + ) + # Reshape for attention calculation: (b_sz, n_head, s_len, d_head) - xq = xq.transpose(1, 2) - xk = xk.transpose(1, 2) - xv = xv.transpose(1, 2) + xq, xk, xv = map(lambda t: t.transpose(1, 2), (xq, xk, xv)) # Required as we are not using a nn.Dropout layer if self.training: @@ -234,9 +207,10 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None): else: att_dropout = 0.0 - # Using beta torch functionality (subject to change) - # See - https://shorturl.at/jtI17 - if past_kv is None: + # Calculate attention + # Note: we avoid explicitly saving a (seq_len, seq_len) cache in order to + # save vRAM. + if past_kv is None and attn_mask is None: att = F.scaled_dot_product_attention( query=xq, key=xk, @@ -245,8 +219,9 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None): is_causal=True, ) else: - assert xq.size(2) == 1 - mask = torch.ones(1, xk.size(2), dtype=bool, device=xk.device) + mask = self._create_mask( + xq.size(2), xk.size(2), attn_mask=attn_mask, device=xk.device + ) att = F.scaled_dot_product_attention( query=xq, key=xk, @@ -255,15 +230,14 @@ def _att_block(self, x: torch.Tensor, use_cache=False, past_kv=None): is_causal=False, attn_mask=mask, ) + # If masked token show up in query, they come out as nan. Need to set to zero. + att = torch.nan_to_num(att, nan=0.0) # Reshape for out: (b_sz, s_len, n_head, d_head) out = att.transpose(1, 2).contiguous() out = out.view(batch_size, seq_len, self.n_heads * self.d_head) - return ( - self.resid_dropout(self.att_proj_linear(out)), - kv if use_cache else None, - ) + return self.resid_dropout(self.att_proj_linear(out)) def _ff_block(self, x: torch.Tensor): x = self.ff_linear_2(self.ff_activation(self.ff_linear_1(x))) @@ -292,12 +266,21 @@ def __init__(self, model_config: ModelConfig): for _ in range(model_config.n_layers): self.encode_layers.append(FusedEncoderBlock(model_config)) - def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): + def forward( + self, + src: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + past_kv: Optional[list[KVCache]] = None, + ): """Forward pass of Transformer. Args: src (torch.tensor): Input to encoder block, of shape (batch_size, seq_len, d_model). + attn_mask (Optional[torch.tensor]): Attention mask of shape + (batch_size, seq_len). Defaults to None. + past_kv (Optional[list[KVCache]]): a list of kv caches. The list index + corresponds to the layer index. Returns: torch.tensor: Model outputs with shape (batch_size, seq_len, @@ -305,13 +288,11 @@ def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): """ hidden_states = self.tok_embeddings(src) - assert src.shape[1] <= self.model_config.max_seq_len, "Too long." - # NOTE: If you want to use gradient checkpointing then you must # remove torch.compile from the train script as this is not currently # supported. # Implements gradient checkpoints on Encoder Layers. - if self.model_config.grad_checkpoint is True and not use_cache: + if self.model_config.grad_checkpoint is True: for layer in self.encode_layers: def create_custom_forward(module): @@ -320,27 +301,23 @@ def custom_forward(*args): return custom_forward - hidden_states, _ = torch.utils.checkpoint.checkpoint( + hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), hidden_states, + attn_mask, preserve_rng_state=True, use_reentrant=True, ) else: - new_past_kv = [] past_kv = ( [None] * len(self.encode_layers) if past_kv is None else past_kv ) for layer, _kv in zip(self.encode_layers, past_kv): - hidden_states, kv = layer( - hidden_states, use_cache=use_cache, past_kv=_kv + hidden_states = layer( + hidden_states, attn_mask=attn_mask, past_kv=_kv ) - new_past_kv.append(kv) - return ( - self.out_layer_norm(hidden_states), - new_past_kv if use_cache else None, - ) + return self.out_layer_norm(hidden_states) class TransformerLM(nn.Module): @@ -359,21 +336,47 @@ def __init__(self, model_config: ModelConfig): model_config.d_model, model_config.vocab_size, bias=False ) - def forward(self, src: torch.Tensor, use_cache=False, past_kv=None): + def forward( + self, + src: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + past_kv: Optional[list[KVCache]] = None, + ): """Forward pass of Transformer decoder with LM head. Args: src (torch.tensor): Input to encoder block, of shape (batch_size, seq_len, d_model). + attn_mask (Optional[torch.tensor]): Attention mask of shape + (batch_size, seq_len). Defaults to None. + past_kv (Optional[list[KVCache]]): a list of kv caches. The list index + corresponds to the layer index. Returns: torch.tensor: Forward pass of src through Transformer and LM head. Has shape (batch_size, seq_len, vocab_size). """ - hidden, past_kv = self.model(src, use_cache=use_cache, past_kv=past_kv) + hidden = self.model(src, attn_mask=attn_mask, past_kv=past_kv) logits = self.lm_head(hidden) - if use_cache: - return logits, past_kv - else: - return logits + return logits + + def get_cache( + self, max_batch_size: int = 16, max_len: int = 2048, device=None + ): + """ + Initialize an empty kv cache according to the model parameters. + We do not make KVCache a part of the model because one may apply techniques + such as CFG utilizing multiple caches. + """ + return [ + KVCache( + max_batch_size=max_batch_size, + max_size=max_len, + n_head=self.model.model_config.n_heads, + d_head=self.model.model_config.d_model + // self.model.model_config.n_heads, + dtype=next(self.parameters()).dtype, + ).to(device) + for _ in range(self.model.model_config.n_layers) + ] diff --git a/aria/model/utils.py b/aria/model/utils.py new file mode 100644 index 0000000..c29ab70 --- /dev/null +++ b/aria/model/utils.py @@ -0,0 +1,18 @@ +import torch + + +@torch.jit.script +def apply_rotary_pos_emb(x, cos, sin, past_len: int = 0): + """ + In-place RoPE. Credits to Katherine Crowson: + x shape (b_sz, s_len, n_head, d_head). + cos, sin shape (s_len, d_head // 2). + """ + d = cos.shape[-1] + cos = cos[None, past_len : past_len + x.size(1), None] + sin = sin[None, past_len : past_len + x.size(1), None] + x1, x2 = x[..., :d], x[..., d : d * 2] + tmp = x1.clone() + x1.mul_(cos).addcmul_(x2, sin, value=-1) + x2.mul_(cos).addcmul_(tmp, sin, value=1) + return x diff --git a/aria/model/yarn_rotary_embedding.py b/aria/model/yarn_rotary_embedding.py new file mode 100644 index 0000000..6a3e872 --- /dev/null +++ b/aria/model/yarn_rotary_embedding.py @@ -0,0 +1,223 @@ +from typing import Tuple, Optional + +import torch +import math +from aria.model.utils import apply_rotary_pos_emb + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale=1.0, coeff=0.1): + if scale <= 1: + return 1.0 + return coeff * math.log(scale) + 1.0 + + +class YaRNScaledRotaryEmbedding(torch.nn.Module): + """ + Adapted from: + https://github.com/jquesnelle/yarn/blob/master/scaled_rope/modeling_llama_together_yarn.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + pos_idx_in_fp32=True, + original_context_length=2048, + scaling_factor=1.0, + extrapolation_factor=1.0, + attn_factor=1.0, + mscale_coeff=0.1, + beta_fast=32, + beta_slow=1, + dynamic=False, + finetuned=False, + device=None, + ): + """ + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + """ + super().__init__() + + self.dim = dim + self.base = float(base) + self.original_context_length = original_context_length + self.scaling_factor = scaling_factor + + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.mscale_coeff = mscale_coeff + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor, self.mscale_coeff) + * attn_factor + ) + self.dynamic = dynamic + self.finetuned = finetuned + + # Generate and save the inverse frequency buffer (non-trainable) + if not dynamic: + self._compute_inv_freq(self.scaling_factor, device) + else: + self._compute_inv_freq_original(device) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _compute_inv_freq(self, scaling_factor, device=None): + pos_freqs = self.base ** ( + torch.arange(0, self.dim, 2).float().to(device) / self.dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.dim, + self.base, + self.original_context_length, + ) + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.dim // 2) + .float() + .to(device) + ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + self.register_buffer("inv_freq", inv_freq) + + def _compute_inv_freq_original(self, device=None): + inv_freq = 1 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq) + + def _update_cos_sin_cache(self, seq_len, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seq_len > self._seq_len_cached + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seq_len + + if self.dynamic: + scaling_factor = None + if ( + seq_len + <= self.original_context_length * self.scaling_factor + ): + if self.finetuned: + scaling_factor = self.scaling_factor + else: + scaling_factor = seq_len / (self.original_context_length) + if scaling_factor: + self._compute_inv_freq(scaling_factor, device) + self.mscale = float( + _yarn_get_mscale(scaling_factor, self.mscale_coeff) + * self.attn_factor + ) + else: + self._compute_inv_freq_original(device) + + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seq_len, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self.inv_freq.to(torch.float32) + else: + inv_freq = self.inv_freq + else: + t = torch.arange( + seq_len, device=device, dtype=self.inv_freq.dtype + ) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + self._cos_cached = (torch.cos(freqs) * self.mscale).to(dtype) + self._sin_cached = (torch.sin(freqs) * self.mscale).to(dtype) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + input_positions: Optional[torch.Tensor] = None, + past_len: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q: (batch, q_len, n_heads, head_dim) + k: (batch, k_len, n_heads, head_dim) + input_positions: (batch, *) + past_len: the length before the second axis of q (usually it is just the kv length) + """ + self._update_cos_sin_cache( + max( + q.size(1) + past_len, + self.original_context_length * self.scaling_factor, + ), + device=q.device, + dtype=q.dtype, + ) + return apply_rotary_pos_emb( + q, + self._cos_cached[past_len : past_len + q.size(1)], + self._sin_cached[past_len : past_len + q.size(1)], + ), apply_rotary_pos_emb( + k, + self._cos_cached[past_len : past_len + k.size(1)], + self._sin_cached[past_len : past_len + k.size(1)], + ) diff --git a/aria/run.py b/aria/run.py index 5721c44..8bcb857 100644 --- a/aria/run.py +++ b/aria/run.py @@ -176,14 +176,8 @@ def _quantize(module, key, input_shape): args.p ) # let user input midi path if not provided - if args.l and 0 < args.l < model.max_seq_len: - max_gen_len = args.l - else: - max_gen_len = model.max_seq_len - - assert ( - truncate_len < model_config.max_seq_len - ), "Truncate length longer than maximum length supported by the model." + assert args.l > 0, "Generation length must be positive." + max_new_tokens = args.l # Load and format prompts midi_dict = MidiDict.from_midi(mid_path=midi_path) @@ -199,8 +193,7 @@ def _quantize(module, key, input_shape): prompts, device=device, force_end=force_end, - max_seq_len=model_config.max_seq_len, - max_gen_len=max_gen_len, + max_new_tokens=max_new_tokens, ) if os.path.isdir("samples") is False: diff --git a/aria/sample.py b/aria/sample.py index 886af41..26a70cd 100644 --- a/aria/sample.py +++ b/aria/sample.py @@ -1,5 +1,4 @@ """Contains generation/sampling code""" - # This file contains code from https://github.com/facebookresearch/llama which # is available under the following licence: @@ -16,9 +15,6 @@ from aria.model import TransformerLM from aria.tokenizer import Tokenizer -# TODO: -# - Enable sampling sequences longer than max_seq_len by truncating - def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): if cfg_mode is None: @@ -39,16 +35,67 @@ def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len): raise ValueError(f"Unknown cfg_mode: {cfg_mode}") +def _process_prompts( + prompts, + pad_token="

", + neg_prompts=None, + use_cfg=False, + neg_prompt_len=None, +) -> list: + """ + Preprocess prompts for generation. + If cfg is used, + 1. the prompts and negative prompts will be combined. + 2. the negative prompts will be truncated for at most as long as the longest prompt. + Args: + prompts: list of prompts + pad_token: pad token ('

') + neg_prompts: list of negative prompts + use_cfg: whether to use cfg + neg_prompt_len: max length of negative prompts. If more than the longest prompt, + pad to this length. + Returns: + list of padded prompts + """ + processed_prompts = [] + max_len = max(len(t) for t in prompts) + pad_len = max(max_len, neg_prompt_len or 0) + + if use_cfg: + if neg_prompts is None: + neg_prompts = [t[-1:] for t in prompts] + assert len(prompts) == len( + neg_prompts + ), "Prompts and neg_prompts must have the same count." + + for prompt in prompts + neg_prompts: + processed_prompts.append( + [pad_token] * max(0, pad_len - len(prompt)) + prompt[:pad_len] + ) + else: + max_len = max(len(t) for t in prompts) + for prompt in prompts: + processed_prompts.append( + [pad_token] * (max_len - len(prompt)) + prompt + ) + + return processed_prompts + + +def _batch_encode(tokenizer, prompts: list[list]) -> torch.Tensor: + return torch.stack([tokenizer.encode(p) for p in prompts], dim=0) + + # Some good settings: # temp=0.85, top_p=0.9, cfg_gamma=1.4 +@torch.no_grad() def greedy_sample( model: TransformerLM, tokenizer: Tokenizer, prompts: List[list], - max_seq_len: int, - max_gen_len: int, + max_new_tokens: int, device: torch.device | None = None, cfg_gamma: float | None = 1.4, cfg_mode: str | None = None, @@ -65,8 +112,7 @@ def greedy_sample( model (TransformerLM): Model to sample from. tokenizer (Tokenizer): Tokenizer corresponding to model. prompts (List[list]): A list of prompts to sample as a batch. - max_seq_len (int): Maximum sequence length supported by the model. - max_gen_len (int): Maximum desired sequence length of the samples. + max_new_tokens (int): Maximum number of new generated tokens. device (torch.device, optional): Device to use. Defaults to None. cfg_gamma (float, optional): CFG gamma parameter. Defaults to 1.2. This parameter *determines* whether parameters related to CFG are used. @@ -77,8 +123,9 @@ def greedy_sample( "sine": sine curve from 1 -> gamma -> 1 neg_prompts (List[list], optional): Alternative prompts to sample from. Defaults to None ("unconditioned" model is approximated using only the last tokens of prompts). - neg_prompt_len (int, optional): Length of the negative prompts. - Defaults to None (minimal length of neg_prompts). + neg_prompt_len (int, optional): Max length used for the negative prompts. + Defaults to None (align to prompts). + When set, if `neg_prompt_len > max(t for t in prompts)`, we pad to `neg_prompt_len`. alpha (float, optional): an alpha parameter during interpolation. Only takes effect when neg_prompt_len < minimal length of neg_prompts. Defaults to 0.4. force_end (bool, optional): Whether to force the end of the prompt. Defaults to False. @@ -93,148 +140,126 @@ def greedy_sample( model.eval() pad_id = tokenizer.pad_id + pad_tok = tokenizer.pad_tok eos_id = tokenizer.tok_to_id[tokenizer.eos_tok] - bsz = len(prompts) - min_prompt_size = min([len(t) for t in prompts]) - max_prompt_size = max([len(t) for t in prompts]) - total_len = min(max_seq_len, max_gen_len + max_prompt_size) + padded_combined_prompts = _process_prompts( + prompts, + pad_tok, + neg_prompts, + cfg_gamma is not None, + neg_prompt_len=neg_prompt_len, + ) + if neg_prompts is not None: + padded_negative_prompts = _process_prompts( + neg_prompts, pad_tok, None, False + ) + else: + padded_negative_prompts = [t[-1:] for t in prompts] + + prompt_len = len(padded_combined_prompts[0]) + if neg_prompts is not None: + neg_offset_len = max(0, prompt_len - max(len(t) for t in prompts)) + else: + neg_offset_len = 0 if force_end: - assert ( - total_len - max_prompt_size > 130 - ), "prompt too long to use force_end=True" + assert max_new_tokens > 130, "prompt too long to use force_end=True" print( - f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_gen_len}" + f"Using hyperparams: temp={temperature}, top_p={top_p}, gamma={cfg_gamma}, gen_len={max_new_tokens}" ) - if cfg_gamma is not None: - # todo: maybe it already works with varying prompts - assert ( - min_prompt_size == max_prompt_size - ), "CFG not supported with varying prompts" - if neg_prompts is None: - neg_prompts = [prompts[-1] for _ in range(bsz)] - - neg_min_len = min(total_len, min(len(a) for a in neg_prompts)) - neg_max_len = max(total_len, max(len(a) for a in neg_prompts)) - neg_prompt_tensors = torch.stack( - [ - torch.concat( - [ - torch.full( - (neg_max_len - len(neg_seq),), pad_id, device=device - ), - tokenizer.encode(neg_seq).to(device), - ] - ) - for neg_seq in neg_prompts - ], - axis=0, - ) - neg_len = ( - neg_min_len - if neg_prompt_len is None - else min(neg_min_len, neg_prompt_len) - ) - neg_tokens = neg_prompt_tensors[:, :neg_len] + total_len = prompt_len + max_new_tokens + tokens = torch.full( + (len(padded_combined_prompts), total_len), pad_id, device=device + ) + tokens[:, :prompt_len] = _batch_encode( + tokenizer, padded_combined_prompts + ).to(device) + full_neg_tokens = _batch_encode(tokenizer, padded_negative_prompts).to( + device + ) - tokens = torch.full((bsz, total_len), pad_id, device=device) - for idx, unencoded_seq in enumerate(prompts): - tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to( - device - ) + dim_tok_inserted = [False for _ in range(tokens.size(0))] + attn_mask = torch.ones( + (len(padded_combined_prompts), total_len), + device=device, + dtype=torch.bool, + ) + attn_mask[:, :prompt_len] = tokens[:, :prompt_len] != pad_id + start_pos = prompt_len - dim_tok_inserted = [False for _ in range(bsz)] - input_text_mask = tokens != pad_id - start_pos = min_prompt_size + past_kv = model.get_cache( + max_batch_size=tokens.size(0), max_len=total_len, device=device + ) + + for cur_pos in ( + pbar := tqdm( + range(start_pos, total_len), + total=total_len - start_pos, + leave=False, + ) + ): + if cur_pos == start_pos: + token = tokens[:, :start_pos] + else: + token = tokens[:, cur_pos - 1 : cur_pos] - past_kv = None - cfg_kv = None - neg_previous_token = None + logits = model.forward( + token, attn_mask=attn_mask[:, :cur_pos], past_kv=past_kv + ) + logits = logits[:, -1, :] - with torch.inference_mode(): - for cur_pos in ( - pbar := tqdm( - range(start_pos, total_len), - total=total_len - start_pos, - leave=False, + if cfg_gamma is not None: + coeff = _get_cfg_coeff( + cfg_gamma, cfg_mode, cur_pos, start_pos, total_len ) + cond_logits = logits[: logits.size(0) // 2] + uncond_logits = logits[logits.size(0) // 2 :] + logits = uncond_logits + coeff * (cond_logits - uncond_logits) + + if temperature > 0: + probs = torch.softmax(logits / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits, dim=-1) + next_token = next_token.reshape(-1) + + # When alpha is used, in the first `max_new_tokens * alpha` generations, the negative + # prompt completions still use its original content (if not exceeding). After that, the + # negative prompt completions will be updated by the new tokens. + if ( + alpha is not None + and cur_pos - neg_offset_len < full_neg_tokens.size(0) + and cur_pos - start_pos < max_new_tokens * alpha ): - token = ( - tokens[:, :start_pos] - if cur_pos == start_pos - else tokens[:, cur_pos - 1 : cur_pos] - ) - logits, past_kv = model.forward( - token, use_cache=True, past_kv=past_kv - ) - logits = logits[:, -1, :] - if cfg_gamma is not None and max_prompt_size < cur_pos: - coeff = _get_cfg_coeff( - cfg_gamma, cfg_mode, cur_pos, start_pos, total_len - ) - - if cur_pos == start_pos: - neg_tok = neg_tokens - elif neg_previous_token is None: - neg_tok = tokens[ - :, (cur_pos - start_pos) + neg_len - 1 - ].unsqueeze(1) - else: - neg_tok = neg_previous_token.unsqueeze(1) - uncond_logits, cfg_kv = model.forward( - neg_tok, use_cache=True, past_kv=cfg_kv - ) - uncond_logits = uncond_logits[:, -1, :] - logits = uncond_logits + coeff * (logits - uncond_logits) - - if temperature > 0: - probs = torch.softmax(logits / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) - else: - next_token = torch.argmax(logits, dim=-1) - next_token = next_token.reshape(-1) - # Only replace token if prompt has already been generated - next_token = torch.where( - input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token - ) + neg_slice = full_neg_tokens[:, cur_pos - neg_offset_len] + next_token = torch.concat([next_token, neg_slice], dim=0) + else: + if cfg_gamma is not None: + next_token = next_token.repeat(2) # Also update neg prompts + + # Insert dim tokens + if force_end and cur_pos >= total_len - 130: + for _idx in range(tokens.size(0)): + if ( + dim_tok_inserted[_idx] is False + and tokenizer.id_to_tok[next_token[_idx].item()][0] != "dur" + ): + next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok] + + # Update dim_tok_inserted + for _idx in range(tokens.size(0)): + if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: + dim_tok_inserted[_idx] = True - # Insert dim tokens - if force_end and cur_pos >= total_len - 130: - for _idx in range(bsz): - if ( - dim_tok_inserted[_idx] is False - and tokenizer.id_to_tok[next_token[_idx].item()][0] - != "dur" - ): - next_token[_idx] = tokenizer.tok_to_id[ - tokenizer.dim_tok - ] - - # Update dim_tok_inserted - for _idx in range(bsz): - if next_token[_idx] == tokenizer.tok_to_id[tokenizer.dim_tok]: - dim_tok_inserted[_idx] = True - - tokens[:, cur_pos] = next_token - if alpha is not None and cur_pos - start_pos < min( - neg_max_len - neg_len, alpha * (total_len - start_pos) - ): - _neg_tokens = neg_prompt_tensors[ - :, cur_pos - start_pos + neg_len - ] - neg_previous_token = torch.where( - _neg_tokens != pad_id, _neg_tokens, next_token - ) - else: - neg_previous_token = next_token + tokens[:, cur_pos] = next_token decoded = [] for idx, seq in enumerate(tokens.tolist()): - # Cut to max gen len - seq = seq[: len(prompts[idx]) + max_gen_len] + if cfg_gamma is not None and 2 * idx >= tokens.size(0): + break # Cut to eos tok if any try: seq = seq[: seq.index(eos_id)] diff --git a/config/models/large_yarn.json b/config/models/large_yarn.json index 9497273..fd049f2 100644 --- a/config/models/large_yarn.json +++ b/config/models/large_yarn.json @@ -4,7 +4,7 @@ "n_layers": 64, "ff_mult": 4, "drop_p": 0.1, - "max_seq_len": 4096, + "max_seq_len": 2048, "grad_checkpoint": false, - "yarn_config": {} + "yarn_config": {"scale": 4.0, "mscale_coeff": 0.1} } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 534179c..1464769 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ accelerate mido jsonlines pydub +einops safetensors \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/reference_implementations.py b/tests/reference_implementations.py new file mode 100644 index 0000000..3d3f90f --- /dev/null +++ b/tests/reference_implementations.py @@ -0,0 +1,46 @@ +# Reference implementations from +# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py + + +def rotate_half(x, interleaved=False): + # Lazy import only when needed + import torch + from einops import rearrange + + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_pos_emb_reference(x, cos, sin, interleaved=False): + """ + x: (batch_size, seq_len, n_heads, head_dim) + cos, sin: (seq_len, rotary_dim / 2) or (batch_size, seq_len, rotary_dim / 2) + """ + # Lazy import only when needed + import torch + from einops import repeat + + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", + ) + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)", + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) \ No newline at end of file diff --git a/tests/test_models.py b/tests/test_models.py index 90621bd..916e65f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,9 +1,14 @@ import logging +import torch import unittest +from aria.data.midi import MidiDict from aria.model import ModelConfig, TransformerLM from aria.config import load_model_config +from aria.sample import greedy_sample from aria.model.model import YaRNConfig +from aria.model.utils import apply_rotary_pos_emb +from .reference_implementations import apply_rotary_pos_emb_reference from aria.tokenizer import TokenizerLazy @@ -18,14 +23,67 @@ def test_yarn_config(self): model_config.set_vocab_size(tokenizer.vocab_size) model = TransformerLM(model_config) assert isinstance(model.model.model_config.yarn_config, YaRNConfig) - max_len = model.model.encode_layers[0].rotary_emb.max_position_embeddings - org_max_len = model.model.encode_layers[0].rotary_emb.original_max_position_embeddings - assert max_len == org_max_len assert model.model.encode_layers[0].rotary_emb.mscale_coeff == 0.07 assert model.model.encode_layers[0].rotary_emb.beta_fast == 32.0 assert model.model.encode_layers[0].rotary_emb.beta_slow == 1.0 assert model.model.encode_layers[0].rotary_emb.base == 10000.0 + def test_rope_util_fns(self): + q = torch.rand(4, 8, 12, 64) + inv_freq = 1 / (10000 ** (torch.arange(0, 64, 2, dtype=torch.float32) / 64)) + t = torch.arange(8, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + cos = torch.cos(freqs) + sin = torch.sin(freqs) + q_ref = apply_rotary_pos_emb_reference(q.clone(), cos, sin) + q = apply_rotary_pos_emb(q.clone(), cos, sin) + assert torch.allclose(q, q_ref, atol=1e-5) + + def test_attn_mask(self): + tokenizer = TokenizerLazy(return_tensors=True) + model_config = ModelConfig(**load_model_config("test")) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerLM(model_config) + assert model.model.model_config.yarn_config is None + model_config = ModelConfig(**load_model_config("test_yarn")) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerLM(model_config).eval() + + inp = torch.randint(0, 10000, (1, 10)) + attn_mask = torch.concat([torch.zeros((1, 5), dtype=torch.bool), torch.ones((1, 5), dtype=torch.bool)], dim=-1) + out = model(inp, attn_mask=attn_mask) + out2 = model(inp[:, -5:]) + assert torch.allclose(out[:, -5:], out2, atol=1e-5) + + def test_generation(self): + tokenizer = TokenizerLazy(return_tensors=True) + model_config = ModelConfig(**load_model_config("test")) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerLM(model_config) + assert model.model.model_config.yarn_config is None + model_config = ModelConfig(**load_model_config("test_yarn")) + model_config.set_vocab_size(tokenizer.vocab_size) + model = TransformerLM(model_config).eval() + + midi_dict = MidiDict.from_midi(mid_path="tests/test_data/basic.mid") + prompts = [tokenizer.tokenize(midi_dict=midi_dict)[:50]] * 3 + out = greedy_sample( + model, + tokenizer, + prompts, + device=torch.device("cpu"), + max_new_tokens=50, + ) + prompts = [[tokenizer.pad_tok] + tokenizer.tokenize(midi_dict=midi_dict)[:50]] * 3 + out2 = greedy_sample( + model, + tokenizer, + prompts, + device=torch.device("cpu"), + max_new_tokens=50, + ) + assert [u == v for u, v in zip(out, out2[1:])] + if __name__ == "__main__": logging.basicConfig(level=logging.INFO)