Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: logging #39

Merged
merged 6 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from pydantic import BaseModel
import yaml

from logger import init_logger

logger = init_logger(__name__)


class AuthKeys(BaseModel):
"""
Expand Down Expand Up @@ -44,11 +48,10 @@ def load_auth_keys(disable_from_config: bool):

DISABLE_AUTH = disable_from_config
if disable_from_config:
print(
"!! Warning: Disabling authentication",
"makes your instance vulnerable.",
"Set the 'disable_auth' flag to False in config.yml",
"if you want to share this instance with others.",
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)

return
Expand All @@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool):
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)

print(
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "
Expand Down
14 changes: 9 additions & 5 deletions gen_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Dict
from pydantic import BaseModel

from logger import init_logger

logger = init_logger(__name__)


class LogConfig(BaseModel):
"""Logging preference config."""
Expand Down Expand Up @@ -38,24 +42,24 @@ def broadcast_status():
enabled.append("generation params")

if len(enabled) > 0:
print("Generation logging is enabled for: " + ", ".join(enabled))
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
else:
print("Generation logging is disabled")
logger.info("Generation logging is disabled")


def log_generation_params(**kwargs):
"""Logs generation parameters to console."""
if CONFIG.generation_params:
print(f"Generation options: {kwargs}\n")
logger.info(f"Generation options: {kwargs}\n")


def log_prompt(prompt: str):
"""Logs the prompt to console."""
if CONFIG.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n")


def log_response(response: str):
"""Logs the response to console."""
if CONFIG.prompt:
print(f"Response: {response if response else 'Empty'}\n")
logger.info(f"Response: {response if response else 'Empty'}\n")
71 changes: 71 additions & 0 deletions logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Logging utility.
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
"""

import logging
import sys
import colorlog

_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"


class ColoredFormatter(colorlog.ColoredFormatter):
"""Adds logging prefix to newlines to align multi-line messages."""

def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
super().__init__(
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
)

def format(self, record):
msg = super().format(record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg


_root_logger = logging.getLogger("aphrodite")
_default_handler = None


def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = ColoredFormatter(
_FORMAT,
datefmt=_DATE_FORMAT,
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
reset=True,
)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False


# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()


def init_logger(name: str):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(_default_handler)
logger.propagate = False
return logger
19 changes: 11 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger

logger = init_logger(__name__)

app = FastAPI()

Expand Down Expand Up @@ -210,8 +213,8 @@ async def generator():

yield get_sse_packet(response.model_dump_json())
except CancelledError:
print(
"\nError: Model load cancelled by user. "
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
Expand Down Expand Up @@ -369,7 +372,7 @@ async def generator():

yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
logger.error("Completion request cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))

Expand Down Expand Up @@ -456,7 +459,7 @@ async def generator():

yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
logger.error("Chat completion cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))

Expand All @@ -481,10 +484,10 @@ async def generator():
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
print(
"The YAML config couldn't load because of the following error:",
f"\n\n{exc}",
"\n\nTabbyAPI will start anyway and not parse this config file.",
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
config = {}

Expand Down
76 changes: 38 additions & 38 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
get_template_from_file,
)
from utils import coalesce, unwrap
from logger import init_logger

logger = init_logger(__name__)

# Bytes to reserve on first device when loading with auto split
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
Expand Down Expand Up @@ -143,10 +146,7 @@ def progress(loaded_modules: int, total_modules: int,
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
print(
"Attempting to load prompt template with name",
{prompt_template_name},
)
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
Expand Down Expand Up @@ -175,13 +175,13 @@ def progress(loaded_modules: int, total_modules: int,

# Catch all for template lookup errors
if self.prompt_template:
print(
f"Using template {self.prompt_template.name} for chat " "completions."
logger.info(
f"Using template {self.prompt_template.name} " "for chat completions."
)
else:
print(
"Chat completions are disabled because a prompt template",
"wasn't provided or auto-detected.",
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)

# Set num of experts per token if provided
Expand All @@ -190,9 +190,9 @@ def progress(loaded_modules: int, total_modules: int,
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support overriding MoE experts"
logger.warning(
"MoE experts per token override is not "
"supported by the current ExLlamaV2 version."
)

chunk_size = min(
Expand All @@ -207,9 +207,9 @@ def progress(loaded_modules: int, total_modules: int,

# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
print(
"A draft config was found but a model name was not given. "
"Please check your config.yml! Skipping draft load."
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False

Expand Down Expand Up @@ -283,20 +283,20 @@ def load_loras(self, lora_directory: pathlib.Path, **kwargs):
lora_scaling = unwrap(lora.get("scaling"), 1.0)

if lora_name is None:
print(
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue

print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
print("Lora successfully loaded.")
logger.info(f"Lora successfully loaded: {lora_name}")
success.append(lora_name)

# Return success and failure names
Expand All @@ -319,7 +319,7 @@ def progress(loaded_modules: int, total_modules: int)
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
print("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)

self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
Expand All @@ -337,7 +337,7 @@ def progress(loaded_modules: int, total_modules: int)
# Load model
self.model = ExLlamaV2(self.config)
if not self.quiet:
print("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)

if not self.gpu_split_auto:
for value in self.model.load_gen(
Expand Down Expand Up @@ -373,7 +373,7 @@ def progress(loaded_modules: int, total_modules: int)
self.draft_cache,
)

print("Model successfully loaded.")
logger.info("Model successfully loaded.")

def unload(self, loras_only: bool = False):
"""
Expand Down Expand Up @@ -494,33 +494,33 @@ def generate_gen(self, prompt: str, **kwargs):
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
gen_settings, "mirostat"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"Mirostat sampling"
logger.warning(
"Mirostat sampling is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "min_p"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support min-P sampling"
logger.warning(
"Min-P sampling is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "tfs"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"tail-free sampling (TFS)"
logger.warning(
"Tail-free sampling (TFS) is not supported by the currently "
"installed ExLlamaV2 version."
)

if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
gen_settings, "temperature_last"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"temperature_last"
logger.warning(
"Temperature last is not supported by the currently "
"installed ExLlamaV2 version."
)

# Apply settings
Expand Down Expand Up @@ -614,10 +614,10 @@ def generate_gen(self, prompt: str, **kwargs):
context_len = len(ids[0])

if context_len > self.config.max_seq_len:
print(
f"WARNING: The context length {context_len} is greater than "
f"the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate.",
logger.warning(
f"Context length {context_len} is greater than max_seq_len "
f"{self.config.max_seq_len}. Generation is truncated and "
"metrics may not be accurate."
)

prompt_tokens = ids.shape[-1]
Expand Down Expand Up @@ -705,7 +705,7 @@ def generate_gen(self, prompt: str, **kwargs):
extra_parts.append("<-- Not accurate (truncated)")

# Print output
print(
logger.info(
initial_response
+ " ("
+ ", ".join(itemization)
Expand Down
1 change: 1 addition & 0 deletions requirements-amd.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog
Loading
Loading