Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add OLMo November 2024 model #10503

Merged
merged 14 commits into from
Nov 25, 2024
Merged

Conversation

2015aroras
Copy link
Contributor

@2015aroras 2015aroras commented Nov 20, 2024

An updated OLMo model will be released in November. The new model has a few small architecture changes compared to the original model:

  • RMSNorm is used instead of standard layer norm.
  • Norm is applied to attention queries and keys.
  • Norm is applied after attention/feedforward rather than before.

The model has been implemented in transformers (huggingface/transformers#34551). This PR implements the model in vLLM.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Nov 20, 2024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it would make sense to have the existing olmo model definition (with expanded functionality) cover this one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to do that if you prefer. We went that route first with transformers and they told us to do separate models instead 😆 .

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think the model definitions are sufficiently similar, then I think that would be a good move!

Copy link
Contributor Author

@2015aroras 2015aroras Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty messy since:

  • The new norm is parametric while the old one is not.
  • Making norm able to be put both before and after feedforward/attention (depending on the model) makes the forward pass pretty messy. Also, the modules have different names depending on whether they are before or after.

We would strongly prefer to have separate models, but if you insist then we will follow whatever decision you folks choose.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this what the model definition would look like if we kept them together https://github.com/huggingface/transformers/pull/34497/files#diff-bcd9325f22ada9d41cdb22d8497a1c31dd874d1a6c2ea4315f1bf795aabb9a43?

I see what you mean about the if self.norm_after: checks all over the place. OK by me

Copy link
Contributor Author

@2015aroras 2015aroras Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the model definition (barring minor improvements/cleanup). Thank you!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it should be fine to have separate models in this case

num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please pass in the prefix , it is required for attention now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19e26b7 Wasn't sure what to pass in for this since it doesn't correspond to a separate set of weights. I followed exaone and just passed prefix as is.

@2015aroras
Copy link
Contributor Author

We renamed the model to OLMo2 in transformers (huggingface/transformers#34864). I have updated the model here accordingly.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks for the quick responses

@mgoin mgoin added new model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed labels Nov 25, 2024
@2015aroras
Copy link
Contributor Author

2015aroras commented Nov 25, 2024

A test fails because Olmo2Config is not in the version of transformers being used in the vLLM test. The rename was only merged a few hours ago. I believe transformers is looking to do a new release very soon.
https://buildkite.com/vllm/ci-aws/builds/11766#01936444-39a8-48ad-9085-7d964b2f56fc/277-553

@youkaichao
Copy link
Member

A test fails because Olmo2Config is not in the version of transformers being used in the vLLM test. The rename was only merged a few hours ago. I believe transformers is looking to do a new release very soon. buildkite.com/vllm/ci-aws/builds/11766#01936444-39a8-48ad-9085-7d964b2f56fc/277-553

it is also fine to directly copy the config file into vllm, e.g.

class ChatGLMConfig(PretrainedConfig):

@2015aroras
Copy link
Contributor Author

it is also fine to directly copy the config file into vllm, e.g.

class ChatGLMConfig(PretrainedConfig):

Wow, I didn't know! I've done this approach now (b50a3bf) and verified that vllm serve works for an old transformers version.

@youkaichao
Copy link
Member

another option is to keep an AutoConfig field in your config.json, like https://huggingface.co/deepseek-ai/DeepSeek-V2.5/blob/98b11844770b2c3ffc18b175c758a803640f4e77/config.json#L8

@youkaichao
Copy link
Member

@2015aroras
Copy link
Contributor Author

I'm happy with the current state of the implementation, and all checks are passing. Please let me know if anything else needs to be done on my end to get this merged.

@mgoin mgoin merged commit 9db713a into vllm-project:main Nov 25, 2024
53 checks passed
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation new model Requests to new models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants