Skip to content

Commit

Permalink
Add support for embedding model parasail-ai/GritLM-7B-vllm
Browse files Browse the repository at this point in the history
This model is a fork of GritLM/GritLM-7B. The main change in the fork
wrt the original repo is the name of the architecture to make vLLM
adoption easier.

Signed-off-by: Pooya Davoodi <[email protected]>
  • Loading branch information
pooyadavoodi committed Dec 2, 2024
1 parent c11f172 commit 1d60ff6
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 9 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ Text Embedding
- :code:`BAAI/bge-multilingual-gemma2`, etc.
-
- ✅︎
* - :code:`GritLM`
- GritLM
- :code:`parasail-ai/GritLM-7B-vllm`.
-
-
* - :code:`LlamaModel`, :code:`LlamaForCausalLM`, :code:`MistralModel`, etc.
- Llama-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
Expand Down
144 changes: 144 additions & 0 deletions tests/models/embedding/language/test_gritlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import math
import os
from typing import List

import openai
import pytest
import pytest_asyncio
from scipy.spatial.distance import cosine

import vllm

from ....utils import RemoteOpenAIServer

MODEL_NAME = "parasail-ai/GritLM-7B-vllm"

# GritLM implementation is only supported by XFormers backend.
os.environ["VLLM_ATTENTION_BACKEND"] = "XFORMERS"


@pytest.fixture(scope="module")
def server():
args = [
"--task",
"embedding",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client


def run_llm_encode(llm: vllm.LLM, queries: List[str], instruction: str,
use_instruction_arg: bool) -> List[float]:
pooling_params = vllm.PoolingParams(
additional_data={"instruction_seq": instruction
}) if use_instruction_arg else None
outputs = llm.encode(
[instruction + q for q in queries],
pooling_params=pooling_params,
)
return [output.outputs.embedding for output in outputs]


async def run_client_embeddings(client: vllm.LLM, queries: List[str],
instruction: str,
use_instruction_arg: bool) -> List[float]:
additional_data = {
"instruction_seq": instruction
} if use_instruction_arg else None
outputs = await client.embeddings.create(
model=MODEL_NAME,
input=[instruction + q for q in queries],
extra_body={"additional_data": additional_data},
)
return [data.embedding for data in outputs.data]


def gritlm_instruction(instruction):
return ("<|user|>\n" + instruction +
"\n<|embed|>\n" if instruction else "<|embed|>\n")


def get_test_data():
"""
Grabbed this test data and the expected values from
README.md in https://github.com/ContextualAI/gritlm
"""
q_instruction = gritlm_instruction(
"Given a scientific paper title, retrieve the paper's abstract")
queries = [
"Bitcoin: A Peer-to-Peer Electronic Cash System",
"Generative Representational Instruction Tuning",
]

d_instruction = gritlm_instruction("")
documents = [
# ruff: noqa: E501
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
]

return queries, q_instruction, documents, d_instruction


def validate_output(q_rep: List[float], d_rep: List[float]):
cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0])
assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001)

cosine_sim_q0_d1 = 1 - cosine(q_rep[0], d_rep[1])
assert math.isclose(cosine_sim_q0_d1, 0.101, abs_tol=0.001)

cosine_sim_q1_d0 = 1 - cosine(q_rep[1], d_rep[0])
assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001)

cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1])
assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001)


@pytest.mark.parametrize("use_instruction_arg", [True, False])
def test_gritlm_offline(use_instruction_arg: bool):
queries, q_instruction, documents, d_instruction = get_test_data()

llm = vllm.LLM(MODEL_NAME, task="embedding")

d_rep = run_llm_encode(
llm,
documents,
d_instruction,
use_instruction_arg=use_instruction_arg,
)
q_rep = run_llm_encode(
llm,
queries,
q_instruction,
use_instruction_arg=use_instruction_arg,
)

validate_output(q_rep, d_rep)


@pytest.mark.asyncio
@pytest.mark.parametrize("use_instruction_arg", [True, False])
async def test_gritlm_api_server(client: openai.AsyncOpenAI,
use_instruction_arg: bool):
queries, q_instruction, documents, d_instruction = get_test_data()

d_rep = await run_client_embeddings(
client,
documents,
d_instruction,
use_instruction_arg=use_instruction_arg,
)
q_rep = await run_client_embeddings(
client,
queries,
q_instruction,
use_instruction_arg=use_instruction_arg,
)

validate_output(q_rep, d_rep)
38 changes: 29 additions & 9 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
Expand Down Expand Up @@ -523,7 +524,7 @@ def _schedule_running(
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
SchedulerRunningOutputs.
"""
Expand Down Expand Up @@ -841,10 +842,10 @@ def _schedule_priority_preemption(
self._get_num_new_uncached_and_cached_tokens(
seq_group, SequenceStatus.WAITING, False, budget))

#Only preempt if priority inversion exists
# Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
# Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens_uncached > 0
and can_allocate == AllocStatus.OK
Expand All @@ -854,7 +855,7 @@ def _schedule_priority_preemption(
)):
break

#Adjust budget to remove the victim sequence group
# Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens_uncached, _ = (
self._get_num_new_uncached_and_cached_tokens(
Expand All @@ -865,11 +866,11 @@ def _schedule_priority_preemption(
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)

#Preempt out the victim sequence group
# Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
# Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)

waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
Expand Down Expand Up @@ -1036,7 +1037,7 @@ def _schedule_prefills(

def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can
Expand Down Expand Up @@ -1141,7 +1142,7 @@ def _schedule_default(self) -> SchedulerOutputs:

def _schedule_chunked_prefill(self) -> SchedulerOutputs:
"""Schedule queued requests.
Chunked prefill allows to chunk prefill requests, batch them together
with decode requests. This policy 1. schedule as many decoding requests
as possible. 2. schedule chunked prefill requests that are not
Expand Down Expand Up @@ -1350,6 +1351,25 @@ def schedule(
seqs[0].data.get_len()):
do_sample = False

pooling_params = seq_group.pooling_params

# Store instruction_seq in pooling_params.
instruction_seq = seq.inputs.inputs.get("instruction_seq")
if instruction_seq is not None:
if pooling_params is None:
pooling_params = PoolingParams()
pooling_params.additional_data = {
"instruction_seq": instruction_seq
}
elif pooling_params.additional_data is None:
pooling_params.additional_data = {
"instruction_seq": instruction_seq
}
else:
pooling_params.additional_data[
"instruction_seq"] = seq.inputs.inputs.get(
"instruction_seq")

# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
Expand All @@ -1360,7 +1380,7 @@ def schedule(
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
do_sample=do_sample,
pooling_params=seq_group.pooling_params,
pooling_params=pooling_params,
token_chunk_size=token_chunk_size,
lora_request=seq_group.lora_request,
computed_block_nums=common_computed_block_nums,
Expand Down
11 changes: 11 additions & 0 deletions vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ class TokenInputs(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""

instruction_seq: NotRequired[Optional[str]]
"""
The instruction sequence that is usually prepended to the original prompt
when passing to the model. Certain models need to extract this instruction
sequence from the prompt in order to adjust certain operations of the
model such as the attention mask.
"""


def token_inputs(
prompt_token_ids: List[int],
Expand All @@ -171,6 +179,7 @@ def token_inputs(
multi_modal_data: Optional["MultiModalDataDict"] = None,
multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
instruction_seq: Optional[str] = None,
) -> TokenInputs:
"""Construct :class:`TokenInputs` from optional values."""
inputs = TokenInputs(type="token", prompt_token_ids=prompt_token_ids)
Expand All @@ -185,6 +194,8 @@ def token_inputs(
inputs["multi_modal_placeholders"] = multi_modal_placeholders
if mm_processor_kwargs is not None:
inputs["mm_processor_kwargs"] = mm_processor_kwargs
if instruction_seq is not None:
inputs["instruction_seq"] = instruction_seq

return inputs

Expand Down
Loading

0 comments on commit 1d60ff6

Please sign in to comment.