diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index ad6dc3ab1..a34f0e612 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -22,15 +22,20 @@ class MergedAdapters(BaseModel): # Majority sign method majority_sign_method: Optional[str] - @validator("ids", "weights") - def validate_ids_weights(cls, ids, weights): - if not ids: + @validator("ids") + def validate_ids(cls, v): + if not v: raise ValidationError("`ids` cannot be empty") - if not weights: + return v + + @validator("weights") + def validate_weights(cls, v, values): + ids = values["ids"] + if not v: raise ValidationError("`weights` cannot be empty") - if len(ids) != len(weights): + if len(ids) != len(v): raise ValidationError("`ids` and `weights` must have the same length") - return ids, weights + return v @validator("merge_strategy") def validate_merge_strategy(cls, v): @@ -94,11 +99,12 @@ class Parameters(BaseModel): # Get decoder input token logprobs and ids decoder_input_details: bool = False - @validator("adapter_id", "merged_adapters") - def valid_adapter_id_merged_adapters(cls, adapter_id, merged_adapters): - if adapter_id is not None and merged_adapters is not None: + @validator("adapter_id") + def valid_adapter_id(cls, v, values): + merged_adapters = values.get("merged_adapters") + if v is not None and merged_adapters is not None: raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`") - return adapter_id, merged_adapters + return v @validator("adapter_source") def valid_adapter_source(cls, v):