Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi 1.5 support #1167

Closed
axel578 opened this issue Sep 24, 2023 · 15 comments · Fixed by #1664
Closed

Phi 1.5 support #1167

axel578 opened this issue Sep 24, 2023 · 15 comments · Fixed by #1664
Labels
new model Requests to new models

Comments

@axel578
Copy link

axel578 commented Sep 24, 2023

Phi 1.5 is a new model from Microsoft, supporting this model would be extremely usefull.

A detailed list of info of phi 1.5 can be found here : https://huggingface.co/microsoft/phi-1_5

Its basically supporting MixFormerSequentialConfig .
The phi 1.5 has weird features, also 4 bit support would be great !! (and not only on gpu, but cpu also please, this model size should work ok on cpu)

@WoosukKwon WoosukKwon added the new model Requests to new models label Sep 27, 2023
@viktor-ferenczi
Copy link
Contributor

Please consider adding this to the Roadmap.

It is also a low hanging fruit for new contributors learning into vLLM's code base.

@adivoj
Copy link

adivoj commented Oct 4, 2023

We neeed it

@Bojun-Feng
Copy link

Hi, I am interested in integrating phi 1.5. However, I am new to the codebase and am a bit overwhelmed by the large code base. Are there any recommended resources to get started?

@adivoj
Copy link

adivoj commented Oct 6, 2023

Hi Bojum,

Good to hear that you're interested. I found the models being defined here: /vllm/model_executor/models and are registered here: /vllm/model_executor/model_loader.py.

Here are some phi1.5 sources you can use:
https://huggingface.co/microsoft/phi-1_5/tree/main
PR: https://huggingface.co/microsoft/phi-1_5/discussions/22/files
https://github.com/OpenAccess-AI-Collective/axolotl/tree/2d60ba3a6ea4def14e6ab974299322a0bf90d5bb/src/axolotl/models/phi

I was thinking of pushing all models to GPT4 and make it come up with the new one based on modeling_mixformer_sequential.py :)

@Bojun-Feng
Copy link

Hi adivoj, thank you so much for the quick response! Unfortunately, I found out that my machine is not compatible with vLLM and therefore can not locally test the code. As a result, other people might be a better fit...

@Linzecong
Copy link

I tried to adapt it and successfully aligned the model weights. The output results were correct before passing through the attention layer, but they were incorrect after passing through the attention layer. Later I found that he used CrossAttention https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py#L518 I am not sure if this problem is caused.

@andersonbcdefg
Copy link

andersonbcdefg commented Oct 17, 2023

I was looking into this. I believe the "cross-attention" used in Phi-1.5 is not true cross-attention, it's just used for current token to attend to past KV-cache during autoregressive generation. From what I can see, the Phi-1.5 architecture is basically the same as a GPT-NeoX with attention/FFN in parallel, except:

  1. In each transformer block the pre-FFN and pre-Attention layernorms are the same (GPT-NeoX doesn't share these params, as noted in their paper).
  2. Phi-1.5 has a bias on the output linear layer, and GPT-NeoX does not.

Alternatively, Phi-1.5 has the same architecture as GPT-J, except:

  1. GPT-J has separate q, k, v projections and Phi-1.5 has W_qkv
  2. GPT-J does not have bias on the q, k, v layers, and Phi-1.5 does have bias on W_qkv

It shouldn't be too crazy to adapt these, given that vLLM already supports GPT-NeoX and GPT-J. The MixFormer modelling file on Hugging Face is unnecessarily complicated relative to how small the architecture changes they made are relative to GPT-NeoX/GPT-J. To me it seems like the easiest way would be:

  • Start with GPT-NeoX model, with the right configuration to match Phi-1.5 (including use_parallel_residual)
  • When loading weights in from Phi-1.5, just copy the same layernorm weight and bias from Phi into both input_layernorm and post_attn_layernorm in GPT-NeoX
  • Modify the output linear layer to have a bias so that the bias can be loaded in as well.

This is all it should take!! I don't have time to do this myself right now because I haven't contributed to vLLM before, but for anyone who's added a model before this should be easy!!

@check39
Copy link

check39 commented Nov 5, 2023

Any update on this? Can anyone make change accordingly if you are familiar with code base?

@maximzubkov
Copy link
Contributor

Hey guys, just opened the PR integrating phi-1.5 into the codebase. Even though the model works, the output on some sequences does not exactly match the HF implementation, so if anyone could have a look, it could fasten the merge progress a lot.

Thanks! @andersonbcdefg @Linzecong @adivoj @Bojun-Feng

@adivoj
Copy link

adivoj commented Nov 14, 2023

Wow, thanks a lot man! All the best to you :) Couldn't it be the temp. or top_p? You had those at minimum while testing I guess.

@WoosukKwon WoosukKwon linked a pull request Nov 15, 2023 that will close this issue
@Linzecong
Copy link

Linzecong commented Nov 15, 2023

Hey guys, just opened the PR integrating phi-1.5 into the codebase. Even though the model works, the output on some sequences does not exactly match the HF implementation, so if anyone could have a look, it could fasten the merge progress a lot.

Thanks! @andersonbcdefg @Linzecong @adivoj @Bojun-Feng

Thank you very much for your support! I did a quick verification, I set top_k to 1 and the output aligned in most cases. I think it is caused by floating point precision issues, because when I use float32 to load the model, the hf output is very stable when top_k is 1. But when using half to load, even if top_k is 1, the output is still unstable.

@adivoj
Copy link

adivoj commented Nov 16, 2023

Man, they just announced phi v 2.0 https://youtu.be/S2gn_EoSWac?feature=shared

@maximzubkov
Copy link
Contributor

I almost envisioned it when @sytelus reacted on the PR 😅

@olsaarik
Copy link

Just a heads up, the microsoft/phi-1_5 model has been updated: https://huggingface.co/microsoft/phi-1_5/commit/271c3397ab4e1f8f4e49868b1e8ba0be95363c88 The relevant changes for vLLM are only in naming. The architecture is now called "PhiForCausalLM" and names of model weights have also changed.

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Nov 17, 2023

Hi @olsaarik, thanks for the heads up! We noticed the change and adapted to it in #1664. I've checked that vLLM works with the new Phi model naming and weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants