From a57a05078f2e502616a10bbebf2ab75391399690 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 25 Oct 2023 15:00:59 -0700 Subject: [PATCH 01/11] Add OlmoLlamaBlock --- olmo/model.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index 1c0fe0a6a..6c68ad462 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -454,6 +454,23 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) return bias + @classmethod + def _scaled_dot_product_attention(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + + This method is based on PyTorch's `scaled_dot_product_attention`. + """ + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + def attention( self, q: torch.Tensor, @@ -513,7 +530,7 @@ def attention( # Get the attention scores. # shape: (B, nh, T, hs) - att = F.scaled_dot_product_attention( + att = self._scaled_dot_product_attention( q, k, v, @@ -669,6 +686,108 @@ def forward( return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), cache +class OlmoLlamaBlock(OlmoBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `OlmoSequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + + # Attention input projection. Projects x -> (q, k, v) + if config.multi_query_attention: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model // config.n_heads + v_proj_out_dim = config.d_model // config.n_heads + else: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model + v_proj_out_dim = config.d_model + self.q_proj = nn.Linear( + config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.k_proj = nn.Linear( + config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.v_proj = nn.Linear( + config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + + @classmethod + def _scaled_dot_product_attention(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: + query_len, key_len = q.size(-2), k.size(-2) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) + + attn_bias = torch.zeros(query_len, key_len, dtype=q.dtype) + if is_causal: + assert attn_mask is None + + diagonal = key_len - query_len + 1 + context_mask = 1 - torch.triu(torch.ones_like(attn_bias, dtype=torch.int), diagonal=diagonal) + attn_bias.masked_fill_(context_mask.bool(), torch.finfo(q.dtype).min) + + if attn_mask is not None: + attn_bias += attn_mask.to(q.dtype) + + attn_weights += attn_bias + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) + attn_weights = torch.matmul(attn_weights, v) + attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) + return attn_weights + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + # Get attention scores. + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.ff_norm(x))))) + + return x, cache + + class OlmoOutput(NamedTuple): logits: torch.FloatTensor """ From 53d68c8fd889581a85495fa0a62543c48eb1c1b6 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 25 Oct 2023 15:02:54 -0700 Subject: [PATCH 02/11] Add config for using the Llama block --- olmo/config.py | 6 ++++++ olmo/model.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/olmo/config.py b/olmo/config.py index 01173cc77..72f2f6cdf 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -185,6 +185,12 @@ class BlockType(StrEnum): sequential = "sequential" parallel = "parallel" + llama = "llama" + """ + A block similar to the sequential block with slightly different + implementations of operations like attention to imitate the behavior of Llama. + """ + class InitFnType(StrEnum): mitchell = "mitchell" diff --git a/olmo/model.py b/olmo/model.py index 6c68ad462..b8ed362a2 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -559,6 +559,8 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl return OlmoSequentialBlock(layer_id, config, cache) elif config.block_type == BlockType.parallel: return OlmoParallelBlock(layer_id, config, cache) + elif config.block_type == BlockType.llama: + return OlmoLlamaBlock(layer_id, config, cache) else: raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'") From 6c4b8e1548f55ad1f3e638693c26de9ee932829f Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 26 Oct 2023 16:10:22 -0700 Subject: [PATCH 03/11] Fix inverted context mask in attention impl, further clean impl --- olmo/model.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index b8ed362a2..3ff737b35 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -738,26 +738,22 @@ def reset_parameters(self): @classmethod def _scaled_dot_product_attention(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: - query_len, key_len = q.size(-2), k.size(-2) + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) - attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(q.size(-1)) - - attn_bias = torch.zeros(query_len, key_len, dtype=q.dtype) + attn_bias = torch.zeros_like(attn_weights) if is_causal: assert attn_mask is None - diagonal = key_len - query_len + 1 - context_mask = 1 - torch.triu(torch.ones_like(attn_bias, dtype=torch.int), diagonal=diagonal) - attn_bias.masked_fill_(context_mask.bool(), torch.finfo(q.dtype).min) + context_mask = torch.ones_like(attn_bias, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(context_mask.logical_not(), torch.finfo(attn_bias.dtype).min) if attn_mask is not None: attn_bias += attn_mask.to(q.dtype) attn_weights += attn_bias attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) - attn_weights = torch.matmul(attn_weights, v) attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) - return attn_weights + return torch.matmul(attn_weights, v) def forward( self, From 7743b0f07605da3583d17bf41b6e8fe753e608ee Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 26 Oct 2023 16:38:37 -0700 Subject: [PATCH 04/11] Run black --- olmo/model.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 3ff737b35..2b9b0ad15 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -455,7 +455,15 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. return bias @classmethod - def _scaled_dot_product_attention(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: + def _scaled_dot_product_attention( + cls, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: """ Computes scaled dot product attention on query, key and value tensors, using an optional attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. @@ -737,7 +745,15 @@ def reset_parameters(self): init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) @classmethod - def _scaled_dot_product_attention(cls, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False) -> torch.Tensor: + def _scaled_dot_product_attention( + cls, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) attn_bias = torch.zeros_like(attn_weights) From 07eb67c5d03fbfaf0e250a74b0546de2c37cdce5 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 27 Oct 2023 16:33:07 -0700 Subject: [PATCH 05/11] Add rope precision --- olmo/config.py | 16 ++++++++++++++++ olmo/model.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/olmo/config.py b/olmo/config.py index 72f2f6cdf..1eaf765b9 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -277,6 +277,11 @@ class ModelConfig(BaseConfig): Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. """ + rope_precision_type: str = "fp32" + """ + Precision with which to apply RoPE (e.g. "amp_bf16", "amp_fp16", or "fp32"). + """ + flash_attention: bool = False """ If ``True``, use ``FlashAttention``. @@ -408,6 +413,17 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ + @property + def rope_precision(self) -> torch.dtype: + if self.rope_precision_type == "amp_bf16": + return torch.bfloat16 + elif self.rope_precision_type == "amp_fp16": + return torch.float16 + elif self.rope_precision_type == "fp32": + return torch.float32 + else: + raise ValueError(f"Unexpected precision type '{self.rope_precision_type}'") + class OptimizerType(StrEnum): lionw = "lionw" diff --git a/olmo/model.py b/olmo/model.py index 2b9b0ad15..59d3f1be7 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -309,7 +309,7 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: return out.to(t.dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - q_, k_ = q.float(), k.float() + q_, k_ = q.to(dtype=self.config.rope_precision), k.to(dtype=self.config.rope_precision) with torch.autocast(q.device.type, enabled=False): query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) From cebdbe53559a09f2bd2de9ecd80ff2e9fff55503 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 30 Oct 2023 11:05:17 -0700 Subject: [PATCH 06/11] Add missing type cast in updated RoPE --- olmo/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/olmo/model.py b/olmo/model.py index 59d3f1be7..047225b14 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -313,6 +313,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch with torch.autocast(q.device.type, enabled=False): query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) + pos_sin = pos_sin.type_as(q_) + pos_cos = pos_cos.type_as(q_) q_ = self.apply_rotary_pos_emb( pos_sin[:, :, key_len - query_len : key_len, :], pos_cos[:, :, key_len - query_len : key_len, :], From b8938d5b41f56d97600d619af8446d11cb0d0140 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 30 Oct 2023 17:43:38 -0700 Subject: [PATCH 07/11] Move some attention bias logic out of Olmo to use in OlmoLlamaBlock --- olmo/model.py | 100 +++++++++++++++++++++++++------------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index 047225b14..d7d014cc9 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -372,6 +372,44 @@ def output_multiplier(self) -> float: return 0.5 +def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: + att_bias = torch.triu( + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), + diagonal=1, + ) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore + + +def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[ + -1 + ] >= seq_len: + if causal_bias.device != device: + causal_bias = causal_bias.to(device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + with torch.autocast(device.type, enabled=False): + causal_bias = causal_attention_bias(seq_len, device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + + +def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) + + # shape: (1, 1, seq_len, seq_len) + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) + alibi_bias.abs_().mul_(-1) + + # shape: (n_heads,) + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) + m.mul_(config.alibi_bias_max / config.n_heads) + + # shape: (1, n_heads, seq_len, seq_len) + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore + + class OlmoBlock(nn.Module): """ A base class for transformer block implementations. @@ -456,9 +494,8 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) return bias - @classmethod def _scaled_dot_product_attention( - cls, + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -746,9 +783,8 @@ def reset_parameters(self): init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) - @classmethod def _scaled_dot_product_attention( - cls, + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -758,15 +794,15 @@ def _scaled_dot_product_attention( ) -> torch.Tensor: attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) - attn_bias = torch.zeros_like(attn_weights) if is_causal: assert attn_mask is None - context_mask = torch.ones_like(attn_bias, dtype=torch.bool).tril(diagonal=0) - attn_bias.masked_fill_(context_mask.logical_not(), torch.finfo(attn_bias.dtype).min) - - if attn_mask is not None: - attn_bias += attn_mask.to(q.dtype) + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] + elif attn_mask is not None: + attn_bias = attn_mask.to(q.dtype) + else: + attn_bias = torch.zeros_like(attn_weights) attn_weights += attn_bias attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) @@ -830,30 +866,6 @@ class OlmoGenerateOutput(NamedTuple): """ -def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: - att_bias = torch.triu( - torch.ones(seq_len, seq_len, device=device, dtype=torch.float), - diagonal=1, - ) - att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) - return att_bias.view(1, 1, seq_len, seq_len) # type: ignore - - -def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) - - # shape: (1, 1, seq_len, seq_len) - alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) - alibi_bias.abs_().mul_(-1) - - # shape: (n_heads,) - m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) - m.mul_(config.alibi_bias_max / config.n_heads) - - # shape: (1, n_heads, seq_len, seq_len) - return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore - - class Olmo(nn.Module): def __init__(self, config: ModelConfig, init_params: bool = True): super().__init__() @@ -912,7 +924,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): # Warm up cache. if self.config.alibi: - self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) @property @@ -945,19 +957,6 @@ def reset_parameters(self): for block in self.transformer.blocks: # type: ignore block.reset_parameters() # type: ignore - def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: - if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[ - -1 - ] >= seq_len: - if causal_bias.device != device: - causal_bias = causal_bias.to(device) - self.__cache["causal_attention_bias"] = causal_bias - return causal_bias - with torch.autocast(device.type, enabled=False): - causal_bias = causal_attention_bias(seq_len, device) - self.__cache["causal_attention_bias"] = causal_bias - return causal_bias - def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ -1 @@ -1052,11 +1051,12 @@ def forward( or past_key_values is not None ): if attention_bias is None and self.config.alibi: - attention_bias = self.get_causal_attention_bias( + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) elif attention_bias is None: - attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) elif attention_bias.dtype in (torch.int8, torch.bool): attention_bias = attention_bias.to(dtype=torch.float) attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) From 2da1a0ade1adecc2a9fb0fc67d6d9c5a7d360fc7 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 30 Oct 2023 17:45:38 -0700 Subject: [PATCH 08/11] Run black --- olmo/model.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/olmo/model.py b/olmo/model.py index d7d014cc9..d00b5bbac 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -382,9 +382,7 @@ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTens def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: - if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[ - -1 - ] >= seq_len: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: if causal_bias.device != device: causal_bias = causal_bias.to(device) cache["causal_attention_bias"] = causal_bias @@ -1052,8 +1050,7 @@ def forward( ): if attention_bias is None and self.config.alibi: attention_bias = get_causal_attention_bias( - self.__cache, - past_length + seq_len, x.device + self.__cache, past_length + seq_len, x.device ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) elif attention_bias is None: attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) From 57caf05a28bfe361583b89497b8004dd7d184062 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 30 Oct 2023 17:46:33 -0700 Subject: [PATCH 09/11] Store buffer cache in OlmoLlamaBlock --- olmo/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/olmo/model.py b/olmo/model.py index d00b5bbac..d567a89a3 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -746,6 +746,7 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): # Layer norms. self.attn_norm = LayerNorm.build(config) self.ff_norm = LayerNorm.build(config) + self.__cache = cache # Attention input projection. Projects x -> (q, k, v) if config.multi_query_attention: From 7d443b61909090816eb02dd670ff953e0a7f9149 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 2 Nov 2023 09:31:17 -0700 Subject: [PATCH 10/11] Change RoPE precision to a boolean config --- olmo/config.py | 16 +++------------- olmo/model.py | 6 +++++- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 00e936aa0..13f4c1494 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -281,9 +281,10 @@ class ModelConfig(BaseConfig): Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. """ - rope_precision_type: str = "fp32" + rope_full_precision: bool = True """ - Precision with which to apply RoPE (e.g. "amp_bf16", "amp_fp16", or "fp32"). + If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, + apply RoPE at the precision of the input. """ flash_attention: bool = False @@ -417,17 +418,6 @@ class ModelConfig(BaseConfig): See :data:`TrainConfig.precision` instead. """ - @property - def rope_precision(self) -> torch.dtype: - if self.rope_precision_type == "amp_bf16": - return torch.bfloat16 - elif self.rope_precision_type == "amp_fp16": - return torch.float16 - elif self.rope_precision_type == "fp32": - return torch.float32 - else: - raise ValueError(f"Unexpected precision type '{self.rope_precision_type}'") - class OptimizerType(StrEnum): lionw = "lionw" diff --git a/olmo/model.py b/olmo/model.py index 07ba2a0c4..c5e631e25 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -312,7 +312,11 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: return out.to(t.dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - q_, k_ = q.to(dtype=self.config.rope_precision), k.to(dtype=self.config.rope_precision) + if self.config.rope_full_precision: + q_, k_ = q.float(), k.float() + else: + q_, k_ = q, k + with torch.autocast(q.device.type, enabled=False): query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) From 6b1a77cbbba5eefeba9e547338e8ccb8963d761e Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 2 Nov 2023 09:31:55 -0700 Subject: [PATCH 11/11] Add activation checkpointing to Llama block --- olmo/model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index c5e631e25..2fad6f348 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -878,7 +878,13 @@ def forward( # Add feed-forward projection. # shape: (batch_size, seq_len, d_model) - x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.ff_norm(x))))) + og_x = x + x = self._activation_checkpoint_fn(self.ff_norm, x) + x = self.ff_proj(x) + x = self._activation_checkpoint_fn(self.act, x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x return x, cache