diff --git a/src/itwinai/loggers.py b/src/itwinai/loggers.py index f7a322e3..ce42e5c0 100644 --- a/src/itwinai/loggers.py +++ b/src/itwinai/loggers.py @@ -113,7 +113,7 @@ class Logger(LogMixin): """Base class for logger Args: - savedir (str, optional): filesystem location where logs are stored. + savedir (Path, optional): filesystem location where logs are stored. Defaults to 'mllogs'. log_freq (Union[int, Literal['epoch', 'batch']], optional): how often should the logger fulfill calls to the `log()` @@ -136,7 +136,7 @@ class Logger(LogMixin): """ #: Location on filesystem where to store data. - savedir: str = None + savedir: Optional[Path] = None #: Supported logging 'kind's. supported_kinds: Tuple[str] #: Current worker global rank @@ -146,7 +146,7 @@ class Logger(LogMixin): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", log_on_workers: Union[int, List[int]] = 0, experiment_id: Optional[str] = None, @@ -240,7 +240,7 @@ def serialize(self, obj: Any, identifier: str) -> str: Returns: str: local path of the serialized object to be logged. """ - itm_path = os.path.join(self.savedir, identifier) + itm_path = Path(self.savedir) / str(identifier) with open(itm_path, "wb") as itm_file: pickle.dump(obj, itm_file) @@ -301,7 +301,7 @@ class _EmptyLogger(Logger): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", log_freq: int | Literal["epoch"] | Literal["batch"] = "epoch", log_on_workers: int | List[int] = 0, ) -> None: @@ -332,7 +332,7 @@ class ConsoleLogger(Logger): """Simplified logger. Args: - savedir (str, optional): where to store artifacts. + savedir (Path, optional): where to store artifacts. Defaults to 'mllogs'. log_freq (Union[int, Literal['epoch', 'batch']], optional): determines whether the logger should fulfill or ignore @@ -349,11 +349,11 @@ class ConsoleLogger(Logger): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = savedir = Path(savedir) / "simple-logger" + savedir = Path(savedir) / "simple-logger" super().__init__(savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers) def create_logger_context(self, rank: Optional[int] = None): @@ -458,11 +458,29 @@ def log( print(f"ConsoleLogger: {identifier} = {item}") +def root(): + raise ValueError("ROOT") + + +def middle(): + try: + root() + except Exception: + raise ValueError("MIDDLE") + + +def outer(): + try: + middle() + except Exception: + raise ValueError("OUTER") + + class MLFlowLogger(Logger): """Abstraction around MLFlow logger. Args: - savedir (str, optional): path on local filesystem where logs are + savedir (Path, optional): path on local filesystem where logs are stored. Defaults to 'mllogs'. experiment_name (str, optional): experiment name. Defaults to ``itwinai.loggers.BASE_EXP_NAME``. @@ -501,7 +519,7 @@ class MLFlowLogger(Logger): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", experiment_name: str = BASE_EXP_NAME, tracking_uri: Optional[str] = None, run_description: Optional[str] = None, @@ -509,7 +527,7 @@ def __init__( log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", log_on_workers: Union[int, List[int]] = 0, ): - savedir = os.path.join(savedir, "mlflow") + savedir = Path(savedir) / "mlflow" super().__init__(savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers) self.tracking_uri = tracking_uri self.run_description = run_description @@ -604,7 +622,7 @@ def log( if not isinstance(item, str): # Save the object locally and then log it name = os.path.basename(identifier) - save_path = Path(self.savedir) / ".trash" / name + save_path = Path(self.savedir) / ".trash" / str(name) save_path.mkdir(os.path.dirname(save_path), exist_ok=True) item = self.serialize(item, save_path) mlflow.log_artifact(local_path=item, artifact_path=identifier) @@ -630,7 +648,7 @@ def log( # Save the object locally and then log it name = os.path.basename(identifier) - save_path = Path(self.savedir) / ".trash" / name + save_path = Path(self.savedir) / ".trash" / str(name) save_path.mkdir(os.path.dirname(save_path), exist_ok=True) torch.save(item, save_path) # Log into mlflow @@ -655,7 +673,7 @@ class WandBLogger(Logger): """Abstraction around WandB logger. Args: - savedir (str, optional): location on local filesystem where logs + savedir (Path, optional): location on local filesystem where logs are stored. Defaults to 'mllogs'. project_name (str, optional): experiment name. Defaults to ``itwinai.loggers.BASE_EXP_NAME``. @@ -685,12 +703,12 @@ class WandBLogger(Logger): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", project_name: str = BASE_EXP_NAME, log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = os.path.join(savedir, "wandb") + savedir = Path(savedir) / "wandb" super().__init__(savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers) self.project_name = project_name @@ -707,10 +725,8 @@ def create_logger_context(self, rank: Optional[int] = None) -> None: if not self.should_log(): return - os.makedirs(os.path.join(self.savedir, "wandb"), exist_ok=True) - self.active_run = wandb.init( - dir=os.path.abspath(self.savedir), project=self.project_name - ) + os.makedirs(self.savedir / "wandb", exist_ok=True) + self.active_run = wandb.init(dir=self.savedir.resolve(), project=self.project_name) def destroy_logger_context(self): """Destroy logger.""" @@ -767,7 +783,7 @@ class TensorBoardLogger(Logger): TensorFlow. Args: - savedir (str, optional): location on local filesystem where logs + savedir (Path, optional): location on local filesystem where logs are stored. Defaults to 'mllogs'. log_freq (Union[int, Literal['epoch', 'batch']], optional): determines whether the logger should fulfill or ignore @@ -793,12 +809,12 @@ class TensorBoardLogger(Logger): def __init__( self, - savedir: str = "mllogs", + savedir: Path = "mllogs", log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", framework: Literal["tensorflow", "pytorch"] = "pytorch", log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = os.path.join(savedir, "tensorboard") + savedir = Path(savedir) / "tensorboard" super().__init__(savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers) self.framework = framework if framework.lower() == "tensorflow": @@ -914,7 +930,7 @@ class LoggersCollection(Logger): supported_kinds: Tuple[str] def __init__(self, loggers: List[Logger]) -> None: - super().__init__(savedir="/tmp/mllogs_LoggersCollection", log_freq=1) + super().__init__(savedir=Path("/tmp/mllogs_LoggersCollection"), log_freq=1) self.loggers = loggers def should_log(self, batch_idx: int = None) -> bool: @@ -998,7 +1014,7 @@ class Prov4MLLogger(Logger): files will be uploaded. Defaults to "www.example.org". experiment_name (str, optional): experiment name. Defaults to "experiment_name". - provenance_save_dir (str, optional): path where to store provenance + provenance_save_dir (Path, optional): path where to store provenance files and logs. Defaults to "prov". save_after_n_logs (Optional[int], optional): how often to save logs to disk from main memory. Defaults to 100. @@ -1031,9 +1047,9 @@ class Prov4MLLogger(Logger): def __init__( self, - prov_user_namespace="www.example.org", - experiment_name="experiment_name", - provenance_save_dir="mllogs/prov_logs", + prov_user_namespace: str = "www.example.org", + experiment_name: str = "experiment_name", + provenance_save_dir: Path = "mllogs/prov_logs", save_after_n_logs: Optional[int] = 100, create_graph: Optional[bool] = True, create_svg: Optional[bool] = True, diff --git a/src/itwinai/parser.py b/src/itwinai/parser.py index 8d0a9efc..de702e6b 100644 --- a/src/itwinai/parser.py +++ b/src/itwinai/parser.py @@ -27,6 +27,12 @@ from .utils import load_yaml +class _ArgumentParser(JAPArgumentParser): + def error(self, message: str, ex: Optional[Exception] = None) -> None: + """Patch error method to re-raise exception instead of exiting exection""" + raise ex + + def add_replace_field(config: Dict, key_chain: str, value: Any) -> None: """Replace or add (if not present) a field in a dictionary, following a path of dot-separated keys. Adding is not supported for list items. @@ -61,9 +67,16 @@ def add_replace_field(config: Dict, key_chain: str, value: Any) -> None: sub_config[k] = value +def get_root_cause(exception: Exception) -> Exception: + """Recursively extract the first exception in the exception chain.""" + root = exception + while root.__cause__ is not None: # Traverse the exception chain + root = root.__cause__ + return root + + class ConfigParser: - """ - Parses a pipeline from a configuration file. + """Parses a pipeline from a configuration file. It also provides functionalities for dynamic override of fields by means of nested key notation. @@ -150,7 +163,7 @@ def parse_pipeline( Returns: Pipeline: instantiated pipeline. """ - pipe_parser = JAPArgumentParser() + pipe_parser = _ArgumentParser() pipe_parser.add_subclass_arguments(Pipeline, "pipeline") pipe_dict = self.config @@ -163,9 +176,13 @@ def parse_pipeline( print("Assembled pipeline:") print(json.dumps(pipe_dict, indent=4)) - # Parse pipeline dict once merged with steps - conf = pipe_parser.parse_object(pipe_dict) - pipe = pipe_parser.instantiate_classes(conf) + try: + # Parse pipeline dict once merged with steps + conf = pipe_parser.parse_object(pipe_dict) + pipe = pipe_parser.instantiate_classes(conf) + except Exception as exc: + exc = get_root_cause(exc) + raise exc self.pipeline = pipe["pipeline"] return self.pipeline @@ -187,10 +204,15 @@ def parse_step( # Wrap config under "step" field and parse it step_dict_config = {"step": step_dict_config} - step_parser = JAPArgumentParser() + step_parser = _ArgumentParser() step_parser.add_subclass_arguments(BaseComponent, "step") - parsed_namespace = step_parser.parse_object(step_dict_config) - return step_parser.instantiate_classes(parsed_namespace)["step"] + try: + parsed_namespace = step_parser.parse_object(step_dict_config) + step = step_parser.instantiate_classes(parsed_namespace)["step"] + except Exception as exc: + exc = get_root_cause(exc) + raise exc + return step class ArgumentParser(JAPArgumentParser):