Skip to content

Commit

Permalink
add description for model fields
Browse files Browse the repository at this point in the history
  • Loading branch information
SecretiveShell committed Sep 29, 2024
1 parent 3e1ef55 commit 677d14f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
4 changes: 2 additions & 2 deletions endpoints/core/types/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Types for auth requests."""

from pydantic import BaseModel
from pydantic import BaseModel, Field


class AuthPermissionResponse(BaseModel):
permission: str
permission: str = Field(description="The permission level of the API key")
49 changes: 38 additions & 11 deletions endpoints/core/types/download.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import List, Optional
from typing import List, Literal, Optional


def _generate_include_list():
Expand All @@ -9,18 +9,45 @@ def _generate_include_list():
class DownloadRequest(BaseModel):
"""Parameters for a HuggingFace repo download."""

repo_id: str
repo_type: str = "model"
folder_name: Optional[str] = None
revision: Optional[str] = None
token: Optional[str] = None
include: List[str] = Field(default_factory=_generate_include_list)
exclude: List[str] = Field(default_factory=list)
chunk_limit: Optional[int] = None
timeout: Optional[int] = None
repo_id: str = Field(
description="The repo ID to download from",
examples=[
"royallab/TinyLlama-1.1B-2T-exl2",
"royallab/LLaMA2-13B-TiefighterLR-exl2",
"turboderp/Llama-3.1-8B-Instruct-exl2",
],
)
repo_type: Literal["model", "lora"] = Field("model", description="The model type")
folder_name: Optional[str] = Field(
default=None,
description="The folder name to save the repo to "
+ "(this is used to load the model)",
)
revision: Optional[str] = Field(
default=None, description="The revision to download from"
)
token: Optional[str] = Field(
default=None,
description="The HuggingFace API token to use, "
+ "required for private/gated repos",
)
include: List[str] = Field(
default_factory=_generate_include_list,
description="A list of file patterns to include in the download",
)
exclude: List[str] = Field(
default_factory=list,
description="A list of file patterns to exclude from the download",
)
chunk_limit: Optional[int] = Field(
None, description="The maximum chunk size to download in bytes"
)
timeout: Optional[int] = Field(
None, description="The timeout for the download in seconds"
)


class DownloadResponse(BaseModel):
"""Response for a download request."""

download_path: str
download_path: str = Field(description="The path to the downloaded repo")
2 changes: 1 addition & 1 deletion endpoints/core/types/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""

name: str
name: str = Field(description="The name of the template to switch to")
22 changes: 12 additions & 10 deletions endpoints/core/types/token.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Tokenization types"""

from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Dict, List, Union


class CommonTokenRequest(BaseModel):
"""Represents a common tokenization request."""

add_bos_token: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
add_bos_token: bool = Field(
True, description="Add the BOS (beginning of sequence) token"
)
encode_special_tokens: bool = Field(True, description="Encode special tokens")
decode_special_tokens: bool = Field(True, description="Decode special tokens")

def get_params(self):
"""Get the parameters for tokenization."""
Expand All @@ -23,29 +25,29 @@ def get_params(self):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""

text: Union[str, List[Dict[str, str]]]
text: Union[str, List[Dict[str, str]]] = Field(description="The string to encode")


class TokenEncodeResponse(BaseModel):
"""Represents a tokenization response."""

tokens: List[int]
length: int
tokens: List[int] = Field(description="The tokens")
length: int = Field(description="The length of the tokens")


class TokenDecodeRequest(CommonTokenRequest):
""" " Represents a detokenization request."""

tokens: List[int]
tokens: List[int] = Field(description="The string to encode")


class TokenDecodeResponse(BaseModel):
"""Represents a detokenization response."""

text: str
text: str = Field(description="The decoded text")


class TokenCountResponse(BaseModel):
"""Represents a token count response."""

length: int
length: int = Field(description="The length of the text")

0 comments on commit 677d14f

Please sign in to comment.