diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index c6598694..fc42d94f 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -19,21 +19,69 @@ app = typer.Typer(pretty_exceptions_enable=False) +@app.command() +def generate_communication_plot( + log_dir: str = "profiling_logs", + pattern: str = r"profile_(\w+)_(\d+)_(\d+)\.csv$", + output_file: str = "plots/comm_plot.png", +) -> None: + """Generate stacked plot showing computation vs. communication fraction. Stores it + + Args: + log_dir: The directory where the csv logs are stored. Defauls to + ``profiling_logs``. + pattern: A regex pattern to recognize the file names in the 'log_dir' folder. + Defaults to ``profile_(\\w+)_(\\d+)_(\\d+)\\.csv$``. + output_file: The path to where the resulting plot should be saved. Defaults to + ``plots/comm_plot.png``. + """ + import matplotlib.pyplot as plt + + from itwinai.torch.profiling.communication_plot import ( + create_combined_comm_overhead_df, + create_stacked_plot, + get_comp_fraction_full_array, + ) + + log_dir_path = Path(log_dir) + if not log_dir_path.exists(): + raise IOError( + f"The directory '{log_dir_path.resolve()}' does not exist, so could not" + f"extract profiling logs. Make sure you are running this command in the " + f"same directory as the logging dir." + ) + + df = create_combined_comm_overhead_df(logs_dir=log_dir_path, pattern=pattern) + values = get_comp_fraction_full_array(df, print_table=True) + + strategies = sorted(df["strategy"].unique()) + gpu_numbers = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) + + fig, _ = create_stacked_plot(values, strategies, gpu_numbers) + + # TODO: set these dynamically? + fig.set_figwidth(8) + fig.set_figheight(6) + + output_path = Path(output_file) + output_path.parent.mkdir(parents=True, exist_ok=True) + + plt.savefig(output_path) + print(f"\nSaved computation vs. communication plot at '{output_path.resolve()}'") + + @app.command() def sanity_check( - torch: Annotated[Optional[bool], typer.Option( - help=("Check also itwinai.torch modules.") - )] = False, - tensorflow: Annotated[Optional[bool], typer.Option( - help=("Check also itwinai.tensorflow modules.") - )] = False, - all: Annotated[Optional[bool], typer.Option( - help=("Check all modules.") - )] = False, + torch: Annotated[ + Optional[bool], typer.Option(help=("Check also itwinai.torch modules.")) + ] = False, + tensorflow: Annotated[ + Optional[bool], typer.Option(help=("Check also itwinai.tensorflow modules.")) + ] = False, + all: Annotated[Optional[bool], typer.Option(help=("Check all modules."))] = False, ): - """Run sanity checks on the installation of itwinai and - its dependencies by trying to import itwinai modules. - By default, only itwinai core modules (neither torch, nor + """Run sanity checks on the installation of itwinai and its dependencies by trying + to import itwinai modules. By default, only itwinai core modules (neither torch, nor tensorflow) are tested.""" from itwinai.tests.sanity_check import ( sanity_check_all, @@ -41,6 +89,7 @@ def sanity_check( sanity_check_tensorflow, sanity_check_torch, ) + all = (torch and tensorflow) or all if all: sanity_check_all() @@ -54,18 +103,15 @@ def sanity_check( @app.command() def scalability_report( - pattern: Annotated[str, typer.Option( - help="Python pattern matching names of CSVs in sub-folders." - )], - plot_title: Annotated[Optional[str], typer.Option( - help=("Plot name.") - )] = None, - skip_id: Annotated[Optional[int], typer.Option( - help=("Skip epoch ID.") - )] = None, - archive: Annotated[Optional[str], typer.Option( - help=("Archive name to backup the data, without extension.") - )] = None, + pattern: Annotated[ + str, typer.Option(help="Python pattern matching names of CSVs in sub-folders.") + ], + plot_title: Annotated[Optional[str], typer.Option(help=("Plot name."))] = None, + skip_id: Annotated[Optional[int], typer.Option(help=("Skip epoch ID."))] = None, + archive: Annotated[ + Optional[str], + typer.Option(help=("Archive name to backup the data, without extension.")), + ] = None, ): """ Generate scalability report merging all CSVs containing epoch time @@ -88,7 +134,7 @@ def scalability_report( import numpy as np import pandas as pd - regex = re.compile(r'{}'.format(pattern)) + regex = re.compile(r"{}".format(pattern)) combined_df = pd.DataFrame() csv_files = [] for root, _, files in os.walk(os.getcwd()): @@ -104,14 +150,13 @@ def scalability_report( print(combined_df) avg_times = ( - combined_df - .drop(columns='epoch_id') - .groupby(['name', 'nodes']) + combined_df.drop(columns="epoch_id") + .groupby(["name", "nodes"]) .mean() .reset_index() ) print("\nAvg over name and nodes:") - print(avg_times.rename(columns=dict(time='avg(time)'))) + print(avg_times.rename(columns=dict(time="avg(time)"))) # fig, (sp_up_ax, eff_ax) = plt.subplots(1, 2, figsize=(12, 4)) fig, sp_up_ax = plt.subplots(1, 1, figsize=(6, 4)) @@ -125,7 +170,7 @@ def scalability_report( series_names = sorted(set(avg_times.name.values)) for name in series_names: - df = avg_times[avg_times.name == name].drop(columns='name') + df = avg_times[avg_times.name == name].drop(columns="name") # Debug # compute_time = [3791., 1884., 1011., 598.] @@ -133,32 +178,42 @@ def scalability_report( # d = {'nodes': nodes, 'time': compute_time} # df = pd.DataFrame(data=d) - df["NGPUs"] = df["nodes"]*4 + df["NGPUs"] = df["nodes"] * 4 # speedup df["Speedup - ideal"] = df["nodes"].astype(float) df["Speedup"] = df["time"].iloc[0] / df["time"] df["Nworkers"] = 1 # efficiency - df["Threadscaled Sim. Time / s"] = df["time"] * \ - df["nodes"] * df["Nworkers"] - df["Efficiency"] = df["Threadscaled Sim. Time / s"].iloc[0] / \ - df["Threadscaled Sim. Time / s"] + df["Threadscaled Sim. Time / s"] = df["time"] * df["nodes"] * df["Nworkers"] + df["Efficiency"] = ( + df["Threadscaled Sim. Time / s"].iloc[0] / df["Threadscaled Sim. Time / s"] + ) sp_up_ax.plot( - df["NGPUs"].values, df["Speedup"].values, - marker=next(markers), lw=1.0, label=name, alpha=0.7) + df["NGPUs"].values, + df["Speedup"].values, + marker=next(markers), + lw=1.0, + label=name, + alpha=0.7, + ) - sp_up_ax.plot(df["NGPUs"].values, df["Speedup - ideal"].values, - ls='dashed', lw=1.0, c='k', label="ideal") + sp_up_ax.plot( + df["NGPUs"].values, + df["Speedup - ideal"].values, + ls="dashed", + lw=1.0, + c="k", + label="ideal", + ) sp_up_ax.legend(ncol=1) sp_up_ax.set_xticks(df["NGPUs"].values) - sp_up_ax.get_xaxis().set_major_formatter( - matplotlib.ticker.ScalarFormatter()) + sp_up_ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - sp_up_ax.set_ylabel('Speedup') - sp_up_ax.set_xlabel('NGPUs (4 per node)') + sp_up_ax.set_ylabel("Speedup") + sp_up_ax.set_xlabel("NGPUs (4 per node)") sp_up_ax.grid() # Sort legend @@ -168,42 +223,42 @@ def scalability_report( plot_png = f"scaling_plot_{plot_title}.png" plt.tight_layout() - plt.savefig(plot_png, bbox_inches='tight', format='png', dpi=300) + plt.savefig(plot_png, bbox_inches="tight", format="png", dpi=300) print("Saved scaling plot to: ", plot_png) if archive is not None: - if '/' in archive: - raise ValueError("Archive name must NOT contain a path. " - f"Received: '{archive}'") - if '.' in archive: - raise ValueError("Archive name must NOT contain an extension. " - f"Received: '{archive}'") + if "/" in archive: + raise ValueError( + f"Archive name must NOT contain a path. Received: '{archive}'" + ) + if "." in archive: + raise ValueError( + f"Archive name must NOT contain an extension. Received: '{archive}'" + ) if os.path.isdir(archive): - raise ValueError(f"Folder '{archive}' already exists. " - "Change archive name.") + raise ValueError(f"Folder '{archive}' already exists. Change archive name.") os.makedirs(archive) for csvfile in csv_files: - shutil.copyfile(csvfile, os.path.join(archive, - os.path.basename(csvfile))) + shutil.copyfile(csvfile, os.path.join(archive, os.path.basename(csvfile))) shutil.copyfile(plot_png, os.path.join(archive, plot_png)) avg_times.to_csv(os.path.join(archive, "avg_times.csv"), index=False) print("Archived AVG epoch times CSV") # Copy SLURM logs: *.err *.out files - if os.path.exists('logs_slurm'): + if os.path.exists("logs_slurm"): print("Archived SLURM logs") - shutil.copytree('logs_slurm', os.path.join(archive, 'logs_slurm')) + shutil.copytree("logs_slurm", os.path.join(archive, "logs_slurm")) # Copy other SLURM logs - for ext in ['*.out', '*.err']: + for ext in ["*.out", "*.err"]: for file in glob.glob(ext): shutil.copyfile(file, os.path.join(archive, file)) # Create archive archive_name = shutil.make_archive( base_name=archive, # archive file name - format='gztar', + format="gztar", # root_dir='.', - base_dir=archive # folder path inside archive + base_dir=archive, # folder path inside archive ) shutil.rmtree(archive) print("Archived logs and plot at: ", archive_name) @@ -211,24 +266,37 @@ def scalability_report( @app.command() def exec_pipeline( - config: Annotated[Path, typer.Option( - help="Path to the configuration file of the pipeline to execute." - )], - pipe_key: Annotated[str, typer.Option( - help=("Key in the configuration file identifying " - "the pipeline object to execute.") - )] = "pipeline", - steps: Annotated[Optional[str], typer.Option( - help=("Run only some steps of the pipeline. Accepted values are " - "indices, python slices (e.g., 0:3 or 2:10:100), and " - "string names of steps.") - )] = None, - print_config: Annotated[bool, typer.Option( - help=("Print config to be executed after overrides.") - )] = False, + config: Annotated[ + Path, + typer.Option(help="Path to the configuration file of the pipeline to execute."), + ], + pipe_key: Annotated[ + str, + typer.Option( + help=( + "Key in the configuration file identifying " + "the pipeline object to execute." + ) + ), + ] = "pipeline", + steps: Annotated[ + Optional[str], + typer.Option( + help=( + "Run only some steps of the pipeline. Accepted values are " + "indices, python slices (e.g., 0:3 or 2:10:100), and " + "string names of steps." + ) + ), + ] = None, + print_config: Annotated[ + bool, typer.Option(help=("Print config to be executed after overrides.")) + ] = False, overrides_list: Annotated[ - Optional[List[str]], typer.Option( - "--override", "-o", + Optional[List[str]], + typer.Option( + "--override", + "-o", help=( "Nested key to dynamically override elements in the " "configuration file with the " @@ -237,13 +305,11 @@ def exec_pipeline( "Example: [...] " "-o pipeline.init_args.trainer.init_args.lr=0.001 " "-o pipeline.my_list.2.batch_size=64 " - ) - ) - ] = None + ), + ), + ] = None, ): - """Execute a pipeline from configuration file. - Allows dynamic override of fields. - """ + """Execute a pipeline from configuration file. Allows dynamic override of fields.""" # Add working directory to python path so that the interpreter is able # to find the local python files imported from the pipeline file import os @@ -251,23 +317,26 @@ def exec_pipeline( import sys from .utils import str_to_slice + sys.path.append(os.path.dirname(config)) sys.path.append(os.getcwd()) # Parse and execute pipeline from itwinai.parser import ConfigParser + overrides_list = overrides_list if overrides_list is not None else [] overrides = { - k: v for k, v - in map(lambda x: (x.split('=')[0], x.split('=')[1]), overrides_list) + k: v + for k, v in map(lambda x: (x.split("=")[0], x.split("=")[1]), overrides_list) } parser = ConfigParser(config=config, override_keys=overrides) if print_config: import json + print() - print("#="*15 + " Used configuration " + "#="*15) + print("#=" * 15 + " Used configuration " + "#=" * 15) print(json.dumps(parser.config, indent=2)) - print("#="*50) + print("#=" * 50) print() pipeline = parser.parse_pipeline(pipeline_nested_key=pipe_key) if steps: @@ -282,13 +351,9 @@ def exec_pipeline( @app.command() def mlflow_ui( path: str = typer.Option("ml-logs/", help="Path to logs storage."), - port: int = typer.Option( - 5000, help="Port on which the MLFlow UI is listening." - ), + port: int = typer.Option(5000, help="Port on which the MLFlow UI is listening."), ): - """ - Visualize Mlflow logs. - """ + """Visualize Mlflow logs.""" import subprocess subprocess.run(f"mlflow ui --backend-store-uri {path} --port {port}".split()) @@ -297,12 +362,9 @@ def mlflow_ui( @app.command() def mlflow_server( path: str = typer.Option("ml-logs/", help="Path to logs storage."), - port: int = typer.Option( - 5000, help="Port on which the server is listening."), + port: int = typer.Option(5000, help="Port on which the server is listening."), ): - """ - Spawn Mlflow server. - """ + """Spawn Mlflow server.""" import subprocess subprocess.run(f"mlflow server --backend-store-uri {path} --port {port}".split()) @@ -310,18 +372,13 @@ def mlflow_server( @app.command() def kill_mlflow_server( - port: int = typer.Option( - 5000, help="Port on which the server is listening."), + port: int = typer.Option(5000, help="Port on which the server is listening."), ): - """ - Kill Mlflow server. - """ + """Kill Mlflow server.""" import subprocess subprocess.run( - f"kill -9 $(lsof -t -i:{port})".split(), - check=True, - stderr=subprocess.DEVNULL + f"kill -9 $(lsof -t -i:{port})".split(), check=True, stderr=subprocess.DEVNULL ) diff --git a/src/itwinai/components.py b/src/itwinai/components.py index 700173c0..4b052319 100644 --- a/src/itwinai/components.py +++ b/src/itwinai/components.py @@ -82,12 +82,14 @@ >>> python my_train.py --config training_pipe.yaml --lr 0.002 """ - from __future__ import annotations import functools import time from abc import ABC, abstractmethod + +# import logging +# from logging import Logger as PythonLogger from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .serialization import ModelLoader, Serializable @@ -113,6 +115,7 @@ def wrapper(self: BaseComponent, *args, **kwargs) -> Any: msg = f"'{self.name}' executed in {self.exec_t:.3f}s" self._printout(msg) return result + return wrapper @@ -127,7 +130,8 @@ class BaseComponent(ABC, Serializable): Args: name (Optional[str], optional): unique identifier for a step. Defaults to None. - """ + """ + _name: str = None #: Dictionary storing constructor arguments. Needed to serialize the #: class to dictionary. Set by ``self.save_parameters()`` method. @@ -144,11 +148,8 @@ def __init__( @property def name(self) -> str: - """Name of current component. Defaults to ``self.__class__.__name__``. - """ - return ( - self._name if self._name is not None else self.__class__.__name__ - ) + """Name of current component. Defaults to ``self.__class__.__name__``.""" + return self._name if self._name is not None else self.__class__.__name__ @name.setter def name(self, name: str) -> None: @@ -157,7 +158,7 @@ def name(self, name: str) -> None: @abstractmethod @monitor_exec def execute(self, *args, **kwargs) -> Any: - """"Execute some operations.""" + """Execute some operations.""" # def setup_console(self): # """Setup Python logging""" @@ -186,9 +187,9 @@ def cleanup(self): @staticmethod def _printout(msg: str): msg = f"# {msg} #" - print("#"*len(msg)) + print("#" * len(msg)) print(msg) - print("#"*len(msg)) + print("#" * len(msg)) class DataGetter(BaseComponent): @@ -213,7 +214,7 @@ def execute( self, train_dataset: MLDataset, validation_dataset: MLDataset, - test_dataset: MLDataset + test_dataset: MLDataset, ) -> Tuple[MLDataset, MLDataset, MLDataset]: """Trains a machine learning model. @@ -230,6 +231,7 @@ def execute( class DataSplitter(BaseComponent): """Splits a dataset into train, validation, and test splits.""" + _train_proportion: Union[int, float] _validation_proportion: Union[int, float] _test_proportion: Union[int, float] @@ -239,7 +241,7 @@ def __init__( train_proportion: Union[int, float], validation_proportion: Union[int, float], test_proportion: Union[int, float], - name: Optional[str] = None + name: Optional[str] = None, ) -> None: super().__init__(name) self.save_parameters(**self.locals2params(locals())) @@ -291,10 +293,7 @@ def test_proportion(self, prop: Union[int, float]) -> None: @abstractmethod @monitor_exec - def execute( - self, - dataset: MLDataset - ) -> Tuple[MLDataset, MLDataset, MLDataset]: + def execute(self, dataset: MLDataset) -> Tuple[MLDataset, MLDataset, MLDataset]: """Splits a dataset into train, validation and test splits. Args: @@ -315,7 +314,7 @@ def execute( self, train_dataset: MLDataset, validation_dataset: MLDataset, - test_dataset: MLDataset + test_dataset: MLDataset, ) -> Tuple[MLDataset, MLDataset, MLDataset, MLModel]: """Trains a machine learning model. @@ -348,9 +347,7 @@ def __init__( @abstractmethod @monitor_exec def execute( - self, - predict_dataset: MLDataset, - model: Optional[MLModel] = None + self, predict_dataset: MLDataset, model: Optional[MLModel] = None ) -> MLDataset: """Applies a machine learning model on a dataset of samples. @@ -433,26 +430,29 @@ def execute(self, *args) -> Tuple: """ result = [] for itm in self.policy: - if isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX): - arg_idx = int(itm[len(self.INPUT_PREFIX):]) - if arg_idx >= len(args): - max_idx = max(map( - lambda itm: int(itm[len(self.INPUT_PREFIX):]), + if not (isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX)): + result.append(itm) + continue + + arg_idx = int(itm[len(self.INPUT_PREFIX) :]) + if arg_idx >= len(args): + max_idx = max( + map( + lambda itm: int(itm[len(self.INPUT_PREFIX) :]), filter( lambda el: ( - isinstance(el, str) - and el.startswith(self.INPUT_PREFIX) + isinstance(el, str) and el.startswith(self.INPUT_PREFIX) ), - self.policy - ))) - raise IndexError( - f"The args received as input by '{self.name}' " - "are not consistent with the given adapter policy " - "because input args are too few! " - f"Input args are {len(args)} but the policy foresees " - f"at least {max_idx+1} items." + self.policy, + ), ) - result.append(args[arg_idx]) - else: - result.append(itm) + ) + raise IndexError( + f"The args received as input by '{self.name}' " + "are not consistent with the given adapter policy " + "because input args are too few! " + f"Input args are {len(args)} but the policy foresees " + f"at least {max_idx+1} items." + ) + result.append(args[arg_idx]) return tuple(result) diff --git a/src/itwinai/torch/distributed.py b/src/itwinai/torch/distributed.py index 6534e508..559ba2b0 100644 --- a/src/itwinai/torch/distributed.py +++ b/src/itwinai/torch/distributed.py @@ -2,8 +2,6 @@ import os from typing import Any, Iterable, List, Literal, Optional, Tuple, Union -import deepspeed -import horovod.torch as hvd import torch import torch.distributed as dist import torch.nn as nn @@ -565,10 +563,7 @@ class DeepSpeedStrategy(TorchDistributedStrategy): #: Torch distributed communication backend. backend: Literal['nccl', 'gloo', 'mpi'] - def __init__( - self, - backend: Literal['nccl', 'gloo', 'mpi'] - ) -> None: + def __init__(self, backend: Literal['nccl', 'gloo', 'mpi']) -> None: super().__init__() self.backend = backend @@ -581,6 +576,8 @@ def init(self) -> None: DistributedStrategyError: when trying to initialize a strategy already initialized. """ + import deepspeed + self.deepspeed = deepspeed if not distributed_resources_available(): raise RuntimeError( "Trying to run distributed on insufficient resources.") @@ -591,10 +588,11 @@ def init(self) -> None: # https://github.com/Lightning-AI/pytorch-lightning/issues/13567 ompi_lrank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get( - 'LOCAL_RANK', ompi_lrank) + 'LOCAL_RANK', ompi_lrank + ) # https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization - deepspeed.init_distributed(dist_backend=self.backend) + self.deepspeed.init_distributed(dist_backend=self.backend) self.is_initialized = True self.set_device() @@ -608,9 +606,10 @@ def distributed( """Setup model, optimizer and scheduler for distributed.""" if not self.is_initialized: raise UninitializedStrategyError( - "Strategy has not been initialized. Use the init method.") + "Strategy has not been initialized. Use the init method." + ) - distrib_model, optimizer, _, lr_scheduler = deepspeed.initialize( + distrib_model, optimizer, _, lr_scheduler = self.deepspeed.initialize( model=model, model_parameters=model_parameters, optimizer=optimizer, @@ -752,7 +751,11 @@ def init(self) -> None: "Trying to run distributed on insufficient resources.") if self.is_initialized: raise DistributedStrategyError("Strategy was already initialized") - hvd.init() + + import horovod.torch as hvd + self.hvd = hvd + + self.hvd.init() self.is_initialized = True self.set_device() @@ -772,16 +775,16 @@ def distributed( # Scale learning rate # https://github.com/horovod/horovod/issues/1653#issuecomment-574764452 lr_scaler = 1 - if optim_kwargs.get('op') == hvd.Adasum: - lr_scaler = hvd.local_size() - elif optim_kwargs.get('op') == hvd.Average: - lr_scaler = hvd.size() + if optim_kwargs.get('op') == self.hvd.Adasum: + lr_scaler = self.hvd.local_size() + elif optim_kwargs.get('op') == self.hvd.Average: + lr_scaler = self.hvd.size() for g in optimizer.param_groups: g['lr'] *= lr_scaler self._broadcast_params(model, optimizer) - distOptimizer = hvd.DistributedOptimizer( + distOptimizer = self.hvd.DistributedOptimizer( optimizer, named_parameters=model.named_parameters(), **optim_kwargs @@ -799,8 +802,8 @@ def _broadcast_params( optimizer (optim.Optimizer): Optimizer that is to be broadcasted across processes. """ - hvd.broadcast_parameters(model.state_dict(), root_rank=0) - hvd.broadcast_optimizer_state(optimizer, root_rank=-0) + self.hvd.broadcast_parameters(model.state_dict(), root_rank=0) + self.hvd.broadcast_optimizer_state(optimizer, root_rank=-0) def global_world_size(self) -> int: """Returns the total number of processes (global world size). @@ -811,7 +814,7 @@ def global_world_size(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.size() + return self.hvd.size() def local_world_size(self) -> int: """Returns the local number of workers available per node, @@ -823,7 +826,7 @@ def local_world_size(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.local_size() + return self.hvd.local_size() def global_rank(self) -> int: """Returns the global rank of the current process, where @@ -835,7 +838,7 @@ def global_rank(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.rank() + return self.hvd.rank() def local_rank(self) -> int: """Returns the local rank of the current process. @@ -846,14 +849,14 @@ def local_rank(self) -> int: if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - return hvd.local_rank() + return self.hvd.local_rank() def clean_up(self) -> None: """Shuts Horovod down.""" if not self.is_initialized: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method.") - hvd.shutdown() + self.hvd.shutdown() def allgather_obj(self, obj: Any) -> list[Any]: """All-gathers scalar objects across all workers to a @@ -869,7 +872,7 @@ def allgather_obj(self, obj: Any) -> list[Any]: raise UninitializedStrategyError( "Strategy has not been initialized. Use the init method." ) - return hvd.allgather_object(obj) + return self.hvd.allgather_object(obj) def gather_obj(self, obj: Any, dst_rank: int = 0) -> list[Any]: """The same as ``allgather_obj``, as gather is not supported diff --git a/src/itwinai/torch/profiling/__init__.py b/src/itwinai/torch/profiling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/itwinai/torch/profiling/communication_plot.py b/src/itwinai/torch/profiling/communication_plot.py new file mode 100644 index 00000000..23dec5a0 --- /dev/null +++ b/src/itwinai/torch/profiling/communication_plot.py @@ -0,0 +1,232 @@ +from pathlib import Path +from re import Pattern, compile +from typing import Any, List, Tuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib.patches import Patch + +# Doing this because otherwise I get an error about X11 Forwarding which I believe +# is due to the server trying to pass the image to the client computer +matplotlib.use("Agg") + +# import logging +# from logging import Logger as PythonLogger + + +def calculate_comp_and_comm_time(df: pd.DataFrame) -> Tuple[float, float]: + """Calculates the time spent computing and time spent communicating and returns a + tuple of these numbers in seconds. Assumes that you are running with an NCCL + backend. + + Raises: + ValueError: If not all expected columns ('name', 'self_cuda_time_total') are + found in the given DataFrame. + """ + expected_columns = {"name", "self_cuda_time_total"} + if not expected_columns.issubset(df.columns): + missing_columns = expected_columns - set(df.columns) + raise ValueError( + f"Invalid data format! DataFrame does not contain the necessary columns." + f"\nMissing columns: {missing_columns}" + ) + + nccl_comm_pattern = ( + r"ncclKernel_(?:AllReduce|Broadcast|Reduce|AllGather|ReduceScatter|SendRecv)" + ) + cuda_stream_pattern = r"cudaStream(?:WaitEvent|Synchronize)" + + # Any operation that is a part of PyTorch's ATen library is considered a computation + aten_comp_pattern = r"aten::" + + comm_df = df[ + (df["name"].str.contains(nccl_comm_pattern)) + | (df["name"].str.contains(cuda_stream_pattern)) + ] + comp_df = df[df["name"].str.contains(aten_comp_pattern)] + + comp_time = comp_df["self_cuda_time_total"].sum() + comm_time = comm_df["self_cuda_time_total"].sum() + + # Converting from microseconds to seconds + comp_time *= 1e-6 + comm_time *= 1e-6 + + return comp_time, comm_time + + +def create_stacked_plot( + values: np.ndarray, strategy_labels: List, gpu_numbers: List +) -> Tuple[Any, Any]: + """Creates a stacked plot showing values from 0 to 1, where the given value + will be placed on the bottom and the complement will be placed on top for + each value in 'values'. Returns the figure and the axis so that the caller can + do what they want with it, e.g. save to file, change it or just show it. + + Notes: + - Assumes that the rows of 'values' correspond to the labels in + 'strategy_labels' sorted alphabetically and that the columns correspond to + the GPU numbers in 'gpu_numbers' sorted numerically in ascending order. + """ + sns.set_theme() + + strategy_labels = sorted(strategy_labels) + gpu_numbers = sorted(gpu_numbers, key=lambda x: int(x)) + + width = 1 / (len(strategy_labels) + 1) + comp_color = "lightblue" + comm_color = "lightgreen" + complements = 1 - values + + x = np.arange(len(gpu_numbers)) + fig, ax = plt.subplots() + + # Creating an offset to "center" around zero + static_offset = len(strategy_labels) / 2 - 0.5 + for strategy_idx in range(len(strategy_labels)): + dynamic_bar_offset = strategy_idx - static_offset + + ax.bar( + x=x + dynamic_bar_offset * width, + height=values[strategy_idx], + width=width, + color=comp_color, + ) + ax.bar( + x=x + dynamic_bar_offset * width, + height=complements[strategy_idx], + width=width, + bottom=values[strategy_idx], + color=comm_color, + ) + + # Positioning the labels under the stacks + for gpu_idx in range(len(gpu_numbers)): + if np.isnan(values[strategy_idx, gpu_idx]): + continue + dynamic_label_offset = strategy_idx - static_offset + ax.text( + x=x[gpu_idx] + dynamic_label_offset * width, + y=-0.1, + s=strategy_labels[strategy_idx], + ha="center", + va="top", + fontsize=10, + rotation=60, + ) + + ax.set_ylabel("Computation fraction") + ax.set_title("Computation vs Communication Time by Method") + ax.set_xticks(x) + ax.set_xticklabels(gpu_numbers) + ax.set_ylim(0, 1) + + # Setting the appropriate colors since the legend is manual + legend_elements = [ + Patch(facecolor=comm_color, label="Communication"), + Patch(facecolor=comp_color, label="Computation"), + ] + + # Positioning the legend outside of the plot to not obstruct it + ax.legend( + handles=legend_elements, + loc="upper left", + bbox_to_anchor=(0.80, 1.22), + borderaxespad=0.0, + ) + fig.subplots_adjust(bottom=0.25) + fig.subplots_adjust(top=0.85) + return fig, ax + + +def create_combined_comm_overhead_df(logs_dir: Path, pattern: str) -> pd.DataFrame: + """Reads and combines all files in a folder that matches the given regex pattern + into a single DataFrame. The files must be formatted as csv files. + + Raises: + ValueError: If not all expected columns are found in the stored DataFrame. + ValueError: If no matching files are found in the given logging directory. + """ + re_pattern: Pattern = compile(pattern) + dataframes = [] + expected_columns = { + "strategy", + "num_gpus", + "global_rank", + "name", + "self_cuda_time_total", + } + for entry in logs_dir.iterdir(): + match = re_pattern.search(str(entry)) + if not match: + continue + + df = pd.read_csv(entry) + if not expected_columns.issubset(df.columns): + missing_columns = expected_columns - set(df.columns) + raise ValueError( + f"Invalid data format! File at '{match.string}' doesn't contain all" + f" necessary columns. \nMissing columns: {missing_columns}" + ) + + dataframes.append(df) + if len(dataframes) == 0: + raise ValueError( + f"No matching files found in '{logs_dir.resolve()}' for pattern '{pattern}'" + ) + return pd.concat(dataframes) + + +def get_comp_fraction_full_array( + df: pd.DataFrame, print_table: bool = False +) -> np.ndarray: + """Creates a MxN NumPy array where M is the number of strategies + and N is the number of GPU configurations. The strategies are sorted + alphabetically and the GPU configurations are sorted in ascending number + of GPUs. + """ + unique_num_gpus = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) + unique_strategies = sorted(df["strategy"].unique()) + values = [] + + table_string = "" + + for strategy in unique_strategies: + strategy_values = [] + for num_gpus in unique_num_gpus: + filtered_df = df[ + (df["strategy"] == strategy) & (df["num_gpus"] == num_gpus) + ] + + row_string = f"{strategy:>12} | {num_gpus:>10}" + + # Allows asymmetric testing, i.e. not testing all num gpus and all + # strategies together + if len(filtered_df) == 0: + comp_time, comm_time = np.NaN, np.NaN + strategy_values.append(np.NaN) + + row_string += f" | {'(NO DATA)':>15}" + else: + comp_time, comm_time = calculate_comp_and_comm_time(df=filtered_df) + # Avoid division-by-zero errors (1e-10) + comp_fraction = comp_time / (comp_time + comm_time + 1e-10) + strategy_values.append(comp_fraction) + + row_string += f" | {comp_time:>8.2f}s" + row_string += f" | {comm_time:>8.2f}s" + + table_string += row_string + "\n" + values.append(strategy_values) + + if print_table: + print(f"{'-'*50}") + print(f"{'Strategy':>12} | {'Num. GPUs':>10} | {'Comp.':>9} | {'Comm.':>8}") + print(f"{'-'*50}") + print(table_string) + print(f"{'-'*50}") + + return np.array(values) diff --git a/src/itwinai/torch/profiling/profiler.py b/src/itwinai/torch/profiling/profiler.py new file mode 100644 index 00000000..78c740e7 --- /dev/null +++ b/src/itwinai/torch/profiling/profiler.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import functools +from pathlib import Path +from typing import Any, Callable, Iterable + +import matplotlib +import pandas as pd +from torch.profiler import ProfilerActivity, profile, schedule + +from itwinai.torch.distributed import ( + DeepSpeedStrategy, + HorovodStrategy, + NonDistributedStrategy, + TorchDDPStrategy, +) +from itwinai.torch.trainer import TorchTrainer + +# Doing this because otherwise I get an error about X11 Forwarding which I believe +# is due to the server trying to pass the image to the client computer +matplotlib.use("Agg") + + +def profile_torch_trainer(method: Callable) -> Callable: + """Decorator for execute method for components. Profiles the communication time + vs. computation time and stores the result for future analysis. + """ + + def gather_profiling_data(key_averages: Iterable) -> pd.DataFrame: + profiling_data = [] + for event in key_averages: + profiling_data.append( + { + "name": event.key, + "node_id": event.node_id, + "self_cpu_time_total": event.self_cpu_time_total, + "cpu_time_total": event.cpu_time_total, + "cpu_time_total_str": event.cpu_time_total_str, + "self_cuda_time_total": event.self_cuda_time_total, + "cuda_time_total": event.cuda_time_total, + "cuda_time_total_str": event.cuda_time_total_str, + "calls": event.count, + } + ) + return pd.DataFrame(profiling_data) + + @functools.wraps(method) + def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: + + profiler = profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], + with_modules=True, + schedule=schedule( + # skip_first=1 + wait=1, + warmup=2, + active=100, + ), + ) + profiler.start() + + self.profiler = profiler + try: + result = method(self, *args, **kwargs) + finally: + profiler.stop() + + strategy = self.strategy + if isinstance(strategy, NonDistributedStrategy): + strategy_str = "non-dist" + elif isinstance(strategy, TorchDDPStrategy): + strategy_str = "ddp" + elif isinstance(strategy, DeepSpeedStrategy): + strategy_str = "deepspeed" + elif isinstance(strategy, HorovodStrategy): + strategy_str = "horovod" + else: + strategy_str = "unk" + + global_rank = strategy.global_rank() + num_gpus_global = strategy.global_world_size() + + # Extracting and storing the profiling data + key_averages = profiler.key_averages() + profiling_dataframe = gather_profiling_data(key_averages=key_averages) + profiling_dataframe["strategy"] = strategy_str + profiling_dataframe["num_gpus"] = num_gpus_global + profiling_dataframe["global_rank"] = global_rank + + profiling_log_dir = Path("profiling_logs") + profiling_log_dir.mkdir(parents=True, exist_ok=True) + + filename = f"profile_{strategy_str}_{num_gpus_global}_{global_rank}.csv" + output_path = profiling_log_dir / filename + + print(f"Writing profiling dataframe to {output_path}") + profiling_dataframe.to_csv(output_path) + strategy.clean_up() + + return result + + return profiled_method diff --git a/src/itwinai/torch/trainer.py b/src/itwinai/torch/trainer.py index 121372bc..fa73289a 100644 --- a/src/itwinai/torch/trainer.py +++ b/src/itwinai/torch/trainer.py @@ -98,6 +98,8 @@ class TorchTrainer(Trainer, LogMixin): test_glob_step: int = 0 #: Dictionary of ``torchmetrics`` metrics, indexed by user-defined names. metrics: Dict[str, Metric] + #: PyTorch Profiler for communication vs. computation comparison + profiler: Optional[Any] def __init__( self, @@ -134,6 +136,7 @@ def __init__( self.checkpoints_location = checkpoints_location os.makedirs(self.checkpoints_location, exist_ok=True) self.checkpoint_every = checkpoint_every + self.profiler = None @property def strategy(self) -> TorchDistributedStrategy: @@ -372,7 +375,7 @@ def execute( if self.logger: self.logger.destroy_logger_context() - self.strategy.clean_up() + # self.strategy.clean_up() return train_dataset, validation_dataset, test_dataset, self.model def _set_epoch_dataloaders(self, epoch: int): @@ -392,6 +395,8 @@ def set_epoch(self, epoch: int) -> None: Args: epoch (int): epoch number, from 0 to ``epochs-1``. """ + if self.profiler is not None: + self.profiler.step() self._set_epoch_dataloaders(epoch) def log( @@ -520,10 +525,8 @@ def train(self): val_loss = self.validation_epoch(epoch) # Checkpointing current best model - worker_val_losses = self.strategy.gather( - val_loss, dst_rank=0 - ) - if self.strategy.global_rank() == 0: + worker_val_losses = self.strategy.gather(val_loss, dst_rank=0) + if self.strategy.is_main_worker: avg_loss = torch.mean( torch.stack(worker_val_losses) ).detach().cpu() @@ -628,7 +631,7 @@ def train_step( ) return loss, metrics - def validation_epoch(self, epoch: int) -> torch.Tensor: + def validation_epoch(self, epoch: int) -> Optional[torch.Tensor]: """Perform a complete sweep over the validation dataset, completing an epoch of validation. @@ -636,43 +639,45 @@ def validation_epoch(self, epoch: int) -> torch.Tensor: epoch (int): current epoch number, from 0 to ``self.epochs - 1``. Returns: - Loss: average validation loss for the current epoch. + Optional[Loss]: average validation loss for the current epoch if + self.validation_dataloader is not None """ - if self.validation_dataloader is not None: - self.model.eval() - validation_losses = [] - validation_metrics = [] - for batch_idx, val_batch \ - in enumerate(self.validation_dataloader): - loss, metrics = self.validation_step( - batch=val_batch, - batch_idx=batch_idx - ) - validation_losses.append(loss) - validation_metrics.append(metrics) + if self.validation_dataloader is None: + return + + self.model.eval() + validation_losses = [] + validation_metrics = [] + for batch_idx, val_batch in enumerate(self.validation_dataloader): + loss, metrics = self.validation_step( + batch=val_batch, + batch_idx=batch_idx + ) + validation_losses.append(loss) + validation_metrics.append(metrics) - # Important: update counter - self.validation_glob_step += 1 + # Important: update counter + self.validation_glob_step += 1 - # Aggregate and log losses - avg_loss = torch.mean(torch.stack(validation_losses)) + # Aggregate and log losses + avg_loss = torch.mean(torch.stack(validation_losses)) + self.log( + item=avg_loss.item(), + identifier='validation_loss_epoch', + kind='metric', + step=self.validation_glob_step, + ) + # Aggregate and log metrics + avg_metrics = pd.DataFrame(validation_metrics).mean().to_dict() + for m_name, m_val in avg_metrics.items(): self.log( - item=avg_loss.item(), - identifier='validation_loss_epoch', + item=m_val, + identifier='validation_' + m_name + '_epoch', kind='metric', step=self.validation_glob_step, ) - # Aggregate and log metrics - avg_metrics = pd.DataFrame(validation_metrics).mean().to_dict() - for m_name, m_val in avg_metrics.items(): - self.log( - item=m_val, - identifier='validation_' + m_name + '_epoch', - kind='metric', - step=self.validation_glob_step, - ) - return avg_loss + return avg_loss def validation_step( self, diff --git a/use-cases/eurac/config.yaml b/use-cases/eurac/config.yaml index 64ee45f7..8912e898 100644 --- a/use-cases/eurac/config.yaml +++ b/use-cases/eurac/config.yaml @@ -6,7 +6,7 @@ tmp_stats: /p/scratch/intertwin/datasets/eurac/stats experiment: "drought use case lstm" run_name: "alps_test" -epochs: 2 +epochs: 5 random_seed: 1010 lr: 0.001 batch_size: 256 diff --git a/use-cases/eurac/plots/comm_plot.png b/use-cases/eurac/plots/comm_plot.png new file mode 100644 index 00000000..b406b1a0 Binary files /dev/null and b/use-cases/eurac/plots/comm_plot.png differ diff --git a/use-cases/eurac/runall.sh b/use-cases/eurac/runall.sh index 6169366e..03a4798d 100755 --- a/use-cases/eurac/runall.sh +++ b/use-cases/eurac/runall.sh @@ -12,7 +12,7 @@ if [ -z "$NUM_GPUS" ]; then NUM_GPUS=4 fi if [ -z "$TIME" ]; then - TIME=0:20:00 + TIME=0:40:00 fi if [ -z "$DEBUG" ]; then DEBUG=false diff --git a/use-cases/eurac/slurm.sh b/use-cases/eurac/slurm.sh index e1ec58b1..e907e54c 100644 --- a/use-cases/eurac/slurm.sh +++ b/use-cases/eurac/slurm.sh @@ -100,7 +100,7 @@ if [ "$DIST_MODE" == "horovod" ] ; then srun --cpu-bind=none \ --ntasks-per-node=$SLURM_GPUS_PER_NODE \ --cpus-per-task=$SLURM_CPUS_PER_GPU \ - --ntasks=$SLURM_GPUS_PER_NODE \ + --ntasks=$(($SLURM_GPUS_PER_NODE * $SLURM_NNODES)) \ $TRAINING_CMD else # E.g. for 'deepspeed' or 'ddp' srun --cpu-bind=none --ntasks-per-node=1 \ diff --git a/use-cases/eurac/trainer.py b/use-cases/eurac/trainer.py index 53c50202..88ac42f5 100644 --- a/use-cases/eurac/trainer.py +++ b/use-cases/eurac/trainer.py @@ -1,7 +1,7 @@ import os from pathlib import Path from timeit import default_timer as timer -from typing import Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional, Union, Any, Tuple import pandas as pd import torch @@ -13,8 +13,10 @@ from hython.trainer import ConvTrainer, RNNTrainer, RNNTrainParams from ray import train from torch.optim.lr_scheduler import ReduceLROnPlateau +from torch.utils.data import Dataset from tqdm.auto import tqdm +from itwinai.distributed import suppress_workers_print from itwinai.loggers import EpochTimeTracker, Logger from itwinai.torch.config import TrainingConfiguration from itwinai.torch.distributed import ( @@ -25,6 +27,7 @@ ) from itwinai.torch.trainer import TorchTrainer from itwinai.torch.type import Metric +from itwinai.torch.profiling.profiler import profile_torch_trainer class RNNDistributedTrainer(TorchTrainer): @@ -88,6 +91,16 @@ def __init__( ) self.save_parameters(**self.locals2params(locals())) + @suppress_workers_print + @profile_torch_trainer + def execute( + self, + train_dataset: Dataset, + validation_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None + ) -> Tuple[Dataset, Dataset, Dataset, Any]: + return super().execute(train_dataset, validation_dataset, test_dataset) + def create_model_loss_optimizer(self) -> None: self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr) self.lr_scheduler = ReduceLROnPlateau( @@ -125,6 +138,14 @@ def create_model_loss_optimizer(self) -> None: **distribute_kwargs, ) + def set_epoch(self, epoch: int): + if self.profiler is not None: + self.profiler.step() + + if self.strategy.is_distributed: + self.train_loader.sampler.set_epoch(epoch) + self.val_loader.sampler.set_epoch(epoch) + def train(self): """Override version of hython to support distributed strategy.""" # Tracking epoch times for scaling test @@ -158,11 +179,7 @@ def train(self): best_loss = float("inf") for epoch in tqdm(range(self.epochs)): epoch_start_time = timer() - if self.strategy.is_distributed: - # *Added for distributed* - self.train_loader.sampler.set_epoch(epoch) - self.val_loader.sampler.set_epoch(epoch) - + self.set_epoch(epoch) self.model.train() # set time indices for training @@ -368,7 +385,6 @@ def create_model_loss_optimizer(self) -> None: patience=self.config.lr_reduction_patience ) - target_weights = { t: 1 / len(self.config.target_names) for t in self.config.target_names } @@ -489,7 +505,7 @@ def create_dataloaders(self, train_dataset, validation_dataset, test_dataset): processing=( "multi-gpu" if self.config.distributed else "single-gpu" ), - ) + ) val_sampler_builder = SamplerBuilder( validation_dataset, diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index 8226e070..5fd2a3f9 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -43,7 +43,7 @@ def __init__( ) -> None: super().__init__( epochs=num_epochs, - config={}, + config=config, strategy=strategy, logger=logger, random_seed=random_seed,