Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add block with Llama-like implementations #346

Merged
merged 13 commits into from
Nov 2, 2023
22 changes: 22 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,12 @@ class BlockType(StrEnum):
sequential = "sequential"
parallel = "parallel"

llama = "llama"
"""
A block similar to the sequential block with slightly different
implementations of operations like attention to imitate the behavior of Llama.
"""


class InitFnType(StrEnum):
mitchell = "mitchell"
Expand Down Expand Up @@ -280,6 +286,11 @@ class ModelConfig(BaseConfig):
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""

rope_precision_type: str = "fp32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making a StrEnum for the options here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or never mind if you go with my other suggestion.

"""
Precision with which to apply RoPE (e.g. "amp_bf16", "amp_fp16", or "fp32").
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think "amp_*" is meaningful here. It seems like there should really be two options:

  • fp32, or
  • whatever type q and k are

So maybe change this to a flag called rope_full_precision or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with this suggestion instead of making it a StrEnum as in your other one.

"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
Expand Down Expand Up @@ -411,6 +422,17 @@ class ModelConfig(BaseConfig):
See :data:`TrainConfig.precision` instead.
"""

@property
def rope_precision(self) -> torch.dtype:
if self.rope_precision_type == "amp_bf16":
return torch.bfloat16
elif self.rope_precision_type == "amp_fp16":
return torch.float16
elif self.rope_precision_type == "fp32":
return torch.float32
else:
raise ValueError(f"Unexpected precision type '{self.rope_precision_type}'")


class OptimizerType(StrEnum):
lionw = "lionw"
Expand Down
219 changes: 176 additions & 43 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,12 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t:
return out.to(t.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q_, k_ = q.float(), k.float()
q_, k_ = q.to(dtype=self.config.rope_precision), k.to(dtype=self.config.rope_precision)
Copy link
Collaborator Author

@2015aroras 2015aroras Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes to rope should not affect the overall result if rope_precision_type = fp32 (the default).

with torch.autocast(q.device.type, enabled=False):
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
pos_sin = pos_sin.type_as(q_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to change any existing runs that use Rope?

Copy link
Collaborator Author

@2015aroras 2015aroras Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rotary embeddings are fp32 by construction, so if rope_precision_type = fp32 (the default) then this shouldn't change the type.

pos_cos = pos_cos.type_as(q_)
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, key_len - query_len : key_len, :],
pos_cos[:, :, key_len - query_len : key_len, :],
Expand Down Expand Up @@ -381,6 +383,42 @@ def output_multiplier(self) -> float:
return 0.5


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore


def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.to(device)
cache["causal_attention_bias"] = causal_bias
return causal_bias
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
cache["causal_attention_bias"] = causal_bias
return causal_bias


def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)

# shape: (1, 1, seq_len, seq_len)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(-1)

# shape: (n_heads,)
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
m.mul_(config.alibi_bias_max / config.n_heads)

# shape: (1, n_heads, seq_len, seq_len)
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore


class OlmoBlock(nn.Module):
"""
A base class for transformer block implementations.
Expand Down Expand Up @@ -465,6 +503,30 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
return bias

def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
"""
Computes scaled dot product attention on query, key and value tensors, using an optional
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.

This method is based on PyTorch's `scaled_dot_product_attention`.
"""
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)

def attention(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -524,7 +586,7 @@ def attention(

# Get the attention scores.
# shape: (B, nh, T, hs)
att = F.scaled_dot_product_attention(
att = self._scaled_dot_product_attention(
q,
k,
v,
Expand Down Expand Up @@ -555,6 +617,8 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl
return OlmoSequentialBlock(layer_id, config, cache)
elif config.block_type == BlockType.parallel:
return OlmoParallelBlock(layer_id, config, cache)
elif config.block_type == BlockType.llama:
return OlmoLlamaBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'")

Expand Down Expand Up @@ -682,6 +746,112 @@ def forward(
return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), cache


class OlmoLlamaBlock(OlmoBlock):
"""
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection). This block is similar to `OlmoSequentialBlock`
but some operations have slightly different implementations to imitate the
behavior of Llama.
"""

def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
# Layer norms.
self.attn_norm = LayerNorm.build(config)
self.ff_norm = LayerNorm.build(config)
self.__cache = cache

# Attention input projection. Projects x -> (q, k, v)
if config.multi_query_attention:
q_proj_out_dim = config.d_model
k_proj_out_dim = config.d_model // config.n_heads
v_proj_out_dim = config.d_model // config.n_heads
else:
q_proj_out_dim = config.d_model
k_proj_out_dim = config.d_model
v_proj_out_dim = config.d_model
self.q_proj = nn.Linear(
config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device
)
self.k_proj = nn.Linear(
config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device
)
self.v_proj = nn.Linear(
config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device
)

# Feed-forward input projection.
self.ff_proj = nn.Linear(
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
)

def reset_parameters(self):
super().reset_parameters()
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)

def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

if is_causal:
assert attn_mask is None

query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
elif attn_mask is not None:
attn_bias = attn_mask.to(q.dtype)
else:
attn_bias = torch.zeros_like(attn_weights)

attn_weights += attn_bias
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout_p)
return torch.matmul(attn_weights, v)

def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
x_normed = self.attn_norm(x)
q = self.q_proj(x_normed)
k = self.k_proj(x_normed)
v = self.v_proj(x_normed)

# Get attention scores.
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)

# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(att)

# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.ff_norm(x)))))

return x, cache


class OlmoOutput(NamedTuple):
logits: torch.FloatTensor
"""
Expand Down Expand Up @@ -736,30 +906,6 @@ def reset_parameters(self):
block.reset_parameters()


def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore


def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)

# shape: (1, 1, seq_len, seq_len)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
alibi_bias.abs_().mul_(-1)

# shape: (n_heads,)
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
m.mul_(config.alibi_bias_max / config.n_heads)

# shape: (1, n_heads, seq_len, seq_len)
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore


class Olmo(nn.Module):
def __init__(self, config: ModelConfig, init_params: bool = True):
super().__init__()
Expand Down Expand Up @@ -836,7 +982,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True):

# Warm up cache.
if self.config.alibi:
self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))

def enable_activation_checkpointing(self, enable: bool = True):
Expand Down Expand Up @@ -895,19 +1041,6 @@ def reset_parameters(self):
for block_group in self.transformer.block_groups:
block_group.reset_parameters()

def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[
-1
] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.to(device)
self.__cache["causal_attention_bias"] = causal_bias
return causal_bias
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
self.__cache["causal_attention_bias"] = causal_bias
return causal_bias

def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
-1
Expand Down Expand Up @@ -1002,11 +1135,11 @@ def forward(
or past_key_values is not None
):
if attention_bias is None and self.config.alibi:
attention_bias = self.get_causal_attention_bias(
past_length + seq_len, x.device
attention_bias = get_causal_attention_bias(
self.__cache, past_length + seq_len, x.device
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
elif attention_bias is None:
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device)
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
elif attention_bias.dtype in (torch.int8, torch.bool):
attention_bias = attention_bias.to(dtype=torch.float)
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
Expand Down
Loading