From 8a7a3e4436d7284df4c0913f074d77d640a9c6c3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 18 Apr 2024 16:15:12 -0700 Subject: [PATCH] [Core] add an option to log every function call to for debugging hang/crash in distributed inference (#4079) Co-authored-by: Simon Mo --- .buildkite/test-pipeline.yaml | 2 +- .github/ISSUE_TEMPLATE/400-bug report.yml | 2 + tests/test_logger.py | 27 ++++++++++++ vllm/executor/ray_gpu_executor.py | 12 ++++-- vllm/logger.py | 52 +++++++++++++++++++++++ vllm/utils.py | 13 +++++- vllm/worker/worker_base.py | 20 +++++++-- 7 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 tests/test_logger.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2263dee20fbed..0f920c7ec1442 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -40,7 +40,7 @@ steps: - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - label: Engine Test - command: pytest -v -s engine tokenization test_sequence.py test_config.py + command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test commands: diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index f1124dfa78bbc..c87f8fddcb776 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -57,6 +57,8 @@ body: If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. + + If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs. placeholder: | A clear and concise description of what the bug is. diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 0000000000000..601f72b50811c --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,27 @@ +import os +import sys +import tempfile + +from vllm.logger import enable_trace_function_call + + +def f1(x): + return f2(x) + + +def f2(x): + return x + + +def test_trace_function_call(): + fd, path = tempfile.mkstemp() + cur_dir = os.path.dirname(__file__) + enable_trace_function_call(path, cur_dir) + f1(1) + with open(path, 'r') as f: + content = f.read() + + assert "f1" in content + assert "f2" in content + sys.settrace(None) + os.remove(path) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 5a43f1fc28a84..f779b0f8a5113 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -10,7 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) + get_vllm_instance_id, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -133,12 +133,18 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", for node_id, gpu_ids in node_gpus.items(): node_gpus[node_id] = sorted(gpu_ids) - # Set CUDA_VISIBLE_DEVICES for the driver and workers. + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [] for (node_id, _) in worker_node_and_gpu_ids: all_args_to_update_environment_variables.append([{ "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])) + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + os.getenv("VLLM_TRACE_FUNCTION", "0"), }]) self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) diff --git a/vllm/logger.py b/vllm/logger.py index af9575085ef37..046f0e9099a4b 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,9 +1,11 @@ # Adapted from # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" +import datetime import logging import os import sys +from functools import partial from typing import Optional VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) @@ -65,3 +67,53 @@ def init_logger(name: str): logger.addHandler(_default_handler) logger.propagate = False return logger + + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + with open(log_path, 'a') as f: + if event == 'call': + f.write(f"{datetime.datetime.now()} Call to" + f" {func_name} in {filename}:{lineno}\n") + else: + f.write(f"{datetime.datetime.now()} Return from" + f" {func_name} in {filename}:{lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info(f"Trace frame log is saved to {log_file_path}") + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/utils.py b/vllm/utils.py index 49e7033c23e6a..fbe86dacaeb99 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -163,6 +163,17 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +@lru_cache(maxsize=None) +def get_vllm_instance_id(): + """ + If the environment variable VLLM_INSTANCE_ID is set, return it. + Otherwise, return a random UUID. + Instance id represents an instance of the VLLM. All processes in the same + instance should have the same instance id. + """ + return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}") + + @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 @@ -274,7 +285,7 @@ def get_open_port() -> int: def update_environment_variables(envs: Dict[str, str]): for k, v in envs.items(): - if k in os.environ: + if k in os.environ and os.environ[k] != v: logger.warning(f"Overwriting environment variable {k} " f"from '{os.environ[k]}' to '{v}'") os.environ[k] = v diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 13e062fe64b29..783dff3a43404 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,12 +1,15 @@ +import datetime import importlib import os +import tempfile +import threading from abc import ABC, abstractmethod from typing import Dict, List, Set, Tuple -from vllm.logger import init_logger +from vllm.logger import enable_trace_function_call, init_logger from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.utils import update_environment_variables +from vllm.utils import get_vllm_instance_id, update_environment_variables logger = init_logger(__name__) @@ -115,9 +118,20 @@ def update_environment_variables(self, envs: Dict[str, str]) -> None: def init_worker(self, *args, **kwargs): """ - Actual initialization of the worker class. + Actual initialization of the worker class, and set up + function tracing if required. Arguments are passed to the worker class constructor. """ + if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) self.worker = worker_class(*args, **kwargs)