Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Color log formatter #1247

Merged
merged 5 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lmdeploy.messages import (EngineGenerationConfig, PytorchEngineConfig,
ResponseType)
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger, get_model
from lmdeploy.utils import get_logger, get_model, logging_timer

from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ..check_env import check_env, check_model
Expand Down Expand Up @@ -404,6 +404,7 @@ def end_session(self, session_id: int):
"""Add new session."""
return end(self.req_sender, session_id)

@logging_timer('CreateModelInputs', logger)
@torch.inference_mode()
def create_model_inputs(self, messages: SeqList, adapters: AdapterList):
"""create model inputs from messages.
Expand Down Expand Up @@ -506,6 +507,8 @@ def _stopping_criteria(self, msg: SchedulerSequence, next_token_id: int):
"""

def _check_stop_word(sampling_param, next_token_id):
if sampling_param.ignore_eos:
return False
return (sampling_param.stop_words is not None
and next_token_id in sampling_param.stop_words)

Expand All @@ -527,6 +530,7 @@ def _check_session_len(msg, max_session_len):
return True
return False

@logging_timer('SamplingLogits', logger)
async def async_sampling_logits(self, logits: torch.Tensor,
running: SeqList, inputs: ModelInputs):
"""sampling logits."""
Expand Down Expand Up @@ -565,11 +569,13 @@ def _gather_history(seqs: SeqList, device: torch.device):
with torch.inference_mode(), torch.cuda.stream(self.stream):
logits = logits_processor(input_ids, split_logits)
next_token_ids = logits_processor.sampling(logits)
self.stream.synchronize()
await asyncio.get_event_loop().run_in_executor(None,
self.stream.synchronize)
next_token_ids = next_token_ids.cpu()

return next_token_ids, split_logits

@logging_timer('UpdateRunning', logger)
def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
meta: Any):
"""update scheduler."""
Expand All @@ -593,6 +599,7 @@ def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence):

return True

@logging_timer('ModelForward', logger)
async def _async_model_forward(self, inputs: ModelInputs,
swap_in_map: Dict, swap_out_map: Dict):
"""model forward."""
Expand Down Expand Up @@ -699,6 +706,7 @@ async def __long_context_forward(inputs):
else:
return await __long_context_forward(inputs)

@logging_timer('AsyncStep', logger)
async def async_step(self, is_prefill: bool, return_logits: bool = False):
"""one step inference. Used to perform streaming chat.

Expand All @@ -714,6 +722,7 @@ async def async_step(self, is_prefill: bool, return_logits: bool = False):
adapters = schedule_output.adapters
if len(running) == 0:
return dict()
logger.debug(f'<AsyncStep>: batch_size={len(running)}')

inputs = self.create_model_inputs(running, adapters)

Expand Down Expand Up @@ -952,7 +961,6 @@ def _send_resp(step_tokens):
await asyncio.sleep(0.01)
continue

logger.debug('async_loop: RequestManager Step.')
self.req_manager.step()

# forward
Expand All @@ -961,16 +969,13 @@ def _send_resp(step_tokens):
is_prefill = not prefill_counter or not has_running
if is_prefill:
prefill_counter = prefill_interval
logger.debug('async_loop: Engine Step - '
f'prefilling: {is_prefill}')
with torch.inference_mode():
step_tokens: Dict[int,
InferOutput] = await self.async_step(
is_prefill=is_prefill)
prefill_counter -= 1

# send response
logger.debug('async_loop: Response.')
_send_resp(step_tokens)


Expand Down
9 changes: 6 additions & 3 deletions lmdeploy/pytorch/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@


def _raise_exception_on_finish(task: asyncio.Task) -> None:
msg = ('Engine loop failed!')
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise RuntimeError(msg) from exc
except Exception as e:
logger.exception(f'Engine loop failed with error: {e}')


def _ignore_exception_on_finish(task: asyncio.Task) -> None:
Expand Down Expand Up @@ -158,6 +157,8 @@ def _resp_get(self):
except Empty:
continue
except Exception as e:
logger.exception(
f'sender[{self.sender_id}] get response failed: {e}')
raise e

async def _async_resp_get(self):
Expand All @@ -177,6 +178,8 @@ async def __no_threadsafe_get():
except asyncio.TimeoutError:
continue
except Exception as e:
logger.exception(
f'sender[{self.sender_id}] get response failed: {e}')
raise e

if self.is_thread_safe():
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/models/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ def _find_submodulename():
rewrite_qualname = _find_submodulename()

origin_qualname = f'{module_name}.{class_name}'
logger.debug(
f'Find rewrite of module {origin_qualname}: {rewrite_qualname}.')
if rewrite_qualname is not None:
logger.debug('Find rewrite of module\n'
f'{origin_qualname} <=> {rewrite_qualname}')
return rewrite_qualname


Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/paging/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Dict, List, Set, Union

from lmdeploy.utils import get_logger
from lmdeploy.utils import get_logger, logging_timer

from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ..config import CacheConfig, SchedulerConfig
Expand Down Expand Up @@ -116,6 +116,7 @@ def add_adapter(self, adapter_path: str, adapter_name: str):
adapter) - self.block_manager.num_gpu_blocks
return adapter.build_weight_map(block_table)

@logging_timer('SchedulePrefilling', logger)
def _schedule_prefill(self):
"""Schedule for prefilling."""

Expand Down Expand Up @@ -206,6 +207,7 @@ def _deactive_adapter(adapter_name):
self.running += running
return running, swap_in_map, swap_out_map, copy_map

@logging_timer('ScheduleDecoding', logger)
def _schedule_decoding(self):
"""schedule decoding."""
assert len(self.running) != 0
Expand Down
74 changes: 73 additions & 1 deletion lmdeploy/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import functools
import logging
import sys
import time
from contextlib import contextmanager
from logging import Logger, LogRecord
from typing import List, Optional

logger_initialized = {}


class ColorFormatter(logging.Formatter):

_LEVELNAME_COLOR_MAP = dict(CRITICAL='\033[91m',
grimoire marked this conversation as resolved.
Show resolved Hide resolved
ERROR='\033[31m',
WARN='\033[33m',
WARNING='\033[33m',
INFO='\033[37m',
DEBUG='\033[32m')

_RESET_COLOR = '\033[0m'

def format(self, record: LogRecord):
"""format."""
if sys.platform == 'win32':
# windows does not support ASNI color
return super().format(record)
levelname = record.levelname
level_color = self._LEVELNAME_COLOR_MAP.get(levelname,
self._RESET_COLOR)
levelname = f'{level_color}{levelname}{self._RESET_COLOR}'
record.levelname = levelname
return super().format(record)


class FilterDuplicateWarning(logging.Filter):
"""Filter the repeated warning message.

Expand Down Expand Up @@ -85,7 +113,8 @@ def get_logger(
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)

formatter = logging.Formatter(log_formatter)
# formatter = logging.Formatter(log_formatter)
grimoire marked this conversation as resolved.
Show resolved Hide resolved
formatter = ColorFormatter(log_formatter)
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
Expand Down Expand Up @@ -158,3 +187,46 @@ def get_model(pretrained_model_name_or_path: str,
model_path = snapshot_download(pretrained_model_name_or_path,
**download_kwargs)
return model_path


def logging_timer(op_name: str, logger: Logger, level: int = logging.DEBUG):
"""logging timer."""

@contextmanager
def __timer():
"""timer."""
start = time.perf_counter()
yield
end = time.perf_counter()
duration = (end - start) * 1000
logger.log(level, f'<{op_name}> take time: {duration:.2f} ms')

def __inner(func):
"""inner."""

@functools.wraps(func)
def __func_warpper(*args, **kwargs):
"""func warpper."""
if not logger.isEnabledFor(level):
return func(*args, **kwargs)
with __timer():
return func(*args, **kwargs)

@functools.wraps(func)
def __async_warpper(*args, **kwargs):
"""async warpper."""

async def __tmp():
if not logger.isEnabledFor(level):
return (await func(*args, **kwargs))
with __timer():
return (await func(*args, **kwargs))

return __tmp()

if asyncio.iscoroutinefunction(func):
return __async_warpper
else:
return __func_warpper

return __inner
Loading