Skip to content

Commit

Permalink
Model: Remove and format comments
Browse files Browse the repository at this point in the history
The comment in __init__ was outdated and all the kwargs are the
config options anyways.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Aug 28, 2024
1 parent 80198ca commit 4958c06
Showing 1 changed file with 22 additions and 57 deletions.
79 changes: 22 additions & 57 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,51 +101,9 @@ class ExllamaV2Container:

def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
Primary initializer for model container.
Args:
model_dir (int): Model directory containing config.json,
tokenizer.model etc.
quiet (bool): Suppress console output
load_progress_callback (function, optional): A function to call for
each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int,
loading_draft: bool)
**kwargs:
`cache_mode` (str): Sets cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'max_seq_len' (int): Override model's default max sequence
length (default: 4096)
'cache_size' (int): Num of tokens to allocate space for in the k/v cache
(default: max_seq_len)
'rope_scale' (float): Set RoPE scaling factor for model
(default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
(default: 1.0)
'prompt_template' (str): Manually sets the prompt template for
this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model
(default: 2048)
Inferencing in chunks reduces overall VRAM overhead by
processing very long sequences in smaller batches. This
limits the size of temporary buffers needed for the hidden
state and attention weights.
'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft
model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
model. By default, the draft model's alpha value is
calculated automatically to scale to the size of the
full model.
'draft_cache_mode' (str): Sets draft cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'lora_dir' (str): LoRA directory
'loras' (list[dict]): List of loras to be loaded, consisting of
'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across
available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
Kwargs are located in config_sample.yml
"""

self.quiet = quiet
Expand Down Expand Up @@ -386,7 +344,7 @@ def progress(loaded_modules: int, total_modules: int,
self.draft_config.max_attention_size = chunk_size**2

def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods"""
"""Tries to find a prompt template using various methods."""

logger.info("Attempting to load a prompt template if present.")

Expand Down Expand Up @@ -428,6 +386,7 @@ def find_prompt_template(self, prompt_template_name, model_directory):

def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""

ratio = self.config.max_seq_len / base_seq_len

# Default to a 1 alpha if the sequence length is ever less
Expand Down Expand Up @@ -504,7 +463,9 @@ async def load(self, progress_callback=None):
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
"""

Expand Down Expand Up @@ -549,11 +510,13 @@ async def load_gen(self, progress_callback=None, **kwargs):
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
"""
Load model, generator function
Synchronous generator for loading.
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
Runs under a shared inference mode context.
Expand Down Expand Up @@ -695,6 +658,8 @@ def create_cache(
)

async def create_generator(self):
"""Create and save a Exllama generator class."""

try:
# Don't acquire locks unless a model is loaded
if self.model_loaded:
Expand Down Expand Up @@ -728,9 +693,7 @@ def get_loras(self):
return unwrap(self.generator.generator.current_loras, [])

async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
"""Load loras."""

loras = unwrap(kwargs.get("loras"), [])

Expand Down Expand Up @@ -777,9 +740,7 @@ async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
self.load_condition.notify_all()

async def unload(self, loras_only: bool = False, **kwargs):
"""
Free all VRAM resources used by this model
"""
"""Free all VRAM resources used by the model (and loras)."""

# Shutdown immediately unloads and bypasses all locks
do_shutdown = kwargs.get("shutdown")
Expand Down Expand Up @@ -836,7 +797,7 @@ async def unload(self, loras_only: bool = False, **kwargs):
self.load_condition.notify_all()

def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
"""Wrapper to encode tokens from a text string."""

return (
self.tokenizer.encode(
Expand Down Expand Up @@ -888,7 +849,7 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt"""
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
Expand Down Expand Up @@ -939,7 +900,11 @@ async def generate(
return joined_generation

def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
"""
Check and warn the user if a sampler is unsupported.
Meant for dev wheels!
"""

return kwargs

Expand Down

0 comments on commit 4958c06

Please sign in to comment.