Skip to content

Commit

Permalink
feat: logging (#39)
Browse files Browse the repository at this point in the history
* add logging

* simplify the logger

* formatting

* final touches

* fix format

* Model: Add log to metrics

Signed-off-by: kingbri <[email protected]>

---------

Authored-by: AlpinDale <[email protected]>
  • Loading branch information
AlpinDale authored Dec 23, 2023
1 parent f5314fc commit 6a5bbd2
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 74 deletions.
15 changes: 9 additions & 6 deletions auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from pydantic import BaseModel
import yaml

from logger import init_logger

logger = init_logger(__name__)


class AuthKeys(BaseModel):
"""
Expand Down Expand Up @@ -44,11 +48,10 @@ def load_auth_keys(disable_from_config: bool):

DISABLE_AUTH = disable_from_config
if disable_from_config:
print(
"!! Warning: Disabling authentication",
"makes your instance vulnerable.",
"Set the 'disable_auth' flag to False in config.yml",
"if you want to share this instance with others.",
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

return
Expand All @@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool):
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

print(
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
Expand Down
14 changes: 9 additions & 5 deletions gen_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Dict
from pydantic import BaseModel

from logger import init_logger

logger = init_logger(__name__)


class LogConfig(BaseModel):
"""Logging preference config."""
Expand Down Expand Up @@ -38,24 +42,24 @@ def broadcast_status():
enabled.append("generation params")

if len(enabled) > 0:
print("Generation logging is enabled for: " + ", ".join(enabled))
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
else:
print("Generation logging is disabled")
logger.info("Generation logging is disabled")


def log_generation_params(**kwargs):
"""Logs generation parameters to console."""
if CONFIG.generation_params:
print(f"Generation options: {kwargs}\n")
logger.info(f"Generation options: {kwargs}\n")


def log_prompt(prompt: str):
"""Logs the prompt to console."""
if CONFIG.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n")


def log_response(response: str):
"""Logs the response to console."""
if CONFIG.prompt:
print(f"Response: {response if response else 'Empty'}\n")
logger.info(f"Response: {response if response else 'Empty'}\n")
71 changes: 71 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Logging utility.
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
"""

import logging
import sys
import colorlog

_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"


class ColoredFormatter(colorlog.ColoredFormatter):
"""Adds logging prefix to newlines to align multi-line messages."""

def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
super().__init__(
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
)

def format(self, record):
msg = super().format(record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg


_root_logger = logging.getLogger("aphrodite")
_default_handler = None


def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = ColoredFormatter(
_FORMAT,
datefmt=_DATE_FORMAT,
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
reset=True,
)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False


# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()


def init_logger(name: str):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(_default_handler)
logger.propagate = False
return logger
19 changes: 11 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger

logger = init_logger(__name__)

app = FastAPI()

Expand Down Expand Up @@ -210,8 +213,8 @@ async def generator():

yield get_sse_packet(response.model_dump_json())
except CancelledError:
print(
"\nError: Model load cancelled by user. "
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
Expand Down Expand Up @@ -369,7 +372,7 @@ async def generator():

yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
logger.error("Completion request cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))

Expand Down Expand Up @@ -456,7 +459,7 @@ async def generator():

yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
logger.error("Chat completion cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))

Expand All @@ -481,10 +484,10 @@ async def generator():
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
print(
"The YAML config couldn't load because of the following error:",
f"\n\n{exc}",
"\n\nTabbyAPI will start anyway and not parse this config file.",
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
config = {}

Expand Down
76 changes: 38 additions & 38 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
get_template_from_file,
)
from utils import coalesce, unwrap
from logger import init_logger

logger = init_logger(__name__)

# Bytes to reserve on first device when loading with auto split
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
Expand Down Expand Up @@ -143,10 +146,7 @@ def progress(loaded_modules: int, total_modules: int,
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
print(
"Attempting to load prompt template with name",
{prompt_template_name},
)
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
Expand Down Expand Up @@ -175,13 +175,13 @@ def progress(loaded_modules: int, total_modules: int,

# Catch all for template lookup errors
if self.prompt_template:
print(
f"Using template {self.prompt_template.name} for chat " "completions."
logger.info(
f"Using template {self.prompt_template.name} " "for chat completions."
)
else:
print(
"Chat completions are disabled because a prompt template",
"wasn't provided or auto-detected.",
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)

# Set num of experts per token if provided
Expand All @@ -190,9 +190,9 @@ def progress(loaded_modules: int, total_modules: int,
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support overriding MoE experts"
logger.warning(
"MoE experts per token override is not "
"supported by the current ExLlamaV2 version."
)

chunk_size = min(
Expand All @@ -207,9 +207,9 @@ def progress(loaded_modules: int, total_modules: int,

# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
print(
"A draft config was found but a model name was not given. "
"Please check your config.yml! Skipping draft load."
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False

Expand Down Expand Up @@ -283,20 +283,20 @@ def load_loras(self, lora_directory: pathlib.Path, **kwargs):
lora_scaling = unwrap(lora.get("scaling"), 1.0)

if lora_name is None:
print(
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue

print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
print("Lora successfully loaded.")
logger.info(f"Lora successfully loaded: {lora_name}")
success.append(lora_name)

# Return success and failure names
Expand All @@ -319,7 +319,7 @@ def progress(loaded_modules: int, total_modules: int)
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
print("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)

self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
Expand All @@ -337,7 +337,7 @@ def progress(loaded_modules: int, total_modules: int)
# Load model
self.model = ExLlamaV2(self.config)
if not self.quiet:
print("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)

if not self.gpu_split_auto:
for value in self.model.load_gen(
Expand Down Expand Up @@ -373,7 +373,7 @@ def progress(loaded_modules: int, total_modules: int)
self.draft_cache,
)

print("Model successfully loaded.")
logger.info("Model successfully loaded.")

def unload(self, loras_only: bool = False):
"""
Expand Down Expand Up @@ -494,33 +494,33 @@ def generate_gen(self, prompt: str, **kwargs):
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
gen_settings, "mirostat"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"Mirostat sampling"
logger.warning(
"Mirostat sampling is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "min_p"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support min-P sampling"
logger.warning(
"Min-P sampling is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "tfs"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"tail-free sampling (TFS)"
logger.warning(
"Tail-free sampling (TFS) is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
gen_settings, "temperature_last"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"temperature_last"
logger.warning(
"Temperature last is not supported by the currently "
"installed ExLlamaV2 version."
)

# Apply settings
Expand Down Expand Up @@ -614,10 +614,10 @@ def generate_gen(self, prompt: str, **kwargs):
context_len = len(ids[0])

if context_len > self.config.max_seq_len:
print(
f"WARNING: The context length {context_len} is greater than "
f"the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate.",
logger.warning(
f"Context length {context_len} is greater than max_seq_len "
f"{self.config.max_seq_len}. Generation is truncated and "
"metrics may not be accurate."
)

prompt_tokens = ids.shape[-1]
Expand Down Expand Up @@ -705,7 +705,7 @@ def generate_gen(self, prompt: str, **kwargs):
extra_parts.append("<-- Not accurate (truncated)")

# Print output
print(
logger.info(
initial_response
+ " ("
+ ", ".join(itemization)
Expand Down
1 change: 1 addition & 0 deletions requirements-amd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog
Loading

0 comments on commit 6a5bbd2

Please sign in to comment.