Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Pydantic v2 adapter_id and merged_adapters validation #408

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, field_validator, Field, ConfigDict
from pydantic import BaseModel, field_validator, model_validator, Field, ConfigDict
from typing import Optional, List, Dict, Any, OrderedDict, Union

from lorax.errors import ValidationError
Expand Down Expand Up @@ -116,12 +116,13 @@ class Parameters(BaseModel):
# Optional response format specification to constrain the generated text
response_format: Optional[ResponseFormat] = None

@field_validator("adapter_id")
def valid_adapter_id(cls, v, values):
merged_adapters = values.get("merged_adapters")
if v is not None and merged_adapters is not None:
@model_validator(mode="after")
def valid_adapter_id(self):
adapter_id = self.adapter_id
merged_adapters = self.merged_adapters
if adapter_id is not None and merged_adapters is not None:
raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`")
return v
return self

@field_validator("adapter_source")
def valid_adapter_source(cls, v):
Expand Down
13 changes: 9 additions & 4 deletions clients/python/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from lorax.types import Parameters, Request
from lorax.types import Parameters, Request, MergedAdapters
from lorax.errors import ValidationError


Expand Down Expand Up @@ -68,6 +68,13 @@ def test_parameters_validation():
with pytest.raises(ValidationError):
Parameters(typical_p=1)

# Test adapter_id and merged_adapters
merged_adapters = MergedAdapters(ids=["test/adapter-id-1", "test/adapter-id-2"], weights=[0.5, 0.5], density=0.5)
Parameters(adapter_id="test/adapter-id")
Parameters(merged_adapters=merged_adapters)
with pytest.raises(ValidationError):
Parameters(adapter_id="test/adapter-id", merged_adapters=merged_adapters)


def test_request_validation():
Request(inputs="test")
Expand All @@ -79,6 +86,4 @@ def test_request_validation():
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))

with pytest.raises(ValidationError):
Request(
inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True
)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True)
Loading