Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for guided decoding in offline interface #4130

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a5aeec2
first commit, added extra parameters for offline LLM, issue with work…
kevinbu233 Apr 16, 2024
dcbfc69
clean up cod code
kevinbu233 Apr 16, 2024
bd4d84c
add skip cleanup
simon-mo Apr 17, 2024
76c0924
Changed model and sampling parameters, cleaned up naming
kevinbu233 Apr 17, 2024
4709a98
cleand up code and created helper functions
kevinbu233 Apr 18, 2024
aab046d
first merge resolved
kevinbu233 Apr 18, 2024
84b1442
fix format
kevinbu233 Apr 19, 2024
1d6abe4
fix merge conflict
kevinbu233 Apr 19, 2024
daa2c8f
added docstrings for sampling params guided options
kevinbu233 Apr 23, 2024
35c73ee
fix merge conflict with main
kevinbu233 Apr 23, 2024
9e454c5
:erge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
ef4cf6f
fixed support for multiple sampling params for LLM
kevinbu233 Apr 23, 2024
945125d
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Apr 23, 2024
1c0769d
added noqa for extra long line
kevinbu233 Apr 24, 2024
062f0bc
Update tests/entrypoints/test_local_LLM.py
simon-mo May 1, 2024
05a2512
Merge branch 'main' of github.com:vllm-project/vllm into yihuan_issue…
simon-mo May 2, 2024
438ab37
fix typing
simon-mo May 2, 2024
06c2205
fix test and more refactoring
simon-mo May 3, 2024
4158d78
use x2
simon-mo May 3, 2024
0d9e5a5
lint
simon-mo May 3, 2024
ff9ba7f
fix isort
simon-mo May 3, 2024
bbb59bf
merge with main
kevinbu233 May 27, 2024
d779f86
fixing merge issues
kevinbu233 May 30, 2024
f923677
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
a15b511
fixed merge
kevinbu233 May 30, 2024
292264a
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 May 30, 2024
77d42a8
format
kevinbu233 May 30, 2024
60ab6f6
finished merge first draft
kevinbu233 Jun 14, 2024
ae23772
merged main
kevinbu233 Jun 14, 2024
3fb6258
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 19, 2024
75cf9a7
fixed merge conflict and fixed suggestions
kevinbu233 Jun 19, 2024
c429ef8
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 21, 2024
da89c1b
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 22, 2024
6c8b82a
fix test_openai error
kevinbu233 Jun 23, 2024
4e759d9
Merge remote-tracking branch 'upstream/main' into yihuan_issue3536
kevinbu233 Jun 24, 2024
4ac8abb
fixing response_format test
kevinbu233 Jun 24, 2024
2eedeba
temporay push
kevinbu233 Jun 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 223 additions & 0 deletions tests/entrypoints/test_local_LLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# imports for guided decoding tests
import json
import os
import re

import jsonschema
import pytest

# downloading lora to test lora requests
from huggingface_hub import snapshot_download
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved

from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams

from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.outputs import RequestOutput

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"

TEST_SCHEMA = {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "string"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work history"]
}

TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")

TEST_CHOICE = [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
"Swift", "Kotlin"
]

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(scope="session")
def llm():
return LLM(model=MODEL_NAME, max_model_len=15000)
simon-mo marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skip_global_cleanup
def test_simple_prompts(llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_regex_(llm):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_options=dict(guided_regex=TEST_REGEX))
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {TEST_REGEX}"
],
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert re.fullmatch(TEST_REGEX, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(llm):
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_options=dict(guided_json=TEST_SCHEMA),
max_tokens=1000)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {TEST_SCHEMA}"
],
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
print(outputs)
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)

@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_choice=TEST_CHOICE))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
assert generated_text in TEST_CHOICE
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(llm):
simple_sql_grammar = """
start: select_statement

select_statement: "SELECT" column "from" table "where" condition

column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number

number: "1" | "2"
"""
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_options=dict(guided_grammar=simple_sql_grammar))
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None

# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(simple_sql_grammar)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


if __name__ == "__main__":
pytest.main([__file__])
5 changes: 3 additions & 2 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Set

from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
Expand Down
3 changes: 1 addition & 2 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional

from typing import Sequence as GenericSequence
from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations
import enum
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List
from typing import Sequence as GenericSequence

from vllm.sequence import Sequence, SequenceGroup

Expand Down
9 changes: 9 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter

from vllm.model_executor.guided_decoding import get_local_guided_decoding_logits_processor


class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
Expand Down Expand Up @@ -176,10 +178,17 @@ def generate(
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)

guided_decode_logits_processor = get_local_guided_decoding_logits_processor(sampling_params, self.get_tokenizer())
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]

self._add_request(
prompt,
sampling_params,
Expand Down
57 changes: 56 additions & 1 deletion vllm/model_executor/guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Tuple, Union
from typing import Tuple, Union, Dict

from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
Expand All @@ -16,6 +16,7 @@
JSONLogitsProcessor,
RegexLogitsProcessor)

from vllm.sampling_params import SamplingParams

class GuidedDecodingMode(Enum):
JSON = "json"
Expand Down Expand Up @@ -82,6 +83,60 @@ async def get_guided_decoding_logits_processor(
logits_processor.init_state()
return logits_processor

def get_local_guided_decoding_logits_processor(sampling_params, tokenizer):
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
# global global_thread_pool
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved

guide, mode = _get_guide_and_mode_from_sampling_params(sampling_params.guided_options)
if not guide:
return None

# if global_thread_pool is None:
# global_thread_pool = concurrent.futures.ThreadPoolExecutor(
# max_workers=2)

kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
result = _get_cached_logits_processor(guide, tokenizer, mode)

logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor


def _get_guide_and_mode_from_sampling_params(
guided_options: Dict[str, str]
) -> Tuple[str, GuidedDecodingMode]:
if not guided_options:
return None, None

if "guided_json" in guided_options:
json = guided_options["guided_json"]
if isinstance(json, dict):
# turn dict into hashable string
json = json_dumps(json)
elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
return json, GuidedDecodingMode.JSON
elif "guided_regex" in guided_options:
return guided_options["guided_regex"], GuidedDecodingMode.REGEX
elif "guided_choice" in guided_options:
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in guided_options["guided_choice"]
]
choices_regex = "(" + "|".join(choices) + ")"
kevinbu233 marked this conversation as resolved.
Show resolved Hide resolved
return choices_regex, GuidedDecodingMode.CHOICE
elif "guided_grammar" in guided_options:
return guided_options["guided_grammar"], GuidedDecodingMode.GRAMMAR
else:
return None, None

def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
Expand Down
Loading