Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features to match OpenLM #302

Merged
merged 43 commits into from
Oct 10, 2023
Merged

Features to match OpenLM #302

merged 43 commits into from
Oct 10, 2023

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Sep 29, 2023

Adds some features to allow us to more closely match the architecture from Mitchell's OpenLM.

  • Add option --model.weight_tying (bool, defaults to True) that allows us to disable weight tying of the input embedding with the output linear.
  • Add option to restrict block feed-forward hidden dimension (usually mlp_ratio * d_model) to a multiple of 256 like Mitchell does here.

The configuration I'm running in my mitch-ish runs (see W&B) is relatively slow compared to our defaults, but there are some low-hanging improvements we can make:

  • Cache the RoPE sin and cos of the positions. It's silly that we're not doing this already. (160d143).
  • Try torch-scripting the apply_rotary_pos_emb function like Mitchell does. Can we trust torchscript on AMD? We'll find out.

Other changes:

  • Move all the RoPE logic to the RotaryEmbedding module. This makes way more sense in my opinion and simplifies the OlmoBlock.attention implementation. (7fc33c5)

@epwalsh epwalsh mentioned this pull request Sep 29, 2023
10 tasks
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry about caching, because it seems to screw up FSDP every time we try it. So let's make sure this runs in LUMI before we merge.

@epwalsh
Copy link
Member Author

epwalsh commented Sep 29, 2023

I worry about caching, because it seems to screw up FSDP every time we try it. So let's make sure this runs in LUMI before we merge.

Agreed, we should definitely test on LUMI before merging. I'm not too worried about these changes though because we've been doing the same thing with the ALiBi bias.

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments, since this is still a draft

olmo/model.py Show resolved Hide resolved
olmo/model.py Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
@epwalsh
Copy link
Member Author

epwalsh commented Oct 4, 2023

After avoiding buffers with RoPE there is a huge improvement!
image

@dirkgr
Copy link
Member

dirkgr commented Oct 4, 2023 via email

@epwalsh
Copy link
Member Author

epwalsh commented Oct 4, 2023

Avoiding buffers? Why does that make a difference?

My first thought was that it's because buffers are stored in bf16 with our FSDP settings, so we lose some precision when RoPE is applied. And I think this is still true, but I found a bigger issue.

It turns out using meta device deferred initialization introduced a bug with our RoPE "inv_freq" buffer. This buffer is initialized when its module is initialized, but there's no data in the buffer since it's a meta-device tensor, and later on FSDP calls Module.to_empty() which then causes those buffers to materialize to all zeros. In other words, when model.init_device is set to "meta", this line is essentially ignored and instead "inv_freq" ends up all zeros:

https://github.com/allenai/LLM/blob/602968ae92294b5eeb70e7422d073cb0183166fd/olmo/model.py#L251

@dirkgr
Copy link
Member

dirkgr commented Oct 4, 2023 via email

@dirkgr
Copy link
Member

dirkgr commented Oct 4, 2023 via email

@epwalsh
Copy link
Member Author

epwalsh commented Oct 4, 2023

Isn't the Alibi stuff stored the same way?

No, the ALiBi bias is stored differently. We realized early on that buffers didn't work well for some reason and we made that fix with ALiBi, but I guess we never fixed the same issue with RoPE because we weren't using it at the time.

Does that mean all my Rope experiments didn't really work?

I think so. We should run those again.

@epwalsh epwalsh marked this pull request as ready for review October 5, 2023 22:14
@epwalsh epwalsh mentioned this pull request Oct 8, 2023
configs/mcli/v1-mix-medium-mitch-ish.yaml Outdated Show resolved Hide resolved
olmo/model.py Outdated Show resolved Hide resolved
Copy link
Member

@dirkgr dirkgr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe changes because of the sequence length thing with Rope?

ssh_clone: true
command: |-
pip install urllib3==1.26.17
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was MosaicML's recommendation for solving the SSLError. Didn't work, I removed it.

olmo/model.py Show resolved Hide resolved
olmo/model.py Show resolved Hide resolved
return pos_sin, pos_cos

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
q_, k_ = q.float(), k.float()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're messing with precision here, do we need to disable autocast?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any of the operations with this forward method would autocast to bf16, but just to make sure: 62fcb47

@epwalsh epwalsh requested a review from dirkgr October 9, 2023 22:18
@epwalsh epwalsh merged commit fddded5 into main Oct 10, 2023
10 checks passed
@epwalsh epwalsh deleted the petew/tweaks branch October 10, 2023 16:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants