Skip to content

Commit

Permalink
Fix Pydantic v2 adapter_id and merged_adapters validation (#408)
Browse files Browse the repository at this point in the history
  • Loading branch information
claudioMontanari authored Apr 11, 2024
1 parent 70db455 commit 30174d7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
13 changes: 7 additions & 6 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, field_validator, Field, ConfigDict
from pydantic import BaseModel, field_validator, model_validator, Field, ConfigDict
from typing import Optional, List, Dict, Any, OrderedDict, Union

from lorax.errors import ValidationError
Expand Down Expand Up @@ -116,12 +116,13 @@ class Parameters(BaseModel):
# Optional response format specification to constrain the generated text
response_format: Optional[ResponseFormat] = None

@field_validator("adapter_id")
def valid_adapter_id(cls, v, values):
merged_adapters = values.get("merged_adapters")
if v is not None and merged_adapters is not None:
@model_validator(mode="after")
def valid_adapter_id(self):
adapter_id = self.adapter_id
merged_adapters = self.merged_adapters
if adapter_id is not None and merged_adapters is not None:
raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`")
return v
return self

@field_validator("adapter_source")
def valid_adapter_source(cls, v):
Expand Down
13 changes: 9 additions & 4 deletions clients/python/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from lorax.types import Parameters, Request
from lorax.types import Parameters, Request, MergedAdapters
from lorax.errors import ValidationError


Expand Down Expand Up @@ -68,6 +68,13 @@ def test_parameters_validation():
with pytest.raises(ValidationError):
Parameters(typical_p=1)

# Test adapter_id and merged_adapters
merged_adapters = MergedAdapters(ids=["test/adapter-id-1", "test/adapter-id-2"], weights=[0.5, 0.5], density=0.5)
Parameters(adapter_id="test/adapter-id")
Parameters(merged_adapters=merged_adapters)
with pytest.raises(ValidationError):
Parameters(adapter_id="test/adapter-id", merged_adapters=merged_adapters)


def test_request_validation():
Request(inputs="test")
Expand All @@ -79,6 +86,4 @@ def test_request_validation():
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))

with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True)

0 comments on commit 30174d7

Please sign in to comment.