Skip to content

Commit

Permalink
Add logging_handlers argument (#969)
Browse files Browse the repository at this point in the history
* Add `logging_handlers` argument

* Update call to super init
  • Loading branch information
gabrielmbmb authored Sep 12, 2024
1 parent 6e2c9b1 commit ccea49a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
8 changes: 7 additions & 1 deletion src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
logging_handlers: Optional[List[logging.Handler]] = None,
) -> "Distiset": # type: ignore
"""Run the pipeline. It will set the runtime parameters for the steps and validate
the pipeline.
Expand All @@ -338,6 +339,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
Expand All @@ -356,7 +360,9 @@ def run(
self._add_dataset_generator_step(dataset)

setup_logging(
log_queue=self._log_queue, filename=str(self._cache_location["log_file"])
log_queue=self._log_queue,
filename=str(self._cache_location["log_file"]),
logging_handlers=logging_handlers,
)

# Set the name of the pipeline if it's the default one. This should be called
Expand Down
26 changes: 21 additions & 5 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@
import signal
import sys
from multiprocessing.pool import Pool
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Union,
cast,
)

import tblib

Expand All @@ -30,6 +40,7 @@
from distilabel.utils.ray import script_executed_in_ray_cluster

if TYPE_CHECKING:
import logging
from queue import Queue

from distilabel.distiset import Distiset
Expand Down Expand Up @@ -141,6 +152,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline.
Expand All @@ -163,6 +175,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
Expand All @@ -183,11 +198,12 @@ def run(
self._log_queue = cast("Queue[Any]", mp.Queue())

if distiset := super().run(
parameters,
use_cache,
storage_parameters,
use_fs_to_pass_data,
parameters=parameters,
use_cache=use_cache,
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
logging_handlers=logging_handlers,
):
return distiset

Expand Down
14 changes: 10 additions & 4 deletions src/distilabel/pipeline/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from distilabel.utils.serialization import TYPE_INFO_KEY

if TYPE_CHECKING:
import logging
from os import PathLike
from queue import Queue

Expand Down Expand Up @@ -82,6 +83,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline in the Ray cluster.
Expand All @@ -104,6 +106,9 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
Returns:
The `Distiset` created by the pipeline.
Expand All @@ -120,11 +125,12 @@ def run(
)

if distiset := super().run(
parameters,
use_cache,
storage_parameters,
use_fs_to_pass_data,
parameters=parameters,
use_cache=use_cache,
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
logging_handlers=logging_handlers,
):
return distiset

Expand Down
15 changes: 11 additions & 4 deletions src/distilabel/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from logging import FileHandler
from logging.handlers import QueueHandler, QueueListener
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union

from rich.logging import RichHandler

Expand Down Expand Up @@ -47,7 +47,9 @@


def setup_logging(
log_queue: Optional["Queue[Any]"] = None, filename: Optional[str] = None
log_queue: Optional["Queue[Any]"] = None,
filename: Optional[str] = None,
logging_handlers: Optional[List[logging.Handler]] = None,
) -> None:
"""Sets up logging to use a queue across all processes."""
global queue_listener
Expand All @@ -60,21 +62,26 @@ def setup_logging(
# If the current process is the main process, set up a `QueueListener`
# to handle logs from all subprocesses
if mp.current_process().name == "MainProcess" and filename:
if logging_handlers is None:
logging_handlers = []

formatter = logging.Formatter("['%(name)s'] %(message)s")
handler = RichHandler(rich_tracebacks=True)
handler.setFormatter(formatter)
logging_handlers.append(handler)

if not Path(filename).parent.exists():
Path(filename).parent.mkdir(parents=True, exist_ok=True)

file_handler = FileHandler(filename, delay=True, encoding="utf-8")
file_formatter = logging.Formatter(
"[%(asctime)s] %(levelname)-8s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
file_handler.setFormatter(file_formatter)
logging_handlers.append(file_handler)

if log_queue is not None:
queue_listener = QueueListener(
log_queue, handler, file_handler, respect_handler_level=True
log_queue, *logging_handlers, respect_handler_level=True
)
queue_listener.start()

Expand Down

0 comments on commit ccea49a

Please sign in to comment.