From ccea49a582404b3dbbb21131a21551d00ab945d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Thu, 12 Sep 2024 10:02:44 +0200 Subject: [PATCH] Add `logging_handlers` argument (#969) * Add `logging_handlers` argument * Update call to super init --- src/distilabel/pipeline/base.py | 8 +++++++- src/distilabel/pipeline/local.py | 26 +++++++++++++++++++++----- src/distilabel/pipeline/ray.py | 14 ++++++++++---- src/distilabel/utils/logging.py | 15 +++++++++++---- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index ced081343b..cd6f35e59d 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -312,6 +312,7 @@ def run( storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, dataset: Optional["InputDataset"] = None, + logging_handlers: Optional[List[logging.Handler]] = None, ) -> "Distiset": # type: ignore """Run the pipeline. It will set the runtime parameters for the steps and validate the pipeline. @@ -338,6 +339,9 @@ def run( dataset: If given, it will be used to create a `GeneratorStep` and put it as the root step. Convenient method when you have already processed the dataset in your script and just want to pass it already processed. Defaults to `None`. + logging_handlers: A list of logging handlers that will be used to log the + output of the pipeline. This argument can be useful so the logging messages + can be extracted and used in a different context. Defaults to `None`. Returns: The `Distiset` created by the pipeline. @@ -356,7 +360,9 @@ def run( self._add_dataset_generator_step(dataset) setup_logging( - log_queue=self._log_queue, filename=str(self._cache_location["log_file"]) + log_queue=self._log_queue, + filename=str(self._cache_location["log_file"]), + logging_handlers=logging_handlers, ) # Set the name of the pipeline if it's the default one. This should be called diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 1daa63944d..be7919d56d 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -16,7 +16,17 @@ import signal import sys from multiprocessing.pool import Pool -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Union, + cast, +) import tblib @@ -30,6 +40,7 @@ from distilabel.utils.ray import script_executed_in_ray_cluster if TYPE_CHECKING: + import logging from queue import Queue from distilabel.distiset import Distiset @@ -141,6 +152,7 @@ def run( storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, dataset: Optional["InputDataset"] = None, + logging_handlers: Optional[List["logging.Handler"]] = None, ) -> "Distiset": """Runs the pipeline. @@ -163,6 +175,9 @@ def run( dataset: If given, it will be used to create a `GeneratorStep` and put it as the root step. Convenient method when you have already processed the dataset in your script and just want to pass it already processed. Defaults to `None`. + logging_handlers: A list of logging handlers that will be used to log the + output of the pipeline. This argument can be useful so the logging messages + can be extracted and used in a different context. Defaults to `None`. Returns: The `Distiset` created by the pipeline. @@ -183,11 +198,12 @@ def run( self._log_queue = cast("Queue[Any]", mp.Queue()) if distiset := super().run( - parameters, - use_cache, - storage_parameters, - use_fs_to_pass_data, + parameters=parameters, + use_cache=use_cache, + storage_parameters=storage_parameters, + use_fs_to_pass_data=use_fs_to_pass_data, dataset=dataset, + logging_handlers=logging_handlers, ): return distiset diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py index bfdff96c64..cf9a26064a 100644 --- a/src/distilabel/pipeline/ray.py +++ b/src/distilabel/pipeline/ray.py @@ -25,6 +25,7 @@ from distilabel.utils.serialization import TYPE_INFO_KEY if TYPE_CHECKING: + import logging from os import PathLike from queue import Queue @@ -82,6 +83,7 @@ def run( storage_parameters: Optional[Dict[str, Any]] = None, use_fs_to_pass_data: bool = False, dataset: Optional["InputDataset"] = None, + logging_handlers: Optional[List["logging.Handler"]] = None, ) -> "Distiset": """Runs the pipeline in the Ray cluster. @@ -104,6 +106,9 @@ def run( dataset: If given, it will be used to create a `GeneratorStep` and put it as the root step. Convenient method when you have already processed the dataset in your script and just want to pass it already processed. Defaults to `None`. + logging_handlers: A list of logging handlers that will be used to log the + output of the pipeline. This argument can be useful so the logging messages + can be extracted and used in a different context. Defaults to `None`. Returns: The `Distiset` created by the pipeline. @@ -120,11 +125,12 @@ def run( ) if distiset := super().run( - parameters, - use_cache, - storage_parameters, - use_fs_to_pass_data, + parameters=parameters, + use_cache=use_cache, + storage_parameters=storage_parameters, + use_fs_to_pass_data=use_fs_to_pass_data, dataset=dataset, + logging_handlers=logging_handlers, ): return distiset diff --git a/src/distilabel/utils/logging.py b/src/distilabel/utils/logging.py index 9939527aa4..994c81e321 100644 --- a/src/distilabel/utils/logging.py +++ b/src/distilabel/utils/logging.py @@ -18,7 +18,7 @@ from logging import FileHandler from logging.handlers import QueueHandler, QueueListener from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union from rich.logging import RichHandler @@ -47,7 +47,9 @@ def setup_logging( - log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None + log_queue: Optional["Queue[Any]"] = None, + filename: Optional[str] = None, + logging_handlers: Optional[List[logging.Handler]] = None, ) -> None: """Sets up logging to use a queue across all processes.""" global queue_listener @@ -60,21 +62,26 @@ def setup_logging( # If the current process is the main process, set up a `QueueListener` # to handle logs from all subprocesses if mp.current_process().name == "MainProcess" and filename: + if logging_handlers is None: + logging_handlers = [] + formatter = logging.Formatter("['%(name)s'] %(message)s") handler = RichHandler(rich_tracebacks=True) handler.setFormatter(formatter) + logging_handlers.append(handler) + if not Path(filename).parent.exists(): Path(filename).parent.mkdir(parents=True, exist_ok=True) - file_handler = FileHandler(filename, delay=True, encoding="utf-8") file_formatter = logging.Formatter( "[%(asctime)s] %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) file_handler.setFormatter(file_formatter) + logging_handlers.append(file_handler) if log_queue is not None: queue_listener = QueueListener( - log_queue, handler, file_handler, respect_handler_level=True + log_queue, *logging_handlers, respect_handler_level=True ) queue_listener.start()