From 34174713f16e43b5d33e1f53d447f88fe272860a Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 19:14:02 -0500 Subject: [PATCH 1/6] add logging --- auth.py | 15 ++++---- gen_logging.py | 13 ++++--- logger.py | 77 ++++++++++++++++++++++++++++++++++++++++ main.py | 20 ++++++----- model.py | 77 +++++++++++++++++++--------------------- requirements-amd.txt | 1 + requirements-cu118.txt | 1 + requirements-nowheel.txt | 1 + requirements.txt | 1 + tests/wheel_test.py | 28 ++++++++------- utils.py | 6 +++- 11 files changed, 166 insertions(+), 74 deletions(-) create mode 100644 logger.py diff --git a/auth.py b/auth.py index 451ba0ec..b881823f 100644 --- a/auth.py +++ b/auth.py @@ -9,6 +9,9 @@ from pydantic import BaseModel import yaml +from logger import init_logger + +logger = init_logger(__name__) class AuthKeys(BaseModel): """ @@ -44,12 +47,12 @@ def load_auth_keys(disable_from_config: bool): DISABLE_AUTH = disable_from_config if disable_from_config: - print( - "!! Warning: Disabling authentication", - "makes your instance vulnerable.", - "Set the 'disable_auth' flag to False in config.yml", - "if you want to share this instance with others.", + logger.warning( + "Disabling authentication makes your instance vulnerable. " + "Set the `disable_auth` flag to False in config.yml if you " + "want to share this instance with others." ) + return @@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool): with open("api_tokens.yml", "w", encoding="utf8") as auth_file: yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) - print( + logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " diff --git a/gen_logging.py b/gen_logging.py index c8d36209..e14cbefd 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -4,6 +4,9 @@ from typing import Dict from pydantic import BaseModel +from logger import init_logger + +logger = init_logger(__name__) class LogConfig(BaseModel): """Logging preference config.""" @@ -38,24 +41,24 @@ def broadcast_status(): enabled.append("generation params") if len(enabled) > 0: - print("Generation logging is enabled for: " + ", ".join(enabled)) + logger.info("Generation logging is enabled for: " + ", ".join(enabled)) else: - print("Generation logging is disabled") + logger.info("Generation logging is disabled") def log_generation_params(**kwargs): """Logs generation parameters to console.""" if CONFIG.generation_params: - print(f"Generation options: {kwargs}\n") + logger.info(f"Generation options: {kwargs}\n") def log_prompt(prompt: str): """Logs the prompt to console.""" if CONFIG.prompt: - print(f"Prompt: {prompt if prompt else 'Empty'}\n") + logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n") def log_response(response: str): """Logs the response to console.""" if CONFIG.prompt: - print(f"Response: {response if response else 'Empty'}\n") + logger.info(f"Response: {response if response else 'Empty'}\n") diff --git a/logger.py b/logger.py new file mode 100644 index 00000000..25b642f6 --- /dev/null +++ b/logger.py @@ -0,0 +1,77 @@ +""" +Logging utility. +https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py +""" + +import logging +import sys +import colorlog + +# pylint: disable=line-too-long +_FORMAT = "%(log_color)s%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class ColoredFormatter(colorlog.ColoredFormatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, + fmt, + datefmt=None, + log_colors=None, + reset=True, + style="%"): + super().__init__(fmt, + datefmt=datefmt, + log_colors=log_colors, + reset=reset, + style=style) + + def format(self, record): + msg = super().format(record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("aphrodite") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = ColoredFormatter(_FORMAT, + datefmt=_DATE_FORMAT, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, + reset=True) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(_default_handler) + logger.propagate = False + return logger diff --git a/main.py b/main.py index fb1fe261..285a281b 100644 --- a/main.py +++ b/main.py @@ -41,6 +41,9 @@ ) from templating import get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap +from logger import init_logger + +logger = init_logger(__name__) app = FastAPI() @@ -210,8 +213,8 @@ async def generator(): yield get_sse_packet(response.model_dump_json()) except CancelledError: - print( - "\nError: Model load cancelled by user. " + logger.error( + "Model load cancelled by user. " "Please make sure to run unload to free up resources." ) except Exception as exc: @@ -369,7 +372,7 @@ async def generator(): yield get_sse_packet(response.model_dump_json()) except CancelledError: - print("Error: Completion request cancelled by user.") + logger.error("Completion request cancelled by user.") except Exception as exc: yield get_generator_error(str(exc)) @@ -456,7 +459,7 @@ async def generator(): yield get_sse_packet(finish_response.model_dump_json()) except CancelledError: - print("Error: Chat completion cancelled by user.") + logger.error("Chat completion cancelled by user.") except Exception as exc: yield get_generator_error(str(exc)) @@ -481,11 +484,10 @@ async def generator(): with open("config.yml", "r", encoding="utf8") as config_file: config = unwrap(yaml.safe_load(config_file), {}) except Exception as exc: - print( - "The YAML config couldn't load because of the following error:", - f"\n\n{exc}", - "\n\nTabbyAPI will start anyway and not parse this config file.", - ) + logger.error( + "The YAML config couldn't load because of the following error: " + f"\n\n{exc}" + "\n\nTabbyAPI will start anyway and not parse this config file.") config = {} network_config = unwrap(config.get("network"), {}) diff --git a/model.py b/model.py index e4dde491..b2766b90 100644 --- a/model.py +++ b/model.py @@ -23,6 +23,9 @@ get_template_from_file, ) from utils import coalesce, unwrap +from logger import init_logger + +logger = init_logger(__name__) # Bytes to reserve on first device when loading with auto split AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2 @@ -143,10 +146,8 @@ def progress(loaded_modules: int, total_modules: int, # Set prompt template override if provided prompt_template_name = kwargs.get("prompt_template") if prompt_template_name: - print( - "Attempting to load prompt template with name", - {prompt_template_name}, - ) + logger.info("Loading prompt template with name " + f"{prompt_template_name}") # Read the template self.prompt_template = get_template_from_file(prompt_template_name) else: @@ -175,14 +176,11 @@ def progress(loaded_modules: int, total_modules: int, # Catch all for template lookup errors if self.prompt_template: - print( - f"Using template {self.prompt_template.name} for chat " "completions." - ) + logger.info(f"Using template {self.prompt_template.name} " + "for chat completions.") else: - print( - "Chat completions are disabled because a prompt template", - "wasn't provided or auto-detected.", - ) + logger.warning("Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected.") # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") @@ -190,10 +188,8 @@ def progress(loaded_modules: int, total_modules: int, if hasattr(self.config, "num_experts_per_token"): self.config.num_experts_per_token = num_experts_override else: - print( - " !! Warning: Currently installed ExLlamaV2 does not " - "support overriding MoE experts" - ) + logger.warning("MoE experts per token override is not " + "supported by the current ExLlamaV2 version.") chunk_size = min( unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len @@ -207,10 +203,8 @@ def progress(loaded_modules: int, total_modules: int, # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: - print( - "A draft config was found but a model name was not given. " - "Please check your config.yml! Skipping draft load." - ) + logger.warning("Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!") enable_draft = False if enable_draft: @@ -283,20 +277,20 @@ def load_loras(self, lora_directory: pathlib.Path, **kwargs): lora_scaling = unwrap(lora.get("scaling"), 1.0) if lora_name is None: - print( + logger.warning( "One of your loras does not have a name. Please check your " "config.yml! Skipping lora load." ) failure.append(lora_name) continue - print(f"Loading lora: {lora_name} at scaling {lora_scaling}") + logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") lora_path = lora_directory / lora_name # FIXME(alpin): Does self.model need to be passed here? self.active_loras.append( ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) ) - print("Lora successfully loaded.") + logger.info(f"Lora successfully loaded: {lora_name}") success.append(lora_name) # Return success and failure names @@ -319,7 +313,8 @@ def progress(loaded_modules: int, total_modules: int) if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) if not self.quiet: - print("Loading draft model: " + self.draft_config.model_dir) + logger.info( + "Loading draft model: " + self.draft_config.model_dir) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 @@ -337,7 +332,7 @@ def progress(loaded_modules: int, total_modules: int) # Load model self.model = ExLlamaV2(self.config) if not self.quiet: - print("Loading model: " + self.config.model_dir) + logger.info("Loading model: " + self.config.model_dir) if not self.gpu_split_auto: for value in self.model.load_gen( @@ -373,7 +368,7 @@ def progress(loaded_modules: int, total_modules: int) self.draft_cache, ) - print("Model successfully loaded.") + logger.info("Model successfully loaded.") def unload(self, loras_only: bool = False): """ @@ -494,33 +489,33 @@ def generate_gen(self, prompt: str, **kwargs): if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( gen_settings, "mirostat" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "Mirostat sampling" + logger.warning( + "Mirostat sampling is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( gen_settings, "min_p" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not " - "support min-P sampling" + logger.warning( + "Min-P sampling is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( gen_settings, "tfs" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "tail-free sampling (TFS)" + logger.warning( + "Tail-free sampling (TFS) is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( gen_settings, "temperature_last" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "temperature_last" + logger.warning( + "Temperature last is not supported by the currently " + "installed ExLlamaV2 version." ) # Apply settings @@ -614,10 +609,10 @@ def generate_gen(self, prompt: str, **kwargs): context_len = len(ids[0]) if context_len > self.config.max_seq_len: - print( - f"WARNING: The context length {context_len} is greater than " - f"the max_seq_len {self.config.max_seq_len}.", - "Generation is truncated and metrics may not be accurate.", + logger.warning( + f"Context length {context_len} is greater than max_seq_len " + f"{self.config.max_seq_len}. Generation is truncated and " + "metrics may not be accurate." ) prompt_tokens = ids.shape[-1] diff --git a/requirements-amd.txt b/requirements-amd.txt index 24e0b0bd..d3239e9a 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -13,3 +13,4 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog \ No newline at end of file diff --git a/requirements-cu118.txt b/requirements-cu118.txt index d5eac3fc..d804226a 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -19,6 +19,7 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements-nowheel.txt b/requirements-nowheel.txt index 1f195301..c36272e1 100644 --- a/requirements-nowheel.txt +++ b/requirements-nowheel.txt @@ -5,3 +5,4 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog diff --git a/requirements.txt b/requirements.txt index f87bc74b..a0bd32f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog # Flash attention v2 diff --git a/tests/wheel_test.py b/tests/wheel_test.py index c343ed50..acc6f81a 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -2,41 +2,45 @@ from importlib.metadata import version from importlib.util import find_spec +from logger import init_logger + +logger = init_logger(__name__) + successful_packages = [] errored_packages = [] if find_spec("flash_attn") is not None: - print(f"Flash attention on version {version('flash_attn')} successfully imported") + logger.info(f"Flash attention on version {version('flash_attn')} " + "successfully imported") successful_packages.append("flash_attn") else: - print("Flash attention 2 is not found in your environment.") + logger.error("Flash attention 2 is not found in your environment.") errored_packages.append("flash_attn") if find_spec("exllamav2") is not None: - print(f"Exllamav2 on version {version('exllamav2')} successfully imported") + logger.info(f"Exllamav2 on version {version('exllamav2')} " + "successfully imported") successful_packages.append("exllamav2") else: - print("Exllamav2 is not found in your environment.") + logger.error("Exllamav2 is not found in your environment.") errored_packages.append("exllamav2") if find_spec("torch") is not None: - print(f"Torch on version {version('torch')} successfully imported") + logger.info(f"Torch on version {version('torch')} successfully imported") successful_packages.append("torch") else: - print("Torch is not found in your environment.") + logger.error("Torch is not found in your environment.") errored_packages.append("torch") if find_spec("jinja2") is not None: - print(f"Jinja2 on version {version('jinja2')} successfully imported") + logger.info(f"Jinja2 on version {version('jinja2')} successfully imported") successful_packages.append("jinja2") else: - print("Jinja2 is not found in your environment.") + logger.error("Jinja2 is not found in your environment.") errored_packages.append("jinja2") -print( - f"\nSuccessful imports: {', '.join(successful_packages)}", - f"\nErrored imports: {''.join(errored_packages)}", -) +logger.info(f"\nSuccessful imports: {', '.join(successful_packages)}") +logger.error(f"Errored imports: {''.join(errored_packages)}" if len(errored_packages) > 0: print( diff --git a/utils.py b/utils.py index 6f00d3ea..529afe0c 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,10 @@ from pydantic import BaseModel +from logger import init_logger + +logger = init_logger(__name__) + def load_progress(module, modules): """Wrapper callback for load progress.""" @@ -32,7 +36,7 @@ def get_generator_error(message: str): generator_error = TabbyGeneratorError(error=error_message) # Log and send the exception - print(f"\n{generator_error.error.trace}") + logger.error(generator_error.error.message) return get_sse_packet(generator_error.model_dump_json()) From ba76331c3e8ea2fa01a48947b6fdcceb9df18c9d Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 19:16:34 -0500 Subject: [PATCH 2/6] simplify the logger --- logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logger.py b/logger.py index 25b642f6..376e54d0 100644 --- a/logger.py +++ b/logger.py @@ -8,7 +8,7 @@ import colorlog # pylint: disable=line-too-long -_FORMAT = "%(log_color)s%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_FORMAT = "%(log_color)s%(levelname)s] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" From 628960be95534a62c270b3f5064b552071210216 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 19:19:55 -0500 Subject: [PATCH 3/6] formatting --- auth.py | 2 +- gen_logging.py | 1 + logger.py | 38 ++++++++++++++++---------------------- main.py | 3 ++- model.py | 29 +++++++++++++++++------------ tests/wheel_test.py | 10 +++++----- 6 files changed, 42 insertions(+), 41 deletions(-) diff --git a/auth.py b/auth.py index b881823f..4185ddb6 100644 --- a/auth.py +++ b/auth.py @@ -13,6 +13,7 @@ logger = init_logger(__name__) + class AuthKeys(BaseModel): """ This class represents the authentication keys for the application. @@ -52,7 +53,6 @@ def load_auth_keys(disable_from_config: bool): "Set the `disable_auth` flag to False in config.yml if you " "want to share this instance with others." ) - return diff --git a/gen_logging.py b/gen_logging.py index e14cbefd..77313573 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -8,6 +8,7 @@ logger = init_logger(__name__) + class LogConfig(BaseModel): """Logging preference config.""" diff --git a/logger.py b/logger.py index 376e54d0..af20bdfd 100644 --- a/logger.py +++ b/logger.py @@ -7,7 +7,6 @@ import sys import colorlog -# pylint: disable=line-too-long _FORMAT = "%(log_color)s%(levelname)s] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" @@ -15,17 +14,10 @@ class ColoredFormatter(colorlog.ColoredFormatter): """Adds logging prefix to newlines to align multi-line messages.""" - def __init__(self, - fmt, - datefmt=None, - log_colors=None, - reset=True, - style="%"): - super().__init__(fmt, - datefmt=datefmt, - log_colors=log_colors, - reset=reset, - style=style) + def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"): + super().__init__( + fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style + ) def format(self, record): msg = super().format(record) @@ -47,16 +39,18 @@ def _setup_logger(): _default_handler.flush = sys.stdout.flush # type: ignore _default_handler.setLevel(logging.INFO) _root_logger.addHandler(_default_handler) - fmt = ColoredFormatter(_FORMAT, - datefmt=_DATE_FORMAT, - log_colors={ - "DEBUG": "cyan", - "INFO": "green", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "red,bg_white", - }, - reset=True) + fmt = ColoredFormatter( + _FORMAT, + datefmt=_DATE_FORMAT, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, + reset=True, + ) _default_handler.setFormatter(fmt) # Setting this will avoid the message # being propagated to the parent logger. diff --git a/main.py b/main.py index 285a281b..6d5d3040 100644 --- a/main.py +++ b/main.py @@ -487,7 +487,8 @@ async def generator(): logger.error( "The YAML config couldn't load because of the following error: " f"\n\n{exc}" - "\n\nTabbyAPI will start anyway and not parse this config file.") + "\n\nTabbyAPI will start anyway and not parse this config file." + ) config = {} network_config = unwrap(config.get("network"), {}) diff --git a/model.py b/model.py index b2766b90..93d0120a 100644 --- a/model.py +++ b/model.py @@ -146,8 +146,7 @@ def progress(loaded_modules: int, total_modules: int, # Set prompt template override if provided prompt_template_name = kwargs.get("prompt_template") if prompt_template_name: - logger.info("Loading prompt template with name " - f"{prompt_template_name}") + logger.info("Loading prompt template with name " f"{prompt_template_name}") # Read the template self.prompt_template = get_template_from_file(prompt_template_name) else: @@ -176,11 +175,14 @@ def progress(loaded_modules: int, total_modules: int, # Catch all for template lookup errors if self.prompt_template: - logger.info(f"Using template {self.prompt_template.name} " - "for chat completions.") + logger.info( + f"Using template {self.prompt_template.name} " "for chat completions." + ) else: - logger.warning("Chat completions are disabled because a prompt " - "template wasn't provided or auto-detected.") + logger.warning( + "Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected." + ) # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") @@ -188,8 +190,10 @@ def progress(loaded_modules: int, total_modules: int, if hasattr(self.config, "num_experts_per_token"): self.config.num_experts_per_token = num_experts_override else: - logger.warning("MoE experts per token override is not " - "supported by the current ExLlamaV2 version.") + logger.warning( + "MoE experts per token override is not " + "supported by the current ExLlamaV2 version." + ) chunk_size = min( unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len @@ -203,8 +207,10 @@ def progress(loaded_modules: int, total_modules: int, # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: - logger.warning("Draft model is disabled because a model name " - "wasn't provided. Please check your config.yml!") + logger.warning( + "Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!" + ) enable_draft = False if enable_draft: @@ -313,8 +319,7 @@ def progress(loaded_modules: int, total_modules: int) if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) if not self.quiet: - logger.info( - "Loading draft model: " + self.draft_config.model_dir) + logger.info("Loading draft model: " + self.draft_config.model_dir) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 diff --git a/tests/wheel_test.py b/tests/wheel_test.py index acc6f81a..d9173f73 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -10,16 +10,16 @@ errored_packages = [] if find_spec("flash_attn") is not None: - logger.info(f"Flash attention on version {version('flash_attn')} " - "successfully imported") + logger.info( + f"Flash attention on version {version('flash_attn')} " "successfully imported" + ) successful_packages.append("flash_attn") else: logger.error("Flash attention 2 is not found in your environment.") errored_packages.append("flash_attn") if find_spec("exllamav2") is not None: - logger.info(f"Exllamav2 on version {version('exllamav2')} " - "successfully imported") + logger.info(f"Exllamav2 on version {version('exllamav2')} " "successfully imported") successful_packages.append("exllamav2") else: logger.error("Exllamav2 is not found in your environment.") @@ -40,7 +40,7 @@ errored_packages.append("jinja2") logger.info(f"\nSuccessful imports: {', '.join(successful_packages)}") -logger.error(f"Errored imports: {''.join(errored_packages)}" +logger.error(f"Errored imports: {''.join(errored_packages)}") if len(errored_packages) > 0: print( From cb0b682f9ad5c165ad1804383ef7b6ff90057f6f Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 19:22:14 -0500 Subject: [PATCH 4/6] final touches --- tests/wheel_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/wheel_test.py b/tests/wheel_test.py index d9173f73..f9ee15e5 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -43,8 +43,11 @@ logger.error(f"Errored imports: {''.join(errored_packages)}") if len(errored_packages) > 0: - print( - "\nIf packages are installed, but not found on this test, please " - "check the wheel versions for the correct python version and CUDA " - "version (if applicable)." + logger.warning( + "If all packages are installed, but not found " + "on this test, please check the wheel versions for the " + "correct python version and CUDA version (if " + "applicable)." ) +else: + logger.info("All wheels are installed correctly.") From 682f8452a231430a2396632bbbdbab527588af9c Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Fri, 22 Dec 2023 19:28:09 -0500 Subject: [PATCH 5/6] fix format --- logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logger.py b/logger.py index af20bdfd..2c6248fd 100644 --- a/logger.py +++ b/logger.py @@ -7,7 +7,7 @@ import sys import colorlog -_FORMAT = "%(log_color)s%(levelname)s] %(message)s" +_FORMAT = "%(log_color)s%(levelname)s: %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" From 645186eb7105b7491ced31f32abc5bcb1dc71420 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 22 Dec 2023 23:31:39 -0500 Subject: [PATCH 6/6] Model: Add log to metrics Signed-off-by: kingbri --- model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model.py b/model.py index 93d0120a..08bc23f9 100644 --- a/model.py +++ b/model.py @@ -705,7 +705,7 @@ def generate_gen(self, prompt: str, **kwargs): extra_parts.append("<-- Not accurate (truncated)") # Print output - print( + logger.info( initial_response + " (" + ", ".join(itemization)