Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Let every model be a reward model/embedding model for PRMs #9314

Closed
1 task done
zhuzilin opened this issue Oct 12, 2024 · 4 comments
Closed
1 task done

[RFC]: Let every model be a reward model/embedding model for PRMs #9314

zhuzilin opened this issue Oct 12, 2024 · 4 comments
Labels

Comments

@zhuzilin
Copy link
Contributor

Motivation.

As the openai o1 series of models gave us a peek on the great potential of RL, the interest on reward model, as a core component of model RL algorithms, are rising. Recently, we have tried to introduce new reward models into vllm by reusing embedding APIs (#8896) and have ongoing RFC to discuss a better server API for reward model specifically (#8967).

This RFC is trying to introduce a broader class of RMs by making all generation models be able to be served as embedding models. In this way, vllm could easily support process-supervised reward model (PRM).

PRM is a kind of rewad model that will give reward score on the intermediate steps of an llm response. For example, in cot, when the model is thinking step by step, prm could make a judgement on each thinking step, which gives a finer granularity for RL optimization. Also, many are guessing that PRM is a cruial component to replicate o1, e.g. GAIR-NLP/O1-Journey.

A common way to train PRM is to add an special token after each step, and use the output logits of the special token, to be more specific, select some tokens as the scoring levels and train an classifier on the logits of those tokens. This is called hard estimation in PRM training (contrary to soft estimation, which will do a regression instead of a classfication) and is used in the famous Let's Verify Step by Step (https://arxiv.org/abs/2305.20050) by openai, who also open sourced the PRM800K dataset along with the paper (https://github.com/openai/prm800k), and also Math-Shepherd (https://arxiv.org/abs/2312.08935). Take Math-Shepherd as a further illustration on how to train PRM, they were using Mistral 7B as the base model, using a ки as the special token, and +, - as the scoring tokens, so that whenever the PRM meet a ки, we can check the logits of + and - to see how good the previous step is. The PRM model of Math-Shepherd is open sourced here math-shepherd-mistral-7b-prm.

To support this kind of PRMs, we can simply add an default pooler method to all models (because we can see that math-shepherd-mistral-7b-prm is just a LlamaForCausalLM) and allow users to set the scoring tokens to truncate the logits (otherwise the vocab_size would be too large).

Proposed Change.

  1. Change the model classification from _GENERATION_MODELS and _EMBEDDING_MODELS into model that can or cannot do generation (that all models can be embedding model).
  2. Add a default pooler function to _GENERATION_MODELS. An example can be
    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        # we can trim the lm_head by calculating the hidden_states @ lm_head
        # and doing pooling at the same time.
        return self._pooler(self.lm_head, hidden_states, pooling_metadata)

where, the self._pooler will trim the logits with the scoring tokens provided by pooling_metadata.
3. Use the reward model of Math-Shepherd, i.e. math-shepherd-mistral-7b-prm as an example.

Feedback Period.

No response

CC List.

cc @DarkLight1337 @mgoin @youkaichao @noamgat @natolambert

Any Other Things.

Thank you for your time on this RFC:)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@zhuzilin zhuzilin added the RFC label Oct 12, 2024
@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 12, 2024

vLLM only supports running either generation or embedding model runner, not both at the same time. If we make every model support both generation and embedding, then we'll have to add a CLI flag to indicate to vLLM which mode to use. I'm already planning to do this because there have already been some cases of conflicting architecture names (where the model linked to that name can be used for both generation and embedding), such as #6282 (comment) and #9303. Perhaps the user can also set the pooling mode this way?

@zhuzilin
Copy link
Contributor Author

zhuzilin commented Oct 12, 2024

then we'll have to add a CLI flag to indicate to vLLM which mode to use

This will be good enough for me :)

Perhaps the user can also set the pooling mode this way?

If you mean allowing the user to add a CLI flag to vllm so that the models can be used as embedding model and at that time, allowing users to truncate the logits (or some customization flag for the pooler), yeah, I think that will be nice.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 22, 2024

vLLM only supports running either generation or embedding model runner, not both at the same time. If we make every model support both generation and embedding, then we'll have to add a CLI flag to indicate to vLLM which mode to use. I'm already planning to do this because there have already been some cases of conflicting architecture names (where the model linked to that name can be used for both generation and embedding), such as #6282 (comment) and #9303. Perhaps the user can also set the pooling mode this way?

#9424 has been merged so any model can now be used as an embedding model once you add a pooler method to it. However, the pooling mode itself is not yet configurable.

@zhuzilin
Copy link
Contributor Author

That's great! Close as this is basically solved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants