Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add block with Llama-like implementations #346

Merged
merged 13 commits into from
Nov 2, 2023
12 changes: 12 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ class BlockType(StrEnum):
sequential = "sequential"
parallel = "parallel"

llama = "llama"
"""
A block similar to the sequential block with slightly different
implementations of operations like attention to imitate the behavior of Llama.
"""


class InitFnType(StrEnum):
mitchell = "mitchell"
Expand Down Expand Up @@ -275,6 +281,12 @@ class ModelConfig(BaseConfig):
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""

rope_full_precision: bool = True
"""
If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
apply RoPE at the precision of the input.
"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
Expand Down
229 changes: 186 additions & 43 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,16 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t:
return out.to(t.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q_, k_ = q.float(), k.float()
if self.config.rope_full_precision:
q_, k_ = q.float(), k.float()
else:
q_, k_ = q, k

with torch.autocast(q.device.type, enabled=False):
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
pos_sin = pos_sin.type_as(q_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to change any existing runs that use Rope?

Copy link
Collaborator Author

@2015aroras 2015aroras Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rotary embeddings are fp32 by construction, so if rope_precision_type = fp32 (the default) then this shouldn't change the type.

pos_cos = pos_cos.type_as(q_)
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, key_len - query_len : key_len, :],
pos_cos[:, :, key_len - query_len : key_len, :],
Expand Down Expand Up @@ -373,6 +379,42 @@ def output_multiplier(self) -> float:
return 0.5


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore


def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.to(device)
cache["causal_attention_bias"] = causal_bias
return causal_bias
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
cache["causal_attention_bias"] = causal_bias
return causal_bias


def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)

# shape: (1, 1, seq_len, seq_len)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(-1)

# shape: (n_heads,)
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
m.mul_(config.alibi_bias_max / config.n_heads)

# shape: (1, n_heads, seq_len, seq_len)
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore


class OlmoBlock(nn.Module):
"""
A base class for transformer block implementations.
Expand Down Expand Up @@ -467,6 +509,30 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
return bias

def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
"""
Computes scaled dot product attention on query, key and value tensors, using an optional
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.

This method is based on PyTorch's `scaled_dot_product_attention`.
"""
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)

def attention(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -526,7 +592,7 @@ def attention(

# Get the attention scores.
# shape: (B, nh, T, hs)
att = F.scaled_dot_product_attention(
att = self._scaled_dot_product_attention(
q,
k,
v,
Expand Down Expand Up @@ -557,6 +623,8 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl
return OlmoSequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.parallel:
return OlmoParallelBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return OlmoLlamaBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'")

Expand Down Expand Up @@ -709,6 +777,118 @@ def forward(
)


class OlmoLlamaBlock(OlmoBlock):
"""
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection). This block is similar to `OlmoSequentialBlock`
but some operations have slightly different implementations to imitate the
behavior of Llama.
"""

def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
# Layer norms.
self.attn_norm = LayerNorm.build(config)
self.ff_norm = LayerNorm.build(config)
self.__cache = cache

# Attention input projection. Projects x -> (q, k, v)
if config.multi_query_attention:
q_proj_out_dim = config.d_model
k_proj_out_dim = config.d_model // config.n_heads
v_proj_out_dim = config.d_model // config.n_heads
else:
q_proj_out_dim = config.d_model
k_proj_out_dim = config.d_model
v_proj_out_dim = config.d_model
self.q_proj = nn.Linear(
config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device
)
self.k_proj = nn.Linear(
config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device
)
self.v_proj = nn.Linear(
config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device
)

# Feed-forward input projection.
self.ff_proj = nn.Linear(
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
)

def reset_parameters(self):
super().reset_parameters()
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)

def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

if is_causal:
assert attn_mask is None

query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
elif attn_mask is not None:
attn_bias = attn_mask.to(q.dtype)
else:
attn_bias = torch.zeros_like(attn_weights)

attn_weights += attn_bias
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
return torch.matmul(attn_weights, v)

def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
x_normed = self.attn_norm(x)
q = self.q_proj(x_normed)
k = self.k_proj(x_normed)
v = self.v_proj(x_normed)

# Get attention scores.
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(att)

# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
x = self._activation_checkpoint_fn(self.ff_norm, x)
x = self.ff_proj(x)
x = self._activation_checkpoint_fn(self.act, x)
x = self.ff_out(x)
x = self.dropout(x)
x = og_x + x

return x, cache


class OlmoOutput(NamedTuple):
logits: torch.FloatTensor
"""
Expand Down Expand Up @@ -791,30 +971,6 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin
block.set_activation_checkpointing(strategy)


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore


def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)

# shape: (1, 1, seq_len, seq_len)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(-1)

# shape: (n_heads,)
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
m.mul_(config.alibi_bias_max / config.n_heads)

# shape: (1, n_heads, seq_len, seq_len)
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore


class Olmo(nn.Module):
def __init__(self, config: ModelConfig, init_params: bool = True):
super().__init__()
Expand Down Expand Up @@ -892,7 +1048,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True):

# Warm up cache.
if self.config.alibi:
self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))

def enable_activation_checkpointing(self, enable: bool = True):
Expand Down Expand Up @@ -942,19 +1098,6 @@ def reset_parameters(self):
for block_group in self.transformer.block_groups:
block_group.reset_parameters()

def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[
-1
] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.to(device)
self.__cache["causal_attention_bias"] = causal_bias
return causal_bias
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
self.__cache["causal_attention_bias"] = causal_bias
return causal_bias

def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
-1
Expand Down Expand Up @@ -1049,11 +1192,11 @@ def forward(
or past_key_values is not None
):
if attention_bias is None and self.config.alibi:
attention_bias = self.get_causal_attention_bias(
past_length + seq_len, x.device
attention_bias = get_causal_attention_bias(
self.__cache, past_length + seq_len, x.device
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
elif attention_bias is None:
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device)
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
elif attention_bias.dtype in (torch.int8, torch.bool):
attention_bias = attention_bias.to(dtype=torch.float)
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
Expand Down
Loading