Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add block with Llama-like implementations #346

Merged
merged 13 commits into from
Nov 2, 2023
Merged

Conversation

2015aroras
Copy link
Collaborator

@2015aroras 2015aroras commented Oct 26, 2023

When investigating why OLMo and Llama produce different results, we found a few different causes. This change adds a Llama block that we can use to address the following causes:

  • Fused output dimensions causes differing results on CUDA despite attempts to make computation deterministic.
  • Torch's attention output does not match that of Llama. This is somehow caused by torch's F.scaled_dot_product_attention.

Also, OLMo always applies rotary embeddings in fp32, whereas Llama does it in the current type (which can be bf16). I have added a config that allows us to configure how rotary embeddings are applied.

Closes #345.

@2015aroras 2015aroras marked this pull request as ready for review October 30, 2023 19:44
olmo/model.py Outdated
@@ -309,10 +309,12 @@ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t:
return out.to(t.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q_, k_ = q.float(), k.float()
q_, k_ = q.to(dtype=self.config.rope_precision), k.to(dtype=self.config.rope_precision)
Copy link
Collaborator Author

@2015aroras 2015aroras Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes to rope should not affect the overall result if rope_precision_type = fp32 (the default).

Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No requested changes, but some questions.

with torch.autocast(q.device.type, enabled=False):
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
pos_sin = pos_sin.type_as(q_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to change any existing runs that use Rope?

Copy link
Collaborator Author

@2015aroras 2015aroras Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rotary embeddings are fp32 by construction, so if rope_precision_type = fp32 (the default) then this shouldn't change the type.

olmo/model.py Outdated
) -> torch.Tensor:
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))

attn_bias = torch.zeros_like(attn_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does this every batch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I can change this to re-use the attn_bias (Llama does this). I think there might be some code in place already that I can leverage

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that it changes attn_bias in place?

Copy link
Collaborator Author

@2015aroras 2015aroras Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the code to leverage existing get_causal_attention_bias. Now the bias should be re-used between calls (in the Llama implementation only). The moving around of get_causal_attention_bias to achieve this is not the cleanest.

olmo/model.py Outdated
attn_bias.masked_fill_(context_mask.logical_not(), torch.finfo(attn_bias.dtype).min)

if attn_mask is not None:
attn_bias += attn_mask.to(q.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens every batch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure how we can avoid doing this addition though. Llama does this addition every time it computes attention

@2015aroras
Copy link
Collaborator Author

2015aroras commented Oct 31, 2023

Once this change is in, using Llama will require setting model.rope_precision_type = amp_bf16 and model.block_type = llama (as well as any changes needed before).

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple comments

olmo/config.py Outdated
@@ -280,6 +286,11 @@ class ModelConfig(BaseConfig):
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""

rope_precision_type: str = "fp32"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making a StrEnum for the options here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or never mind if you go with my other suggestion.

olmo/config.py Outdated
@@ -280,6 +286,11 @@ class ModelConfig(BaseConfig):
Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
"""

rope_precision_type: str = "fp32"
"""
Precision with which to apply RoPE (e.g. "amp_bf16", "amp_fp16", or "fp32").
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think "amp_*" is meaningful here. It seems like there should really be two options:

  • fp32, or
  • whatever type q and k are

So maybe change this to a flag called rope_full_precision or something?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with this suggestion instead of making it a StrEnum as in your other one.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@2015aroras 2015aroras merged commit 4ccf2bd into main Nov 2, 2023
10 checks passed
@2015aroras 2015aroras deleted the shanea/llama-block branch November 2, 2023 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama Block
3 participants