From 677d14fee8a50f3dc9114b19755065aef16d895c Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 29 Sep 2024 21:22:00 +0100 Subject: [PATCH] add description for model fields --- endpoints/core/types/auth.py | 4 +-- endpoints/core/types/download.py | 49 +++++++++++++++++++++++++------- endpoints/core/types/template.py | 2 +- endpoints/core/types/token.py | 22 +++++++------- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/endpoints/core/types/auth.py b/endpoints/core/types/auth.py index b8f3aa2e..070ac88e 100644 --- a/endpoints/core/types/auth.py +++ b/endpoints/core/types/auth.py @@ -1,7 +1,7 @@ """Types for auth requests.""" -from pydantic import BaseModel +from pydantic import BaseModel, Field class AuthPermissionResponse(BaseModel): - permission: str + permission: str = Field(description="The permission level of the API key") diff --git a/endpoints/core/types/download.py b/endpoints/core/types/download.py index cf49501f..a5d73702 100644 --- a/endpoints/core/types/download.py +++ b/endpoints/core/types/download.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import List, Optional +from typing import List, Literal, Optional def _generate_include_list(): @@ -9,18 +9,45 @@ def _generate_include_list(): class DownloadRequest(BaseModel): """Parameters for a HuggingFace repo download.""" - repo_id: str - repo_type: str = "model" - folder_name: Optional[str] = None - revision: Optional[str] = None - token: Optional[str] = None - include: List[str] = Field(default_factory=_generate_include_list) - exclude: List[str] = Field(default_factory=list) - chunk_limit: Optional[int] = None - timeout: Optional[int] = None + repo_id: str = Field( + description="The repo ID to download from", + examples=[ + "royallab/TinyLlama-1.1B-2T-exl2", + "royallab/LLaMA2-13B-TiefighterLR-exl2", + "turboderp/Llama-3.1-8B-Instruct-exl2", + ], + ) + repo_type: Literal["model", "lora"] = Field("model", description="The model type") + folder_name: Optional[str] = Field( + default=None, + description="The folder name to save the repo to " + + "(this is used to load the model)", + ) + revision: Optional[str] = Field( + default=None, description="The revision to download from" + ) + token: Optional[str] = Field( + default=None, + description="The HuggingFace API token to use, " + + "required for private/gated repos", + ) + include: List[str] = Field( + default_factory=_generate_include_list, + description="A list of file patterns to include in the download", + ) + exclude: List[str] = Field( + default_factory=list, + description="A list of file patterns to exclude from the download", + ) + chunk_limit: Optional[int] = Field( + None, description="The maximum chunk size to download in bytes" + ) + timeout: Optional[int] = Field( + None, description="The timeout for the download in seconds" + ) class DownloadResponse(BaseModel): """Response for a download request.""" - download_path: str + download_path: str = Field(description="The path to the downloaded repo") diff --git a/endpoints/core/types/template.py b/endpoints/core/types/template.py index d72d6210..c0932912 100644 --- a/endpoints/core/types/template.py +++ b/endpoints/core/types/template.py @@ -12,4 +12,4 @@ class TemplateList(BaseModel): class TemplateSwitchRequest(BaseModel): """Request to switch a template.""" - name: str + name: str = Field(description="The name of the template to switch to") diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 945adbf5..87f3f022 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,15 +1,17 @@ """Tokenization types""" -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing import Dict, List, Union class CommonTokenRequest(BaseModel): """Represents a common tokenization request.""" - add_bos_token: bool = True - encode_special_tokens: bool = True - decode_special_tokens: bool = True + add_bos_token: bool = Field( + True, description="Add the BOS (beginning of sequence) token" + ) + encode_special_tokens: bool = Field(True, description="Encode special tokens") + decode_special_tokens: bool = Field(True, description="Decode special tokens") def get_params(self): """Get the parameters for tokenization.""" @@ -23,29 +25,29 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict[str, str]]] + text: Union[str, List[Dict[str, str]]] = Field(description="The string to encode") class TokenEncodeResponse(BaseModel): """Represents a tokenization response.""" - tokens: List[int] - length: int + tokens: List[int] = Field(description="The tokens") + length: int = Field(description="The length of the tokens") class TokenDecodeRequest(CommonTokenRequest): """ " Represents a detokenization request.""" - tokens: List[int] + tokens: List[int] = Field(description="The string to encode") class TokenDecodeResponse(BaseModel): """Represents a detokenization response.""" - text: str + text: str = Field(description="The decoded text") class TokenCountResponse(BaseModel): """Represents a token count response.""" - length: int + length: int = Field(description="The length of the text")