diff --git a/olmo/config.py b/olmo/config.py index 30fe651f8..13f4c1494 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -177,6 +177,12 @@ class BlockType(StrEnum): sequential = "sequential" parallel = "parallel" + llama = "llama" + """ + A block similar to the sequential block with slightly different + implementations of operations like attention to imitate the behavior of Llama. + """ + class InitFnType(StrEnum): mitchell = "mitchell" @@ -275,6 +281,12 @@ class ModelConfig(BaseConfig): Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. """ + rope_full_precision: bool = True + """ + If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, + apply RoPE at the precision of the input. + """ + flash_attention: bool = False """ If ``True``, use ``FlashAttention``. diff --git a/olmo/model.py b/olmo/model.py index ff4a0ec3c..2fad6f348 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -312,10 +312,16 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: return out.to(t.dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - q_, k_ = q.float(), k.float() + if self.config.rope_full_precision: + q_, k_ = q.float(), k.float() + else: + q_, k_ = q, k + with torch.autocast(q.device.type, enabled=False): query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) + pos_sin = pos_sin.type_as(q_) + pos_cos = pos_cos.type_as(q_) q_ = self.apply_rotary_pos_emb( pos_sin[:, :, key_len - query_len : key_len, :], pos_cos[:, :, key_len - query_len : key_len, :], @@ -373,6 +379,42 @@ def output_multiplier(self) -> float: return 0.5 +def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: + att_bias = torch.triu( + torch.ones(seq_len, seq_len, device=device, dtype=torch.float), + diagonal=1, + ) + att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) + return att_bias.view(1, 1, seq_len, seq_len) # type: ignore + + +def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: + if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: + if causal_bias.device != device: + causal_bias = causal_bias.to(device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + with torch.autocast(device.type, enabled=False): + causal_bias = causal_attention_bias(seq_len, device) + cache["causal_attention_bias"] = causal_bias + return causal_bias + + +def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: + alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) + + # shape: (1, 1, seq_len, seq_len) + alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) + alibi_bias.abs_().mul_(-1) + + # shape: (n_heads,) + m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) + m.mul_(config.alibi_bias_max / config.n_heads) + + # shape: (1, n_heads, seq_len, seq_len) + return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore + + class OlmoBlock(nn.Module): """ A base class for transformer block implementations. @@ -467,6 +509,30 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) return bias + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + """ + Computes scaled dot product attention on query, key and value tensors, using an optional + attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. + + This method is based on PyTorch's `scaled_dot_product_attention`. + """ + return F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + ) + def attention( self, q: torch.Tensor, @@ -526,7 +592,7 @@ def attention( # Get the attention scores. # shape: (B, nh, T, hs) - att = F.scaled_dot_product_attention( + att = self._scaled_dot_product_attention( q, k, v, @@ -557,6 +623,8 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl return OlmoSequentialBlock(layer_id, config, cache) elif config.block_type == BlockType.parallel: return OlmoParallelBlock(layer_id, config, cache) + elif config.block_type == BlockType.llama: + return OlmoLlamaBlock(layer_id, config, cache) else: raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'") @@ -709,6 +777,118 @@ def forward( ) +class OlmoLlamaBlock(OlmoBlock): + """ + This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). This block is similar to `OlmoSequentialBlock` + but some operations have slightly different implementations to imitate the + behavior of Llama. + """ + + def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): + super().__init__(layer_id, config, cache) + # Layer norms. + self.attn_norm = LayerNorm.build(config) + self.ff_norm = LayerNorm.build(config) + self.__cache = cache + + # Attention input projection. Projects x -> (q, k, v) + if config.multi_query_attention: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model // config.n_heads + v_proj_out_dim = config.d_model // config.n_heads + else: + q_proj_out_dim = config.d_model + k_proj_out_dim = config.d_model + v_proj_out_dim = config.d_model + self.q_proj = nn.Linear( + config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.k_proj = nn.Linear( + config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + self.v_proj = nn.Linear( + config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device + ) + + # Feed-forward input projection. + self.ff_proj = nn.Linear( + config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device + ) + + def reset_parameters(self): + super().reset_parameters() + self.attn_norm.reset_parameters() + self.ff_norm.reset_parameters() + # NOTE: the standard deviation for these weights does not depend on the layer. + init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) + init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) + + def _scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + ) -> torch.Tensor: + attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + + if is_causal: + assert attn_mask is None + + query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None + attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] + elif attn_mask is not None: + attn_bias = attn_mask.to(q.dtype) + else: + attn_bias = torch.zeros_like(attn_weights) + + attn_weights += attn_bias + attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) + return torch.matmul(attn_weights, v) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Get query, key, value projections. + # shape: + # - for regular attn q, k, v: (batch_size, seq_len, d_model) + # - for multi-query attn q: (batch_size, seq_len, d_model) + # k, v: (batch_size, seq_len, d_model // n_heads) + x_normed = self.attn_norm(x) + q = self.q_proj(x_normed) + k = self.k_proj(x_normed) + v = self.v_proj(x_normed) + + # Get attention scores. + att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + x = self._activation_checkpoint_fn(self.ff_norm, x) + x = self.ff_proj(x) + x = self._activation_checkpoint_fn(self.act, x) + x = self.ff_out(x) + x = self.dropout(x) + x = og_x + x + + return x, cache + + class OlmoOutput(NamedTuple): logits: torch.FloatTensor """ @@ -791,30 +971,6 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin block.set_activation_checkpointing(strategy) -def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: - att_bias = torch.triu( - torch.ones(seq_len, seq_len, device=device, dtype=torch.float), - diagonal=1, - ) - att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) - return att_bias.view(1, 1, seq_len, seq_len) # type: ignore - - -def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: - alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) - - # shape: (1, 1, seq_len, seq_len) - alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) - alibi_bias.abs_().mul_(-1) - - # shape: (n_heads,) - m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) - m.mul_(config.alibi_bias_max / config.n_heads) - - # shape: (1, n_heads, seq_len, seq_len) - return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore - - class Olmo(nn.Module): def __init__(self, config: ModelConfig, init_params: bool = True): super().__init__() @@ -892,7 +1048,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): # Warm up cache. if self.config.alibi: - self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) + get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) def enable_activation_checkpointing(self, enable: bool = True): @@ -942,19 +1098,6 @@ def reset_parameters(self): for block_group in self.transformer.block_groups: block_group.reset_parameters() - def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: - if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[ - -1 - ] >= seq_len: - if causal_bias.device != device: - causal_bias = causal_bias.to(device) - self.__cache["causal_attention_bias"] = causal_bias - return causal_bias - with torch.autocast(device.type, enabled=False): - causal_bias = causal_attention_bias(seq_len, device) - self.__cache["causal_attention_bias"] = causal_bias - return causal_bias - def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ -1 @@ -1049,11 +1192,11 @@ def forward( or past_key_values is not None ): if attention_bias is None and self.config.alibi: - attention_bias = self.get_causal_attention_bias( - past_length + seq_len, x.device + attention_bias = get_causal_attention_bias( + self.__cache, past_length + seq_len, x.device ) + self.get_alibi_attention_bias(past_length + seq_len, x.device) elif attention_bias is None: - attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) + attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) elif attention_bias.dtype in (torch.int8, torch.bool): attention_bias = attention_bias.to(dtype=torch.float) attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)