-
Notifications
You must be signed in to change notification settings - Fork 451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add block with Llama-like implementations #346
Changes from 10 commits
a57a050
53d68c8
6c4b8e1
7743b0f
07eb67c
cebdbe5
b8938d5
2da1a0a
57caf05
9f0d165
4b9aec6
7d443b6
6b1a77c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,6 +187,12 @@ class BlockType(StrEnum): | |
sequential = "sequential" | ||
parallel = "parallel" | ||
|
||
llama = "llama" | ||
""" | ||
A block similar to the sequential block with slightly different | ||
implementations of operations like attention to imitate the behavior of Llama. | ||
""" | ||
|
||
|
||
class InitFnType(StrEnum): | ||
mitchell = "mitchell" | ||
|
@@ -280,6 +286,11 @@ class ModelConfig(BaseConfig): | |
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. | ||
""" | ||
|
||
rope_precision_type: str = "fp32" | ||
""" | ||
Precision with which to apply RoPE (e.g. "amp_bf16", "amp_fp16", or "fp32"). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think "amp_*" is meaningful here. It seems like there should really be two options:
So maybe change this to a flag called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I went with this suggestion instead of making it a |
||
""" | ||
|
||
flash_attention: bool = False | ||
""" | ||
If ``True``, use ``FlashAttention``. | ||
|
@@ -411,6 +422,17 @@ class ModelConfig(BaseConfig): | |
See :data:`TrainConfig.precision` instead. | ||
""" | ||
|
||
@property | ||
def rope_precision(self) -> torch.dtype: | ||
if self.rope_precision_type == "amp_bf16": | ||
return torch.bfloat16 | ||
elif self.rope_precision_type == "amp_fp16": | ||
return torch.float16 | ||
elif self.rope_precision_type == "fp32": | ||
return torch.float32 | ||
else: | ||
raise ValueError(f"Unexpected precision type '{self.rope_precision_type}'") | ||
|
||
|
||
class OptimizerType(StrEnum): | ||
lionw = "lionw" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -320,10 +320,12 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: | |
return out.to(t.dtype) | ||
|
||
def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
q_, k_ = q.float(), k.float() | ||
q_, k_ = q.to(dtype=self.config.rope_precision), k.to(dtype=self.config.rope_precision) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes to rope should not affect the overall result if |
||
with torch.autocast(q.device.type, enabled=False): | ||
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None | ||
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) | ||
pos_sin = pos_sin.type_as(q_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this going to change any existing runs that use Rope? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rotary embeddings are fp32 by construction, so if |
||
pos_cos = pos_cos.type_as(q_) | ||
q_ = self.apply_rotary_pos_emb( | ||
pos_sin[:, :, key_len - query_len : key_len, :], | ||
pos_cos[:, :, key_len - query_len : key_len, :], | ||
|
@@ -381,6 +383,42 @@ def output_multiplier(self) -> float: | |
return 0.5 | ||
|
||
|
||
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: | ||
att_bias = torch.triu( | ||
torch.ones(seq_len, seq_len, device=device, dtype=torch.float), | ||
diagonal=1, | ||
) | ||
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) | ||
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore | ||
|
||
|
||
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor: | ||
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len: | ||
if causal_bias.device != device: | ||
causal_bias = causal_bias.to(device) | ||
cache["causal_attention_bias"] = causal_bias | ||
return causal_bias | ||
with torch.autocast(device.type, enabled=False): | ||
causal_bias = causal_attention_bias(seq_len, device) | ||
cache["causal_attention_bias"] = causal_bias | ||
return causal_bias | ||
|
||
|
||
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: | ||
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) | ||
|
||
# shape: (1, 1, seq_len, seq_len) | ||
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) | ||
alibi_bias.abs_().mul_(-1) | ||
|
||
# shape: (n_heads,) | ||
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) | ||
m.mul_(config.alibi_bias_max / config.n_heads) | ||
|
||
# shape: (1, n_heads, seq_len, seq_len) | ||
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore | ||
|
||
|
||
class OlmoBlock(nn.Module): | ||
""" | ||
A base class for transformer block implementations. | ||
|
@@ -465,6 +503,30 @@ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch. | |
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False) | ||
return bias | ||
|
||
def _scaled_dot_product_attention( | ||
self, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
attn_mask: Optional[torch.Tensor] = None, | ||
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
) -> torch.Tensor: | ||
""" | ||
Computes scaled dot product attention on query, key and value tensors, using an optional | ||
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified. | ||
|
||
This method is based on PyTorch's `scaled_dot_product_attention`. | ||
""" | ||
return F.scaled_dot_product_attention( | ||
q, | ||
k, | ||
v, | ||
attn_mask=attn_mask, | ||
dropout_p=dropout_p, | ||
is_causal=is_causal, | ||
) | ||
|
||
def attention( | ||
self, | ||
q: torch.Tensor, | ||
|
@@ -524,7 +586,7 @@ def attention( | |
|
||
# Get the attention scores. | ||
# shape: (B, nh, T, hs) | ||
att = F.scaled_dot_product_attention( | ||
att = self._scaled_dot_product_attention( | ||
q, | ||
k, | ||
v, | ||
|
@@ -555,6 +617,8 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl | |
return OlmoSequentialBlock(layer_id, config, cache) | ||
elif config.block_type == BlockType.parallel: | ||
return OlmoParallelBlock(layer_id, config, cache) | ||
elif config.block_type == BlockType.llama: | ||
return OlmoLlamaBlock(layer_id, config, cache) | ||
else: | ||
raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'") | ||
|
||
|
@@ -682,6 +746,112 @@ def forward( | |
return x + self.dropout(self.ff_out(self.act(ff))) + self.dropout(att), cache | ||
|
||
|
||
class OlmoLlamaBlock(OlmoBlock): | ||
""" | ||
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` | ||
(plus another skip connection). This block is similar to `OlmoSequentialBlock` | ||
but some operations have slightly different implementations to imitate the | ||
behavior of Llama. | ||
""" | ||
|
||
def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache): | ||
super().__init__(layer_id, config, cache) | ||
# Layer norms. | ||
self.attn_norm = LayerNorm.build(config) | ||
self.ff_norm = LayerNorm.build(config) | ||
self.__cache = cache | ||
|
||
# Attention input projection. Projects x -> (q, k, v) | ||
if config.multi_query_attention: | ||
q_proj_out_dim = config.d_model | ||
k_proj_out_dim = config.d_model // config.n_heads | ||
v_proj_out_dim = config.d_model // config.n_heads | ||
else: | ||
q_proj_out_dim = config.d_model | ||
k_proj_out_dim = config.d_model | ||
v_proj_out_dim = config.d_model | ||
self.q_proj = nn.Linear( | ||
config.d_model, q_proj_out_dim, bias=config.include_bias, device=config.init_device | ||
) | ||
self.k_proj = nn.Linear( | ||
config.d_model, k_proj_out_dim, bias=config.include_bias, device=config.init_device | ||
) | ||
self.v_proj = nn.Linear( | ||
config.d_model, v_proj_out_dim, bias=config.include_bias, device=config.init_device | ||
) | ||
|
||
# Feed-forward input projection. | ||
self.ff_proj = nn.Linear( | ||
config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device | ||
) | ||
|
||
def reset_parameters(self): | ||
super().reset_parameters() | ||
self.attn_norm.reset_parameters() | ||
self.ff_norm.reset_parameters() | ||
# NOTE: the standard deviation for these weights does not depend on the layer. | ||
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None) | ||
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None) | ||
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None) | ||
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None) | ||
|
||
def _scaled_dot_product_attention( | ||
self, | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
attn_mask: Optional[torch.Tensor] = None, | ||
dropout_p: float = 0.0, | ||
is_causal: bool = False, | ||
) -> torch.Tensor: | ||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) | ||
|
||
if is_causal: | ||
assert attn_mask is None | ||
|
||
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None | ||
attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len] | ||
elif attn_mask is not None: | ||
attn_bias = attn_mask.to(q.dtype) | ||
else: | ||
attn_bias = torch.zeros_like(attn_weights) | ||
|
||
attn_weights += attn_bias | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(q.dtype) | ||
attn_weights = nn.functional.dropout(attn_weights, p=dropout_p) | ||
return torch.matmul(attn_weights, v) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
attention_bias: Optional[torch.Tensor] = None, | ||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | ||
# Get query, key, value projections. | ||
# shape: | ||
# - for regular attn q, k, v: (batch_size, seq_len, d_model) | ||
# - for multi-query attn q: (batch_size, seq_len, d_model) | ||
# k, v: (batch_size, seq_len, d_model // n_heads) | ||
x_normed = self.attn_norm(x) | ||
q = self.q_proj(x_normed) | ||
k = self.k_proj(x_normed) | ||
v = self.v_proj(x_normed) | ||
|
||
# Get attention scores. | ||
att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache) | ||
|
||
# Add attention scores. | ||
# shape: (B, T, C) | ||
x = x + self.dropout(att) | ||
|
||
# Add feed-forward projection. | ||
# shape: (batch_size, seq_len, d_model) | ||
x = x + self.dropout(self.ff_out(self.act(self.ff_proj(self.ff_norm(x))))) | ||
|
||
return x, cache | ||
|
||
|
||
class OlmoOutput(NamedTuple): | ||
logits: torch.FloatTensor | ||
""" | ||
|
@@ -736,30 +906,6 @@ def reset_parameters(self): | |
block.reset_parameters() | ||
|
||
|
||
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor: | ||
att_bias = torch.triu( | ||
torch.ones(seq_len, seq_len, device=device, dtype=torch.float), | ||
diagonal=1, | ||
) | ||
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min) | ||
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore | ||
|
||
|
||
def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor: | ||
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len) | ||
|
||
# shape: (1, 1, seq_len, seq_len) | ||
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1) | ||
alibi_bias.abs_().mul_(-1) | ||
|
||
# shape: (n_heads,) | ||
m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device) | ||
m.mul_(config.alibi_bias_max / config.n_heads) | ||
|
||
# shape: (1, n_heads, seq_len, seq_len) | ||
return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore | ||
|
||
|
||
class Olmo(nn.Module): | ||
def __init__(self, config: ModelConfig, init_params: bool = True): | ||
super().__init__() | ||
|
@@ -836,7 +982,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): | |
|
||
# Warm up cache. | ||
if self.config.alibi: | ||
self.get_causal_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) | ||
get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config)) | ||
self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config)) | ||
|
||
def enable_activation_checkpointing(self, enable: bool = True): | ||
|
@@ -895,19 +1041,6 @@ def reset_parameters(self): | |
for block_group in self.transformer.block_groups: | ||
block_group.reset_parameters() | ||
|
||
def get_causal_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: | ||
if (causal_bias := self.__cache.get("causal_attention_bias")) is not None and causal_bias.shape[ | ||
-1 | ||
] >= seq_len: | ||
if causal_bias.device != device: | ||
causal_bias = causal_bias.to(device) | ||
self.__cache["causal_attention_bias"] = causal_bias | ||
return causal_bias | ||
with torch.autocast(device.type, enabled=False): | ||
causal_bias = causal_attention_bias(seq_len, device) | ||
self.__cache["causal_attention_bias"] = causal_bias | ||
return causal_bias | ||
|
||
def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor: | ||
if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[ | ||
-1 | ||
|
@@ -1002,11 +1135,11 @@ def forward( | |
or past_key_values is not None | ||
): | ||
if attention_bias is None and self.config.alibi: | ||
attention_bias = self.get_causal_attention_bias( | ||
past_length + seq_len, x.device | ||
attention_bias = get_causal_attention_bias( | ||
self.__cache, past_length + seq_len, x.device | ||
) + self.get_alibi_attention_bias(past_length + seq_len, x.device) | ||
elif attention_bias is None: | ||
attention_bias = self.get_causal_attention_bias(past_length + seq_len, x.device) | ||
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device) | ||
elif attention_bias.dtype in (torch.int8, torch.bool): | ||
attention_bias = attention_bias.to(dtype=torch.float) | ||
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making a
StrEnum
for the options here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or never mind if you go with my other suggestion.