Skip to content

Commit

Permalink
api: add json_schema to OpenAI server (#915)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale authored Dec 18, 2024
1 parent b1492c1 commit f61acdd
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 35 deletions.
14 changes: 12 additions & 2 deletions aphrodite/endpoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,19 @@ class UsageInfo(OpenAIBaseModel):
completion_tokens: Optional[int] = 0


class JsonSchemaResponseFormat(OpenAIBaseModel):
name: str
description: Optional[str] = None
# schema is the field in openai but that causes conflicts with pydantic so
# instead use json_schema with an alias
json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
strict: Optional[bool] = None


class ResponseFormat(OpenAIBaseModel):
# type must be "json_object" or "text"
type: Literal["text", "json_object"]
# type must be "json_schema", "json_object" or "text"
type: Literal["text", "json_object", "json_schema"]
json_schema: Optional[JsonSchemaResponseFormat] = None


class StreamOptions(OpenAIBaseModel):
Expand Down
17 changes: 7 additions & 10 deletions aphrodite/modeling/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,6 @@
CompletionRequest)
from aphrodite.modeling.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from aphrodite.triton_utils import HAS_TRITON

if HAS_TRITON:
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)


async def get_guided_decoding_logits_processor(
Expand All @@ -20,10 +14,11 @@ async def get_guided_decoding_logits_processor(
tokenizer) -> Optional[LogitsProcessorFunc]:
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
if HAS_TRITON:
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
else:
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
pass
if guided_decoding_backend == 'lm-format-enforcer':
from aphrodite.modeling.guided_decoding.lm_format_enforcer_decoding import ( # noqa
Expand All @@ -42,6 +37,8 @@ def get_local_guided_decoding_logits_processor(
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
Expand Down
18 changes: 12 additions & 6 deletions aphrodite/modeling/guided_decoding/lm_format_enforcer_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
from aphrodite.modeling.guided_decoding.lm_format_enforcer_logits_processors import ( # noqa: E501
build_aphrodite_logits_processor,
build_aphrodite_token_enforcer_tokenizer_data)
from aphrodite.triton_utils import HAS_TRITON

if HAS_TRITON:
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)


async def get_lm_format_enforcer_guided_decoding_logits_processor(
Expand All @@ -47,12 +41,21 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None

Expand Down Expand Up @@ -83,6 +86,9 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
elif guided_options.guided_regex:
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
from aphrodite.modeling.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)

# CFG grammar not supported by LMFE, revert to outlines
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
Expand Down
7 changes: 7 additions & 0 deletions aphrodite/modeling/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def _get_guide_and_mode(
and request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else:
return None, None

Expand Down
43 changes: 43 additions & 0 deletions tests/endpoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import sys

from aphrodite import LLM, SamplingParams


def test_lazy_outlines(sample_regex):
"""If users don't use guided decoding, outlines should not be imported.
"""
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# make sure outlines is not imported
assert 'outlines' not in sys.modules
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# make sure outlines is not imported
assert 'outlines' not in sys.modules
49 changes: 32 additions & 17 deletions tests/endpoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
LORA_NAME = "alpindale/zephyr-7b-beta-lora"


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -755,22 +755,6 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded


@pytest.mark.asyncio
async def test_extra_fields(client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant.",
"extra_field": "0",
}], # type: ignore
temperature=0,
seed=0)

assert "extra_forbidden" in exc_info.value.message


@pytest.mark.asyncio
async def test_complex_message_content(client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
Expand Down Expand Up @@ -840,3 +824,34 @@ async def test_long_seed(client: openai.AsyncOpenAI):

assert ("greater_than_equal" in exc_info.value.message
or "less_than_equal" in exc_info.value.message)


@pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
}],
response_format={
"type": "json_schema",
"json_schema": {
"name": "foo_test",
"schema": {
"type": "object",
"properties": {
"result": {
"type": "integer"
},
},
},
}
})
content = resp.choices[0].message.content
assert content is not None
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded

0 comments on commit f61acdd

Please sign in to comment.