Skip to content

Commit

Permalink
Fixed validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jan 30, 2024
1 parent 5d7c087 commit 8e9feb1
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,20 @@ class MergedAdapters(BaseModel):
# Majority sign method
majority_sign_method: Optional[str]

@validator("ids", "weights")
def validate_ids_weights(cls, ids, weights):
if not ids:
@validator("ids")
def validate_ids(cls, v):
if not v:
raise ValidationError("`ids` cannot be empty")
if not weights:
return v

@validator("weights")
def validate_weights(cls, v, values):
ids = values["ids"]
if not v:
raise ValidationError("`weights` cannot be empty")
if len(ids) != len(weights):
if len(ids) != len(v):
raise ValidationError("`ids` and `weights` must have the same length")
return ids, weights
return v

@validator("merge_strategy")
def validate_merge_strategy(cls, v):
Expand Down Expand Up @@ -94,11 +99,12 @@ class Parameters(BaseModel):
# Get decoder input token logprobs and ids
decoder_input_details: bool = False

@validator("adapter_id", "merged_adapters")
def valid_adapter_id_merged_adapters(cls, adapter_id, merged_adapters):
if adapter_id is not None and merged_adapters is not None:
@validator("adapter_id")
def valid_adapter_id(cls, v, values):
merged_adapters = values.get("merged_adapters")
if v is not None and merged_adapters is not None:
raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`")
return adapter_id, merged_adapters
return v

@validator("adapter_source")
def valid_adapter_source(cls, v):
Expand Down

0 comments on commit 8e9feb1

Please sign in to comment.