From 422f2903245c004d23ee6a0090e2b190868715da Mon Sep 17 00:00:00 2001 From: Jeffrey Tang Date: Tue, 20 Feb 2024 15:29:31 -0600 Subject: [PATCH] enh: JSON schema for guided generation now optionally respects field order --- clients/python/lorax/types.py | 4 ++-- launcher/Cargo.toml | 2 +- router/Cargo.toml | 2 +- server/lorax_server/utils/logits_process.py | 1 + 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 8ad776c70..8dd024c7c 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator, Field, ConfigDict -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, OrderedDict, Union from lorax.errors import ValidationError @@ -64,7 +64,7 @@ class ResponseFormat(BaseModel): model_config = ConfigDict(use_enum_values=True) type: ResponseFormatType - schema_spec: Dict[str, Any] = Field(alias="schema") + schema_spec: Union[Dict[str, Any], OrderedDict] = Field(alias="schema") class Parameters(BaseModel): diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 803cc31ce..9053b1313 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -11,7 +11,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } nix = "0.26.2" serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.93" +serde_json = { version = "1.0.93", features = ["preserve_order"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } diff --git a/router/Cargo.toml b/router/Cargo.toml index b6c750280..de69f34b4 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -30,7 +30,7 @@ opentelemetry-otlp = "0.12.0" rand = "0.8.5" reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" -serde_json = "1.0.93" +serde_json = { version = "1.0.93", features = ["preserve_order"] } thiserror = "1.0.38" tokenizers = "0.13.4" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 24a6ee0cd..232913ad9 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -485,6 +485,7 @@ def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): self.tokenizer = self.adapt_tokenizer(tokenizer) regex_string = build_regex_from_object(schema) + regex_string = '[\\n ]*' + regex_string # Hack to allow preceding whitespace self.fsm = RegexFSM(regex_string, tokenizer) self.fsm_state = FSMState(0)