diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index b0b521836..f20a222a3 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -1,5 +1,5 @@ from enum import Enum -from pydantic import BaseModel, field_validator, Field, ConfigDict +from pydantic import BaseModel, field_validator, model_validator, Field, ConfigDict from typing import Optional, List, Dict, Any, OrderedDict, Union from lorax.errors import ValidationError @@ -116,12 +116,13 @@ class Parameters(BaseModel): # Optional response format specification to constrain the generated text response_format: Optional[ResponseFormat] = None - @field_validator("adapter_id") - def valid_adapter_id(cls, v, values): - merged_adapters = values.get("merged_adapters") - if v is not None and merged_adapters is not None: + @model_validator(mode="after") + def valid_adapter_id(self): + adapter_id = self.adapter_id + merged_adapters = self.merged_adapters + if adapter_id is not None and merged_adapters is not None: raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") - return v + return self @field_validator("adapter_source") def valid_adapter_source(cls, v): diff --git a/clients/python/tests/test_types.py b/clients/python/tests/test_types.py index ff6c39f99..a779d7bb8 100644 --- a/clients/python/tests/test_types.py +++ b/clients/python/tests/test_types.py @@ -1,6 +1,6 @@ import pytest -from lorax.types import Parameters, Request +from lorax.types import Parameters, Request, MergedAdapters from lorax.errors import ValidationError @@ -68,6 +68,13 @@ def test_parameters_validation(): with pytest.raises(ValidationError): Parameters(typical_p=1) + # Test adapter_id and merged_adapters + merged_adapters = MergedAdapters(ids=["test/adapter-id-1", "test/adapter-id-2"], weights=[0.5, 0.5], density=0.5) + Parameters(adapter_id="test/adapter-id") + Parameters(merged_adapters=merged_adapters) + with pytest.raises(ValidationError): + Parameters(adapter_id="test/adapter-id", merged_adapters=merged_adapters) + def test_request_validation(): Request(inputs="test") @@ -79,6 +86,4 @@ def test_request_validation(): Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True)) with pytest.raises(ValidationError): - Request( - inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True - ) + Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True)