diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9f3b6f59068e2..71b91f51d1362 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -350,6 +350,11 @@ Text Embedding - :code:`BAAI/bge-multilingual-gemma2`, etc. - - ✅︎ + * - :code:`GritLM` + - GritLM + - :code:`parasail-ai/GritLM-7B-vllm`. + - + - * - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc. - Llama-based - :code:`intfloat/e5-mistral-7b-instruct`, etc. diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py new file mode 100644 index 0000000000000..89c02e1a7951e --- /dev/null +++ b/tests/models/embedding/language/test_gritlm.py @@ -0,0 +1,144 @@ +import math +import os +from typing import List + +import openai +import pytest +import pytest_asyncio +from scipy.spatial.distance import cosine + +import vllm + +from ....utils import RemoteOpenAIServer + +MODEL_NAME = "parasail-ai/GritLM-7B-vllm" + +# GritLM implementation is only supported by XFormers backend. +os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embedding", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +def run_llm_encode(llm: vllm.LLM, queries: List[str], instruction: str, + use_instruction_arg: bool) -> List[float]: + pooling_params = vllm.PoolingParams( + additional_data={"instruction_seq": instruction + }) if use_instruction_arg else None + outputs = llm.encode( + [instruction + q for q in queries], + pooling_params=pooling_params, + ) + return [output.outputs.embedding for output in outputs] + + +async def run_client_embeddings(client: vllm.LLM, queries: List[str], + instruction: str, + use_instruction_arg: bool) -> List[float]: + additional_data = { + "instruction_seq": instruction + } if use_instruction_arg else None + outputs = await client.embeddings.create( + model=MODEL_NAME, + input=[instruction + q for q in queries], + extra_body={"additional_data": additional_data}, + ) + return [data.embedding for data in outputs.data] + + +def gritlm_instruction(instruction): + return ("<|user|>\n" + instruction + + "\n<|embed|>\n" if instruction else "<|embed|>\n") + + +def get_test_data(): + """ + Grabbed this test data and the expected values from + README.md in https://github.com/ContextualAI/gritlm + """ + q_instruction = gritlm_instruction( + "Given a scientific paper title, retrieve the paper's abstract") + queries = [ + "Bitcoin: A Peer-to-Peer Electronic Cash System", + "Generative Representational Instruction Tuning", + ] + + d_instruction = gritlm_instruction("") + documents = [ + # ruff: noqa: E501 + "A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.", + "All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.", + ] + + return queries, q_instruction, documents, d_instruction + + +def validate_output(q_rep: List[float], d_rep: List[float]): + cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) + assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) + + cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1]) + assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001) + + cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0]) + assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001) + + cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) + assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) + + +@pytest.mark.parametrize("use_instruction_arg", [True, False]) +def test_gritlm_offline(use_instruction_arg: bool): + queries, q_instruction, documents, d_instruction = get_test_data() + + llm = vllm.LLM(MODEL_NAME, task="embedding") + + d_rep = run_llm_encode( + llm, + documents, + d_instruction, + use_instruction_arg=use_instruction_arg, + ) + q_rep = run_llm_encode( + llm, + queries, + q_instruction, + use_instruction_arg=use_instruction_arg, + ) + + validate_output(q_rep, d_rep) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_instruction_arg", [True, False]) +async def test_gritlm_api_server(client: openai.AsyncOpenAI, + use_instruction_arg: bool): + queries, q_instruction, documents, d_instruction = get_test_data() + + d_rep = await run_client_embeddings( + client, + documents, + d_instruction, + use_instruction_arg=use_instruction_arg, + ) + q_rep = await run_client_embeddings( + client, + queries, + q_instruction, + use_instruction_arg=use_instruction_arg, + ) + + validate_output(q_rep, d_rep) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index d23009dae01ee..1a3d34187ec35 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -12,6 +12,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, @@ -523,7 +524,7 @@ def _schedule_running( chunked number of tokens are scheduled if `budget.num_batched_tokens` has not enough capacity to schedule all tokens. - + Returns: SchedulerRunningOutputs. """ @@ -841,10 +842,10 @@ def _schedule_priority_preemption( self._get_num_new_uncached_and_cached_tokens( seq_group, SequenceStatus.WAITING, False, budget)) - #Only preempt if priority inversion exists + # Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): - #Only preempt if waiting sequence cannot be allocated + # Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) if (num_new_tokens_uncached > 0 and can_allocate == AllocStatus.OK @@ -854,7 +855,7 @@ def _schedule_priority_preemption( )): break - #Adjust budget to remove the victim sequence group + # Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() num_running_tokens_uncached, _ = ( self._get_num_new_uncached_and_cached_tokens( @@ -865,11 +866,11 @@ def _schedule_priority_preemption( budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) - #Preempt out the victim sequence group + # Preempt out the victim sequence group self._preempt(vseq_group, blocks_to_swap_out) waiting_queue.appendleft(vseq_group) force_preemption_count += 1 - #Put the sequence back into the waiting queue + # Put the sequence back into the waiting queue waiting_queue.appendleft(seq_group) waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) @@ -1036,7 +1037,7 @@ def _schedule_prefills( def _schedule_default(self) -> SchedulerOutputs: """Schedule queued requests. - + The current policy is designed to optimize the throughput. First, it batches as many prefill requests as possible. And it schedules decodes. If there's a pressure on GPU memory, decode requests can @@ -1141,7 +1142,7 @@ def _schedule_default(self) -> SchedulerOutputs: def _schedule_chunked_prefill(self) -> SchedulerOutputs: """Schedule queued requests. - + Chunked prefill allows to chunk prefill requests, batch them together with decode requests. This policy 1. schedule as many decoding requests as possible. 2. schedule chunked prefill requests that are not @@ -1350,6 +1351,25 @@ def schedule( seqs[0].data.get_len()): do_sample = False + pooling_params = seq_group.pooling_params + + # Store instruction_seq in pooling_params. + instruction_seq = seq.inputs.inputs.get("instruction_seq") + if instruction_seq is not None: + if pooling_params is None: + pooling_params = PoolingParams() + pooling_params.additional_data = { + "instruction_seq": instruction_seq + } + elif pooling_params.additional_data is None: + pooling_params.additional_data = { + "instruction_seq": instruction_seq + } + else: + pooling_params.additional_data[ + "instruction_seq"] = seq.inputs.inputs.get( + "instruction_seq") + # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. if is_first_prefill or not self.scheduler_config.send_delta_data: @@ -1360,7 +1380,7 @@ def schedule( sampling_params=seq_group.sampling_params, block_tables=block_tables, do_sample=do_sample, - pooling_params=seq_group.pooling_params, + pooling_params=pooling_params, token_chunk_size=token_chunk_size, lora_request=seq_group.lora_request, computed_block_nums=common_computed_block_nums, diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb7dbbebd7b90..141aaf27307ae 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -163,6 +163,14 @@ class TokenInputs(TypedDict): to pass the mm_processor_kwargs to each of them. """ + instruction_seq: NotRequired[Optional[str]] + """ + The instruction sequence that is usually prepended to the original prompt + when passing to the model. Certain models need to extract this instruction + sequence from the prompt in order to adjust certain operations of the + model such as the attention mask. + """ + def token_inputs( prompt_token_ids: List[int], @@ -171,6 +179,7 @@ def token_inputs( multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + instruction_seq: Optional[str] = None, ) -> TokenInputs: """Construct :class:`TokenInputs` from optional values.""" inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids) @@ -185,6 +194,8 @@ def token_inputs( inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: inputs["mm_processor_kwargs"] = mm_processor_kwargs + if instruction_seq is not None: + inputs["instruction_seq"] = instruction_seq return inputs diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py new file mode 100644 index 0000000000000..97d99c9057241 --- /dev/null +++ b/vllm/model_executor/models/gritlm.py @@ -0,0 +1,223 @@ +import re +from typing import List, Optional, Union + +import torch +from torch import nn +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention import AttentionMetadata +from vllm.attention.backends.xformers import XFormersImpl +from vllm.config import ModelConfig, VllmConfig +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, + token_inputs) +from vllm.logger import init_logger +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.pooling_params import PoolingParams +from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, + PoolerOutput) + +logger = init_logger(__name__) + + +class GritLMPooler(nn.Module): + + def __init__( + self, + model_config: ModelConfig, + ): + super().__init__() + + self.model_config = model_config + + def _get_instruction_lens( + self, device: torch.device, + pooling_metadata: PoolingMetadata) -> torch.Tensor: + """ + Compute the number of tokens of each instruction using the tokenizer. + """ + self.tokenizer = cached_get_tokenizer( + self.model_config.tokenizer, + tokenizer_mode=self.model_config.tokenizer_mode, + tokenizer_revision=self.model_config.tokenizer_revision, + trust_remote_code=self.model_config.trust_remote_code, + truncation_side="left", + ) + + def query_instruction_missing(pooling_params: PoolingParams) -> bool: + return (pooling_params is None + or pooling_params.additional_data is None + or "instruction_seq" not in pooling_params.additional_data) + + for seq_group in pooling_metadata.seq_groups: + if query_instruction_missing(seq_group[1]): + logger.warning( + "Query instruction not found in prompt," + "thus using empty string instead. GritLM requires " + "query instruction in prompt.") + + instruction_lens = torch.tensor( + [ + len( + self.tokenizer( + ("" if query_instruction_missing(seq_group[1]) else + seq_group[1].additional_data["instruction_seq"]), + padding=False, + truncation=True, + add_special_tokens=True, + )["input_ids"]) + for seq_group in pooling_metadata.seq_groups + ], + device=device, + ) + + return instruction_lens + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """ + Pool the hidden states by summing the embeddings of + non-instruction tokens. + """ + instruction_lens = self._get_instruction_lens( + device=hidden_states.device, pooling_metadata=pooling_metadata) + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + mask = torch.zeros_like(hidden_states, dtype=torch.bool) + + start_idx = 0 + for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): + end_idx = start_idx + prompt_len + mask[start_idx + instruction_len:end_idx] = True + start_idx = end_idx + + masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) + + sum_embeddings = torch.zeros(len(prompt_lens), + hidden_states.size(1), + device=hidden_states.device) + + start_idx = 0 + for i, prompt_len in enumerate(prompt_lens): + end_idx = start_idx + prompt_len + sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( + dim=0) + start_idx = end_idx + + num_non_instruction_tokens = prompt_lens - instruction_lens + mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( + 1) + + pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) + + +def input_processor_for_gritlm(ctx: InputContext, inputs: DecoderOnlyInputs): + """ + Extracts query instruction from prompt and adds it to token inputs. + """ + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + prompt = inputs.get("prompt", None) + instruction = "" + + if prompt is None and "prompt_token_ids" in inputs: + prompt = tokenizer.decode(inputs["prompt_token_ids"]) + + if prompt is not None: + match_instruction = re.match(r"( )?(<\|user\|>\n.*\n<\|embed\|>\n)", + prompt) + match_empty_instruction = re.match(r"( )?(<\|embed\|>\n)", prompt) + + if match_instruction and match_instruction.group(2): + instruction = match_instruction.group(2) + elif match_empty_instruction: + instruction = match_empty_instruction.group(2) + else: + logger.warning("Query instruction not found in prompt," + "thus using empty string instead. GritLM requires " + "query instruction in prompt.") + + return token_inputs( + prompt_token_ids=inputs["prompt_token_ids"], + prompt=prompt, + instruction_seq=instruction, + ) + + +@INPUT_REGISTRY.register_input_processor(input_processor_for_gritlm) +class GritLM(LlamaForCausalLM): + """This class implements the embedding model for parasail-ai/GritLM-7B-vllm. + + The class inherits from LlamaForCausalLM and provides a custom pooling + layer. + + The task "embedding" must be specified in the server arguments. + + The main difference between the pooling layer in GritLM and the one in + LlamaForCausalLM is that GritLM ignores the query instruction in the prompt + when pooling the hidden states. + + Instructions can be passed to the model in two ways: + 1. By prepending the instruction to the prompt. The instruction should be + in the format "<|user|>\n\n<|embed|>\n". + 2. By passing the instruction as additional data in the pooling parameters + (e.g. extra_body of client.embeddings.create). + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + **kwargs, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self._pooler = GritLMPooler(model_config=vllm_config.model_config) + + assert isinstance( + self.model.layers[0].self_attn.attn.impl, + XFormersImpl), "GritLM is only supported by XFormers backend, " + "which can be forced by VLLM_ATTENTION_BACKEND=XFORMERS" + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + # Change attention to non-causal. + assert attn_metadata.prefill_metadata.attn_bias is None + attn_metadata.prefill_metadata.attn_bias = [ + BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + ] + + return super().forward( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + **kwargs, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2b7b69e8c3a95..1a0aec7b3188e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -110,6 +110,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "GritLM": ("gritlm", "GritLM"), "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all