Skip to content

Commit

Permalink
API + Model: Add support for JSON schema constraints
Browse files Browse the repository at this point in the history
Add the ability to constrain the return value of a model to be JSON.
Built using the JSON schema standard to define the properties of what
the model should return.

This feature should be more accurate than using GBNF/EBNF to yield
the same results due to the use of lmformatenforcer.

GBNF/EBNF will be added in a different commit/branch.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Feb 23, 2024
1 parent 3608027 commit 8a22b53
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
61 changes: 61 additions & 0 deletions backends/exllamav2/grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from common.logger import init_logger
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2Sampler

# Temporary, remove once the exllama version is bumped
try:
from exllamav2.generator.filters import ExLlamaV2PrefixFilter

_exllama_filter_available = True
except ImportError:
_exllama_filter_available = False

try:
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter

_lmformatenforcer_available = True
except ImportError:
_lmformatenforcer_available = False


logger = init_logger(__name__)


class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""

def add_json_schema_filter(
self,
json_schema: dict,
gen_settings: ExLlamaV2Sampler.Settings,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on a JSON schema."""

# Check if the required dependencies can be imported
if not _exllama_filter_available:
logger.warning(
"ExllamaV2PrefixFilter is not available "
"in the currently installed ExllamaV2 version."
)

return

if not _lmformatenforcer_available:
logger.error(
"lmformatenforcer must be installed to parse a json schema.\n"
"Please run the following command: pip install lm-format-enforcer"
)

return

# Create the parser
schema_parser = JsonSchemaParser(json_schema)
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")

# Append the filters
gen_settings.filters += [lmfilter, prefix_filter]
gen_settings.filter_prefer_eos = True
11 changes: 11 additions & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from typing import List, Optional, Union

from backends.exllamav2.grammar import ExLlamaV2Grammar
from common.gen_logging import log_generation_params, log_prompt, log_response
from common.templating import (
PromptTemplate,
Expand Down Expand Up @@ -758,6 +759,16 @@ def generate_gen(self, prompt: str, **kwargs):
"in the model's vocab. Skipping."
)

# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()

# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, gen_settings, self.model, self.tokenizer
)

# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
Expand Down
5 changes: 5 additions & 0 deletions common/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("negative_prompt")
)

json_schema: Optional[object] = Field(
default_factory=lambda: get_default_sampler_value("json_schema"),
)

# Aliased variables
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
Expand Down Expand Up @@ -261,6 +265,7 @@ def to_gen_params(self, **kwargs):
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
"json_schema": self.json_schema,
}

return {**gen_params, **kwargs}
Expand Down

0 comments on commit 8a22b53

Please sign in to comment.