Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
magdyksaleh committed Apr 12, 2024
2 parents e902c8a + cf2625a commit bc38276
Show file tree
Hide file tree
Showing 36 changed files with 1,479 additions and 649 deletions.
118 changes: 61 additions & 57 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from pydantic import BaseModel, validator, Field, ConfigDict
from pydantic import BaseModel, field_validator, model_validator, Field, ConfigDict
from typing import Optional, List, Dict, Any, OrderedDict, Union

from lorax.errors import ValidationError
Expand All @@ -16,40 +16,40 @@ class MergedAdapters(BaseModel):
# Weights of the adapters to merge
weights: List[float]
# Merge strategy
merge_strategy: Optional[str]
merge_strategy: Optional[str] = None
# Density
density: float
# Majority sign method
majority_sign_method: Optional[str]
majority_sign_method: Optional[str] = None

@validator("ids")
@field_validator("ids")
def validate_ids(cls, v):
if not v:
raise ValidationError("`ids` cannot be empty")
return v

@validator("weights")
@field_validator("weights")
def validate_weights(cls, v, values):
ids = values["ids"]
ids = values.data["ids"]
if not v:
raise ValidationError("`weights` cannot be empty")
if len(ids) != len(v):
raise ValidationError("`ids` and `weights` must have the same length")
return v

@validator("merge_strategy")
@field_validator("merge_strategy")
def validate_merge_strategy(cls, v):
if v is not None and v not in MERGE_STRATEGIES:
raise ValidationError(f"`merge_strategy` must be one of {MERGE_STRATEGIES}")
return v

@validator("density")
@field_validator("density")
def validate_density(cls, v):
if v < 0 or v > 1.0:
raise ValidationError("`density` must be >= 0.0 and <= 1.0")
return v

@validator("majority_sign_method")
@field_validator("majority_sign_method")
def validate_majority_sign_method(cls, v):
if v is not None and v not in MAJORITY_SIGN_METHODS:
raise ValidationError(f"`majority_sign_method` must be one of {MAJORITY_SIGN_METHODS}")
Expand All @@ -69,13 +69,13 @@ class ResponseFormat(BaseModel):

class Parameters(BaseModel):
# The ID of the adapter to use
adapter_id: Optional[str]
adapter_id: Optional[str] = None
# The source of the adapter to use
adapter_source: Optional[str]
adapter_source: Optional[str] = None
# Adapter merge parameters
merged_adapters: Optional[MergedAdapters]
merged_adapters: Optional[MergedAdapters] = None
# API token for accessing private adapters
api_token: Optional[str]
api_token: Optional[str] = None
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
Expand All @@ -90,107 +90,108 @@ class Parameters(BaseModel):
# Stop generating tokens if a member of `stop_sequences` is generated
stop: List[str] = []
# Random sampling seed
seed: Optional[int]
seed: Optional[int] = None
# The value used to module the logits distribution.
temperature: Optional[float]
temperature: Optional[float] = None
# The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_k: Optional[int]
top_k: Optional[int] = None
# If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
# higher are kept for generation.
top_p: Optional[float]
top_p: Optional[float] = None
# truncate inputs tokens to the given size
truncate: Optional[int]
truncate: Optional[int] = None
# Typical Decoding mass
# See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
typical_p: Optional[float]
typical_p: Optional[float] = None
# Generate best_of sequences and return the one if the highest token logprobs
best_of: Optional[int]
best_of: Optional[int] = None
# Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
watermark: bool = False
# Get generation details
details: bool = False
# Get decoder input token logprobs and ids
decoder_input_details: bool = False
# The number of highest probability vocabulary tokens to return as alternative tokens in the generation result
return_k_alternatives: Optional[int]
return_k_alternatives: Optional[int] = None
# Optional response format specification to constrain the generated text
response_format: Optional[ResponseFormat]
response_format: Optional[ResponseFormat] = None

@validator("adapter_id")
def valid_adapter_id(cls, v, values):
merged_adapters = values.get("merged_adapters")
if v is not None and merged_adapters is not None:
@model_validator(mode="after")
def valid_adapter_id(self):
adapter_id = self.adapter_id
merged_adapters = self.merged_adapters
if adapter_id is not None and merged_adapters is not None:
raise ValidationError("you must specify at most one of `adapter_id` or `merged_adapters`")
return v
return self

@validator("adapter_source")
@field_validator("adapter_source")
def valid_adapter_source(cls, v):
if v is not None and v not in ADAPTER_SOURCES:
raise ValidationError(f"`adapter_source` must be one of {ADAPTER_SOURCES}")
return v

@validator("best_of")
@field_validator("best_of")
def valid_best_of(cls, field_value, values):
if field_value is not None:
if field_value <= 0:
raise ValidationError("`best_of` must be strictly positive")
if field_value > 1 and values["seed"] is not None:
if field_value > 1 and values.data["seed"] is not None:
raise ValidationError("`seed` must not be set when `best_of` is > 1")
sampling = (
values["do_sample"]
| (values["temperature"] is not None)
| (values["top_k"] is not None)
| (values["top_p"] is not None)
| (values["typical_p"] is not None)
values.data["do_sample"]
| (values.data["temperature"] is not None)
| (values.data["top_k"] is not None)
| (values.data["top_p"] is not None)
| (values.data["typical_p"] is not None)
)
if field_value > 1 and not sampling:
raise ValidationError("you must use sampling when `best_of` is > 1")

return field_value

@validator("repetition_penalty")
@field_validator("repetition_penalty")
def valid_repetition_penalty(cls, v):
if v is not None and v <= 0:
raise ValidationError("`repetition_penalty` must be strictly positive")
return v

@validator("seed")
@field_validator("seed")
def valid_seed(cls, v):
if v is not None and v < 0:
raise ValidationError("`seed` must be positive")
return v

@validator("temperature")
@field_validator("temperature")
def valid_temp(cls, v):
if v is not None and v <= 0:
raise ValidationError("`temperature` must be strictly positive")
return v

@validator("top_k")
@field_validator("top_k")
def valid_top_k(cls, v):
if v is not None and v <= 0:
raise ValidationError("`top_k` must be strictly positive")
return v

@validator("top_p")
@field_validator("top_p")
def valid_top_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`top_p` must be > 0.0 and < 1.0")
return v

@validator("truncate")
@field_validator("truncate")
def valid_truncate(cls, v):
if v is not None and v <= 0:
raise ValidationError("`truncate` must be strictly positive")
return v

@validator("typical_p")
@field_validator("typical_p")
def valid_typical_p(cls, v):
if v is not None and (v <= 0 or v >= 1.0):
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
return v

@validator("return_k_alternatives")
@field_validator("return_k_alternatives")
def valid_return_k_alternatives(cls, v):
if v is not None and v <= 0:
raise ValidationError("`return_k_alternatives` must be strictly positive")
Expand All @@ -201,19 +202,19 @@ class Request(BaseModel):
# Prompt
inputs: str
# Generation parameters
parameters: Optional[Parameters]
parameters: Optional[Parameters] = None
# Whether to stream output tokens
stream: bool = False

@validator("inputs")
@field_validator("inputs")
def valid_input(cls, v):
if not v:
raise ValidationError("`inputs` cannot be empty")
return v

@validator("stream")
@field_validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
parameters = values.data["parameters"]
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value
Expand All @@ -227,7 +228,7 @@ class InputToken(BaseModel):
text: str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]
logprob: Optional[float] = None


# Alternative Tokens
Expand All @@ -252,7 +253,7 @@ class Token(BaseModel):
# Can be used to ignore tokens when concatenating
special: bool
# Alternative tokens
alternative_tokens: Optional[List[AlternativeToken]]
alternative_tokens: Optional[List[AlternativeToken]] = None


# Generation finish reason
Expand All @@ -274,7 +275,7 @@ class BestOfSequence(BaseModel):
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
seed: Optional[int] = None
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
Expand All @@ -290,21 +291,21 @@ class Details(BaseModel):
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
seed: Optional[int] = None
# Decoder input tokens, empty if decoder_input_details is False
prefill: List[InputToken]
# Generated tokens
tokens: List[Token]
# Additional sequences when using the `best_of` parameter
best_of_sequences: Optional[List[BestOfSequence]]
best_of_sequences: Optional[List[BestOfSequence]] = None


# `generate` return value
class Response(BaseModel):
# Generated text
generated_text: str
# Generation details
details: Optional[Details]
details: Optional[Details] = None


# `generate_stream` details
Expand All @@ -316,7 +317,7 @@ class StreamDetails(BaseModel):
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
seed: Optional[int]
seed: Optional[int] = None


# `generate_stream` return value
Expand All @@ -325,17 +326,20 @@ class StreamResponse(BaseModel):
token: Token
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str]
generated_text: Optional[str] = None
# Generation details
# Only available when the generation is finished
details: Optional[StreamDetails]
details: Optional[StreamDetails] = None


# Inference API currently deployed model
class DeployedModel(BaseModel):
model_id: str
sha: str
# Suppress pydantic warning over `model_id` field
model_config = ConfigDict(protected_namespaces=())


class EmbedResponse(BaseModel):
# Embeddings
embeddings: Optional[List[float]]
embeddings: Optional[List[float]]
Loading

0 comments on commit bc38276

Please sign in to comment.