diff --git a/.gitignore b/.gitignore index 258769c78..57308e72e 100644 --- a/.gitignore +++ b/.gitignore @@ -202,4 +202,7 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ \ No newline at end of file +#.idea/ + +# MacOS +.DS_Store diff --git a/pyproject.toml b/pyproject.toml index d448046c2..b5a8af05f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ # "prov4ml@git+https://github.com/HPCI-Lab/ProvML@main", # Prov4ML # "prov4ml@git+https://github.com/matbun/ProvML@main", "pandas", + "seaborn" ] # dynamic = ["version", "description"] diff --git a/src/itwinai/cli.py b/src/itwinai/cli.py index 9bb6b27ef..922ffb686 100644 --- a/src/itwinai/cli.py +++ b/src/itwinai/cli.py @@ -1,14 +1,22 @@ -""" -Command line interface for out Python application. -You can call commands from the command line. -Example - ->>> $ itwinai --help - -""" - -# NOTE: import libs in the command"s function, not here. -# Otherwise this will slow the whole CLI. +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Matteo Bunino +# +# Credit: +# - Matteo Bunino - CERN +# - Jarl Sondre Sæther - CERN +# +# -------------------------------------------------------------------------------------- +# Command-line interface for the itwinai Python library. +# Example: +# +# >>> itwinai --help +# +# -------------------------------------------------------------------------------------- +# +# NOTE: import libraries in the command's function, not here, as having them here will +# slow down the CLI commands significantly. from pathlib import Path from typing import List, Optional @@ -20,13 +28,18 @@ @app.command() -def generate_gpu_energy_plot( - log_dir: str = "scalability_metrics/gpu_energy_data", - pattern: str = r"gpu_energy_data.*\.csv$", - output_file: str = "plots/gpu_energy_plot.png", +def generate_gpu_data_plots( + log_dir: str = "scalability-metrics/gpu-energy-data", + pattern: str = r".*\.csv$", + plot_dir: str = "plots/", + do_backup: bool = False, + backup_dir: str = "backup-scalability-metrics/", + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> None: - """Generate a GPU energy plot showing the expenditure for each combination of - strategy and number of GPUs in Watt hours. + """Generate GPU energy and utilization plots showing the expenditure for each + combination of strategy and number of GPUs in Watt hours and total computing + percentage. Backs up the data used to create the plot if ``backup_dir`` is not None Args: log_dir: The directory where the csv logs are stored. Defaults to @@ -34,13 +47,27 @@ def generate_gpu_energy_plot( pattern: A regex pattern to recognize the file names in the 'log_dir' folder. Defaults to ``dataframe_(?:\\w+)_(?:\\d+)\\.csv$``. Set it to 'None' to make it None. In this case, it will match all files in the given folder. - output_file: The path to where the resulting plot should be saved. Defaults to - ``plots/gpu_energy_plot.png``. + plot_dir: The directory where the resulting plots should be saved. Defaults to + ``plots/``. + do_backup: Whether to backup the data used for making the plot or not. + backup_dir: The path to where the data used to produce the plot should be + saved. + experiment_name: The name of the experiment to be used when creating a backup + of the data used for the plot. + run_name: The name of the run to be used when creating a backup of the data + used for the plot. """ - import matplotlib.pyplot as plt - from itwinai.torch.monitoring.plotting import gpu_energy_plot, read_energy_df + from itwinai.scalability import ( + backup_scalability_metrics, + convert_matching_files_to_dataframe, + ) + from itwinai.torch.monitoring.plotting import ( + calculate_average_gpu_utilization, + calculate_total_energy_expenditure, + gpu_bar_plot, + ) log_dir_path = Path(log_dir) if not log_dir_path.exists(): @@ -48,24 +75,60 @@ def generate_gpu_energy_plot( f"The provided log_dir, '{log_dir_path.resolve()}', does not exist." ) + plot_dir_path = Path(plot_dir) if pattern.lower() == "none": pattern = None - gpu_utilization_df = read_energy_df(pattern=pattern, log_dir=log_dir_path) - gpu_energy_plot(gpu_utilization_df=gpu_utilization_df) + gpu_data_df = convert_matching_files_to_dataframe( + pattern=pattern, log_dir=log_dir_path + ) - output_path = Path(output_file) - output_path.parent.mkdir(parents=True, exist_ok=True) + energy_df = calculate_total_energy_expenditure(gpu_data_df=gpu_data_df) + utilization_df = calculate_average_gpu_utilization(gpu_data_df=gpu_data_df) + + plot_dir_path.mkdir(parents=True, exist_ok=True) + energy_plot_path = plot_dir_path / "gpu_energy_plot.png" + utilization_plot_path = plot_dir_path / "utilization_plot.png" + + energy_fig, _ = gpu_bar_plot( + data_df=energy_df, + plot_title="Energy Consumption by Strategy and Number of GPUs", + y_label="Energy Consumption (Wh)", + main_column="total_energy_wh", + ) + utilization_fig, _ = gpu_bar_plot( + data_df=utilization_df, + plot_title="GPU Utilization by Strategy and Number of GPUs", + y_label="GPU Utilization (%)", + main_column="utilization", + ) + + energy_fig.savefig(energy_plot_path) + utilization_fig.savefig(utilization_plot_path) + print(f"Saved GPU energy plot at '{energy_plot_path.resolve()}'.") + print(f"Saved utilization plot at '{utilization_plot_path.resolve()}'.") - plt.savefig(output_path) - print(f"\nSaved GPU energy plot at '{output_path.resolve()}'.") + if not do_backup: + return + + backup_scalability_metrics( + experiment_name=experiment_name, + run_name=run_name, + backup_dir=backup_dir, + metric_df=gpu_data_df, + filename="gpu_data.csv", + ) @app.command() def generate_communication_plot( - log_dir: str = "scalability_metrics/communication_data", + log_dir: str = "scalability-metrics/communication-data", pattern: str = r"(.+)_(\d+)_(\d+)\.csv$", output_file: str = "plots/communication_plot.png", + do_backup: bool = False, + backup_dir: str = "backup-scalability-metrics/", + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, ) -> None: """Generate stacked plot showing computation vs. communication fraction. Stores it to output_file. @@ -78,45 +141,124 @@ def generate_communication_plot( make it None. In this case, it will match all files in the given folder. output_file: The path to where the resulting plot should be saved. Defaults to ``plots/comm_plot.png``. + do_backup: Whether to backup the data used for making the plot or not. + backup_dir: The path to where the data used to produce the plot should be + saved. + experiment_name: The name of the experiment to be used when creating a backup + of the data used for the plot. + run_name: The name of the run to be used when creating a backup of the data + used for the plot. """ - import matplotlib.pyplot as plt + from itwinai.scalability import ( + backup_scalability_metrics, + convert_matching_files_to_dataframe, + ) from itwinai.torch.profiling.communication_plot import ( - create_combined_comm_overhead_df, - create_stacked_plot, + communication_overhead_stacked_bar_plot, get_comp_fraction_full_array, ) log_dir_path = Path(log_dir) if not log_dir_path.exists(): raise ValueError( - f"The directory '{log_dir_path.resolve()}' does not exist, so could not" - f"extract profiling logs. Make sure you are running this command in the " - f"same directory as the logging dir or are passing a sufficient relative" - f"path." + f"The provided directory, '{log_dir_path.resolve()}', does not exist." ) if pattern.lower() == "none": pattern = None - df = create_combined_comm_overhead_df(log_dir=log_dir_path, pattern=pattern) - values = get_comp_fraction_full_array(df, print_table=True) - - strategies = sorted(df["strategy"].unique()) - gpu_numbers = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) + expected_columns = { + "strategy", + "num_gpus", + "global_rank", + "name", + "self_cuda_time_total", + } + communication_df = convert_matching_files_to_dataframe( + log_dir=log_dir_path, pattern=pattern, expected_columns=expected_columns + ) + values = get_comp_fraction_full_array(communication_df, print_table=True) - fig, _ = create_stacked_plot(values, strategies, gpu_numbers) + strategies = sorted(communication_df["strategy"].unique()) + gpu_numbers = sorted(communication_df["num_gpus"].unique(), key=lambda x: int(x)) - # TODO: set these dynamically? - fig.set_figwidth(8) - fig.set_figheight(6) + fig, _ = communication_overhead_stacked_bar_plot(values, strategies, gpu_numbers) output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_path) + fig.savefig(output_path) print(f"\nSaved computation vs. communication plot at '{output_path.resolve()}'.") + if not do_backup: + return + + backup_scalability_metrics( + experiment_name=experiment_name, + run_name=run_name, + backup_dir=backup_dir, + metric_df=communication_df, + filename="communication_data.csv", + ) + + +@app.command() +def generate_scalability_plot( + pattern: str = "None", + log_dir: str = "scalability-metrics/epoch-time", + do_backup: bool = False, + backup_dir: str = "backup-scalability-metrics/", + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, +) -> None: + """Creates two scalability plots from measured wall-clock times of an experiment + run and saves them to file. Uses pattern to filter out files if given, otherwise + it will try to use all files it finds in the given log directory. Will store all + the data that was used as a backup file if do_backup is provided. + """ + + from itwinai.scalability import ( # archive_data, + backup_scalability_metrics, + convert_matching_files_to_dataframe, + create_absolute_plot, + create_relative_plot, + ) + + log_dir_path = Path(log_dir) + if pattern.lower() == "none": + pattern = None + + expected_columns = {"name", "nodes", "epoch_id", "time"} + combined_df = convert_matching_files_to_dataframe( + log_dir=log_dir_path, pattern=pattern, expected_columns=expected_columns + ) + print("Merged CSV:") + print(combined_df) + + avg_time_df = ( + combined_df.drop(columns="epoch_id") + .groupby(["name", "nodes"]) + .mean() + .reset_index() + ) + print("\nAvg over name and nodes:") + print(avg_time_df.rename(columns=dict(time="avg(time)"))) + + create_absolute_plot(avg_time_df) + create_relative_plot(avg_time_df) + + if not do_backup: + return + + backup_scalability_metrics( + experiment_name=experiment_name, + run_name=run_name, + backup_dir=backup_dir, + metric_df=combined_df, + filename="epoch_time.csv", + ) + @app.command() def sanity_check( @@ -149,169 +291,6 @@ def sanity_check( sanity_check_slim() -@app.command() -def scalability_report( - pattern: Annotated[ - str, typer.Option(help="Python pattern matching names of CSVs in sub-folders.") - ], - plot_title: Annotated[Optional[str], typer.Option(help=("Plot name."))] = None, - skip_id: Annotated[Optional[int], typer.Option(help=("Skip epoch ID."))] = None, - archive: Annotated[ - Optional[str], - typer.Option(help=("Archive name to backup the data, without extension.")), - ] = None, -): - """ - Generate scalability report merging all CSVs containing epoch time - records in sub-folders. - - Example: - - >>> itwinai scalability-report --pattern="^epoch.+\\.csv$" --skip-id 0 \\ - >>> --plot-title "Some title" --archive archive_name - - """ - # TODO: add max depth and path different from CWD - import glob - import os - import re - import shutil - - import matplotlib - import matplotlib.pyplot as plt - import numpy as np - import pandas as pd - - regex = re.compile(r"{}".format(pattern)) - combined_df = pd.DataFrame() - csv_files = [] - for root, _, files in os.walk(os.getcwd()): - for file in files: - if regex.match(file): - fpath = os.path.join(root, file) - csv_files.append(fpath) - df = pd.read_csv(fpath) - if skip_id is not None: - df = df.drop(df[df.epoch_id == skip_id].index) - combined_df = pd.concat([combined_df, df]) - print("Merged CSV:") - print(combined_df) - - avg_times = ( - combined_df.drop(columns="epoch_id") - .groupby(["name", "nodes"]) - .mean() - .reset_index() - ) - print("\nAvg over name and nodes:") - print(avg_times.rename(columns=dict(time="avg(time)"))) - - # fig, (sp_up_ax, eff_ax) = plt.subplots(1, 2, figsize=(12, 4)) - fig, sp_up_ax = plt.subplots(1, 1, figsize=(6, 4)) - if plot_title is not None: - fig.suptitle(plot_title) - - sp_up_ax.set_yscale("log") - sp_up_ax.set_xscale("log") - - markers = iter("ov^s*dXpD.+12348") - - series_names = sorted(set(avg_times.name.values)) - for name in series_names: - df = avg_times[avg_times.name == name].drop(columns="name") - - # Debug - # compute_time = [3791., 1884., 1011., 598.] - # nodes = [1, 2, 4, 8] - # d = {'nodes': nodes, 'time': compute_time} - # df = pd.DataFrame(data=d) - - df["NGPUs"] = df["nodes"] * 4 - # speedup - df["Speedup - ideal"] = df["nodes"].astype(float) - df["Speedup"] = df["time"].iloc[0] / df["time"] - df["Nworkers"] = 1 - - # efficiency - df["Threadscaled Sim. Time / s"] = df["time"] * df["nodes"] * df["Nworkers"] - df["Efficiency"] = ( - df["Threadscaled Sim. Time / s"].iloc[0] / df["Threadscaled Sim. Time / s"] - ) - - sp_up_ax.plot( - df["NGPUs"].values, - df["Speedup"].values, - marker=next(markers), - lw=1.0, - label=name, - alpha=0.7, - ) - - sp_up_ax.plot( - df["NGPUs"].values, - df["Speedup - ideal"].values, - ls="dashed", - lw=1.0, - c="k", - label="ideal", - ) - sp_up_ax.legend(ncol=1) - - sp_up_ax.set_xticks(df["NGPUs"].values) - sp_up_ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - - sp_up_ax.set_ylabel("Speedup") - sp_up_ax.set_xlabel("NGPUs (4 per node)") - sp_up_ax.grid() - - # Sort legend - handles, labels = sp_up_ax.get_legend_handles_labels() - order = np.argsort(labels) - plt.legend([handles[idx] for idx in order], [labels[idx] for idx in order]) - - plot_png = f"scaling_plot_{plot_title}.png" - plt.tight_layout() - plt.savefig(plot_png, bbox_inches="tight", format="png", dpi=300) - print("Saved scaling plot to: ", plot_png) - - if archive is not None: - if "/" in archive: - raise ValueError( - f"Archive name must NOT contain a path. Received: '{archive}'" - ) - if "." in archive: - raise ValueError( - f"Archive name must NOT contain an extension. Received: '{archive}'" - ) - if os.path.isdir(archive): - raise ValueError(f"Folder '{archive}' already exists. Change archive name.") - os.makedirs(archive) - for csvfile in csv_files: - shutil.copyfile(csvfile, os.path.join(archive, os.path.basename(csvfile))) - shutil.copyfile(plot_png, os.path.join(archive, plot_png)) - avg_times.to_csv(os.path.join(archive, "avg_times.csv"), index=False) - print("Archived AVG epoch times CSV") - - # Copy SLURM logs: *.err *.out files - if os.path.exists("logs_slurm"): - print("Archived SLURM logs") - shutil.copytree("logs_slurm", os.path.join(archive, "logs_slurm")) - # Copy other SLURM logs - for ext in ["*.out", "*.err"]: - for file in glob.glob(ext): - shutil.copyfile(file, os.path.join(archive, file)) - - # Create archive - archive_name = shutil.make_archive( - base_name=archive, # archive file name - format="gztar", - # root_dir='.', - base_dir=archive, # folder path inside archive - ) - shutil.rmtree(archive) - print("Archived logs and plot at: ", archive_name) - - @app.command() def exec_pipeline( config: Annotated[ diff --git a/src/itwinai/loggers.py b/src/itwinai/loggers.py index e827d0053..5012f2383 100644 --- a/src/itwinai/loggers.py +++ b/src/itwinai/loggers.py @@ -55,20 +55,21 @@ https://docs.wandb.ai/ref/python/watch """ -import csv import os import pathlib import pickle from abc import ABC, abstractmethod from contextlib import contextmanager +from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union import mlflow +import pandas as pd import prov4ml import wandb from typing_extensions import override -BASE_EXP_NAME: str = 'default_experiment' +BASE_EXP_NAME: str = "default_experiment" class LogMixin(ABC): @@ -77,10 +78,10 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Log ``item`` with ``identifier`` name of ``kind`` type at ``step`` time step. @@ -122,6 +123,7 @@ class Logger(LogMixin): if List[int], log on workers which rank is in the list. Defaults to 0 (the global rank of the main worker). """ + #: Location on filesystem where to store data. savedir: str = None #: Supported logging 'kind's. @@ -129,28 +131,28 @@ class Logger(LogMixin): #: Current worker global rank worker_rank: int - _log_freq: Union[int, Literal['epoch', 'batch']] + _log_freq: Union[int, Literal["epoch", "batch"]] def __init__( self, - savedir: str = 'mllogs', - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - log_on_workers: Union[int, List[int]] = 0 + savedir: str = "mllogs", + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + log_on_workers: Union[int, List[int]] = 0, ) -> None: self.savedir = savedir self.log_freq = log_freq self.log_on_workers = log_on_workers @property - def log_freq(self) -> Union[int, Literal['epoch', 'batch']]: + def log_freq(self) -> Union[int, Literal["epoch", "batch"]]: """Get ``log_feq``, namely how often should the logger fulfill or ignore calls to the `log()` method.""" return self._log_freq @log_freq.setter - def log_freq(self, val: Union[int, Literal['epoch', 'batch']]): + def log_freq(self, val: Union[int, Literal["epoch", "batch"]]): """Sanitize log_freq value.""" - if val in ['epoch', 'batch'] or (isinstance(val, int) and val > 0): + if val in ["epoch", "batch"] or (isinstance(val, int) and val > 0): self._log_freq = val else: raise ValueError( @@ -214,13 +216,10 @@ def serialize(self, obj: Any, identifier: str) -> str: str: local path of the serialized object to be logged. """ itm_path = os.path.join(self.savedir, identifier) - with open(itm_path, 'wb') as itm_file: + with open(itm_path, "wb") as itm_file: pickle.dump(obj, itm_file) - def should_log( - self, - batch_idx: Optional[int] = None - ) -> bool: + def should_log(self, batch_idx: Optional[int] = None) -> bool: """Determines whether the logger should fulfill or ignore calls to the `log()` method, depending on the ``log_freq`` property: @@ -245,15 +244,17 @@ def should_log( """ # Check worker's global rank worker_ok = ( - self.worker_rank is None or - (isinstance(self.log_on_workers, int) and ( - self.log_on_workers == -1 or - self.log_on_workers == self.worker_rank + self.worker_rank is None + or ( + isinstance(self.log_on_workers, int) + and ( + self.log_on_workers == -1 or self.log_on_workers == self.worker_rank + ) ) + or ( + isinstance(self.log_on_workers, list) + and self.worker_rank in self.log_on_workers ) - or - (isinstance(self.log_on_workers, list) - and self.worker_rank in self.log_on_workers) ) if not worker_ok: return False @@ -264,7 +265,7 @@ def should_log( if batch_idx % self.log_freq == 0: return True return False - if self.log_freq == 'batch': + if self.log_freq == "batch": return True return False return True @@ -277,9 +278,9 @@ class _EmptyLogger(Logger): def __init__( self, - savedir: str = 'mllogs', - log_freq: int | Literal['epoch'] | Literal['batch'] = 'epoch', - log_on_workers: int | List[int] = 0 + savedir: str = "mllogs", + log_freq: int | Literal["epoch"] | Literal["batch"] = "epoch", + log_on_workers: int | List[int] = 0, ) -> None: super().__init__(savedir, log_freq, log_on_workers) @@ -296,10 +297,10 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: pass @@ -321,19 +322,17 @@ class ConsoleLogger(Logger): """ #: Supported kinds in the ``log`` method - supported_kinds: Tuple[str] = ('torch', 'artifact', 'metric') + supported_kinds: Tuple[str] = ("torch", "artifact", "metric") def __init__( self, - savedir: str = 'mllogs', - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - log_on_workers: Union[int, List[int]] = 0 + savedir: str = "mllogs", + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = os.path.join(savedir, 'simple-logger') + savedir = os.path.join(savedir, "simple-logger") super().__init__( - savedir=savedir, - log_freq=log_freq, - log_on_workers=log_on_workers + savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers ) def create_logger_context(self, rank: Optional[int] = None): @@ -376,10 +375,10 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Print metrics to stdout and save artifacts to the filesystem. @@ -398,28 +397,24 @@ def log( if not self.should_log(batch_idx=batch_idx): return - if kind == 'artifact': + if kind == "artifact": if isinstance(item, str) and os.path.isfile(item): import shutil - identifier = os.path.join( - self.run_path, - identifier - ) + + identifier = os.path.join(self.run_path, identifier) if len(os.path.dirname(identifier)) > 0: os.makedirs(os.path.dirname(identifier), exist_ok=True) print(f"ConsoleLogger: Serializing to {identifier}...") shutil.copyfile(item, identifier) else: - identifier = os.path.join( - os.path.basename(self.run_path), - identifier - ) + identifier = os.path.join(os.path.basename(self.run_path), identifier) print(f"ConsoleLogger: Serializing to {identifier}...") self.serialize(item, identifier) - elif kind == 'torch': + elif kind == "torch": identifier = os.path.join(self.run_path, identifier) print(f"ConsoleLogger: Saving to {identifier}...") import torch + torch.save(item, identifier) else: print(f"ConsoleLogger: {identifier} = {item}") @@ -451,8 +446,16 @@ class MLFlowLogger(Logger): #: Supported kinds in the ``log`` method supported_kinds: Tuple[str] = ( - 'metric', 'figure', 'image', 'artifact', 'torch', 'dict', 'param', - 'text', 'model', 'dataset' + "metric", + "figure", + "image", + "artifact", + "torch", + "dict", + "param", + "text", + "model", + "dataset", ) #: Current MLFLow experiment's run. @@ -460,19 +463,17 @@ class MLFlowLogger(Logger): def __init__( self, - savedir: str = 'mllogs', + savedir: str = "mllogs", experiment_name: str = BASE_EXP_NAME, tracking_uri: Optional[str] = None, run_description: Optional[str] = None, run_name: Optional[str] = None, - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - log_on_workers: Union[int, List[int]] = 0 + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + log_on_workers: Union[int, List[int]] = 0, ): - savedir = os.path.join(savedir, 'mlflow') + savedir = os.path.join(savedir, "mlflow") super().__init__( - savedir=savedir, - log_freq=log_freq, - log_on_workers=log_on_workers + savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers ) self.experiment_name = experiment_name self.tracking_uri = tracking_uri @@ -487,10 +488,7 @@ def __init__( # TODO: for pytorch lightning: # mlflow.pytorch.autolog() - def create_logger_context( - self, - rank: Optional[int] = None - ) -> mlflow.ActiveRun: + def create_logger_context(self, rank: Optional[int] = None) -> mlflow.ActiveRun: """ Initializes the logger context. Start MLFLow run. @@ -514,8 +512,7 @@ def create_logger_context( mlflow.set_tracking_uri(self.tracking_uri) mlflow.set_experiment(experiment_name=self.experiment_name) self.active_run: mlflow.ActiveRun = mlflow.start_run( - description=self.run_description, - run_name=self.run_name + description=self.run_description, run_name=self.run_name ) return self.active_run @@ -536,16 +533,16 @@ def save_hyperparameters(self, params: Dict[str, Any]) -> None: return for param_name, val in params.items(): - self.log(item=val, identifier=param_name, step=0, kind='param') + self.log(item=val, identifier=param_name, step=0, kind="param") def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Log with MLFlow. @@ -564,31 +561,25 @@ def log( if not self.should_log(batch_idx=batch_idx): return - if kind == 'metric': + if kind == "metric": # if isinstance(item, list) and isinstance(identifier, list): - mlflow.log_metric( - key=identifier, - value=item, - step=step - ) - if kind == 'artifact': + mlflow.log_metric(key=identifier, value=item, step=step) + if kind == "artifact": if not isinstance(item, str): # Save the object locally and then log it name = os.path.basename(identifier) - save_path = os.path.join(self.savedir, '.trash', name) + save_path = os.path.join(self.savedir, ".trash", name) os.makedirs(os.path.dirname(save_path), exist_ok=True) item = self.serialize(item, save_path) - mlflow.log_artifact( - local_path=item, - artifact_path=identifier - ) - if kind == 'model': + mlflow.log_artifact(local_path=item, artifact_path=identifier) + if kind == "model": import torch + if isinstance(item, torch.nn.Module): mlflow.pytorch.log_model(item, identifier) else: print("WARNING: unrecognized model type") - if kind == 'dataset': + if kind == "dataset": # Log mlflow dataset # https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.log_input # It may be needed to convert item into a mlflow dataset, e.g.: @@ -597,47 +588,33 @@ def log( if isinstance(item, mlflow.data.Dataset): mlflow.log_input(item) else: - print("WARNING: unrecognized dataset type. " - "Must be an MLFlow dataset") - if kind == 'torch': + print( + "WARNING: unrecognized dataset type. " "Must be an MLFlow dataset" + ) + if kind == "torch": import torch # Save the object locally and then log it name = os.path.basename(identifier) - save_path = os.path.join(self.savedir, '.trash', name) + save_path = os.path.join(self.savedir, ".trash", name) os.makedirs(os.path.dirname(save_path), exist_ok=True) torch.save(item, save_path) # Log into mlflow - mlflow.log_artifact( - local_path=save_path, - artifact_path=identifier - ) - if kind == 'dict': - mlflow.log_dict( - dictionary=item, - artifact_file=identifier - ) - if kind == 'figure': + mlflow.log_artifact(local_path=save_path, artifact_path=identifier) + if kind == "dict": + mlflow.log_dict(dictionary=item, artifact_file=identifier) + if kind == "figure": mlflow.log_figure( artifact_file=identifier, figure=item, - save_kwargs=kwargs.get('save_kwargs') - ) - if kind == 'image': - mlflow.log_image( - artifact_file=identifier, - image=item - ) - if kind == 'param': - mlflow.log_param( - key=identifier, - value=item - ) - if kind == 'text': - mlflow.log_text( - artifact_file=identifier, - text=item + save_kwargs=kwargs.get("save_kwargs"), ) + if kind == "image": + mlflow.log_image(artifact_file=identifier, image=item) + if kind == "param": + mlflow.log_param(key=identifier, value=item) + if kind == "text": + mlflow.log_text(artifact_file=identifier, text=item) class WandBLogger(Logger): @@ -662,21 +639,26 @@ class WandBLogger(Logger): #: Supported kinds in the ``log`` method supported_kinds: Tuple[str] = ( - 'watch', 'metric', 'figure', 'image', 'torch', 'dict', - 'param', 'text') + "watch", + "metric", + "figure", + "image", + "torch", + "dict", + "param", + "text", + ) def __init__( self, - savedir: str = 'mllogs', + savedir: str = "mllogs", project_name: str = BASE_EXP_NAME, - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - log_on_workers: Union[int, List[int]] = 0 + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = os.path.join(savedir, 'wandb') + savedir = os.path.join(savedir, "wandb") super().__init__( - savedir=savedir, - log_freq=log_freq, - log_on_workers=log_on_workers + savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers ) self.project_name = project_name @@ -693,10 +675,9 @@ def create_logger_context(self, rank: Optional[int] = None) -> None: if not self.should_log(): return - os.makedirs(os.path.join(self.savedir, 'wandb'), exist_ok=True) + os.makedirs(os.path.join(self.savedir, "wandb"), exist_ok=True) self.active_run = wandb.init( - dir=os.path.abspath(self.savedir), - project=self.project_name + dir=os.path.abspath(self.savedir), project=self.project_name ) def destroy_logger_context(self): @@ -719,10 +700,10 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Log with WandB. Wrapper of https://docs.wandb.ai/ref/python/log @@ -741,7 +722,7 @@ def log( if not self.should_log(batch_idx=batch_idx): return - if kind == 'watch': + if kind == "watch": wandb.watch(item) elif kind in self.supported_kinds: # wandb.log({identifier: item}, step=step, commit=True) @@ -776,33 +757,31 @@ class TensorBoardLogger(Logger): # and add the missing logging types supported by each. #: Supported kinds in the ``log`` method - supported_kinds: Tuple[str] = ( - 'metric', 'image', 'text', 'figure', 'torch') + supported_kinds: Tuple[str] = ("metric", "image", "text", "figure", "torch") def __init__( self, - savedir: str = 'mllogs', - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - framework: Literal['tensorflow', 'pytorch'] = 'pytorch', - log_on_workers: Union[int, List[int]] = 0 + savedir: str = "mllogs", + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + framework: Literal["tensorflow", "pytorch"] = "pytorch", + log_on_workers: Union[int, List[int]] = 0, ) -> None: - savedir = os.path.join(savedir, 'tensorboard') + savedir = os.path.join(savedir, "tensorboard") super().__init__( - savedir=savedir, - log_freq=log_freq, - log_on_workers=log_on_workers + savedir=savedir, log_freq=log_freq, log_on_workers=log_on_workers ) self.framework = framework - if framework.lower() == 'tensorflow': + if framework.lower() == "tensorflow": import tensorflow as tf + self.tf = tf self.writer = tf.summary.create_file_writer(savedir) - elif framework.lower() == 'pytorch': + elif framework.lower() == "pytorch": from torch.utils.tensorboard import SummaryWriter + self.writer = SummaryWriter(savedir) else: - raise ValueError( - "Framework must be either 'tensorflow' or 'pytorch'") + raise ValueError("Framework must be either 'tensorflow' or 'pytorch'") def create_logger_context(self, rank: Optional[int] = None) -> None: """ @@ -817,7 +796,7 @@ def create_logger_context(self, rank: Optional[int] = None) -> None: if not self.should_log(): return - if self.framework == 'tensorflow': + if self.framework == "tensorflow": self.writer.set_as_default() def destroy_logger_context(self): @@ -836,22 +815,23 @@ def save_hyperparameters(self, params: Dict[str, Any]) -> None: if not self.should_log(): return - if self.framework == 'tensorflow': + if self.framework == "tensorflow": from tensorboard.plugins.hparams import api as hp + hparams = {hp.HParam(k): v for k, v in params.items()} with self.writer.as_default(): hp.hparams(hparams) - elif self.framework == 'pytorch': + elif self.framework == "pytorch": self.writer.add_hparams(params, {}) def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Log with Tensorboard. @@ -870,26 +850,26 @@ def log( if not self.should_log(batch_idx=batch_idx): return - if self.framework == 'tensorflow': + if self.framework == "tensorflow": with self.writer.as_default(): - if kind == 'metric': + if kind == "metric": self.tf.summary.scalar(identifier, item, step=step) - elif kind == 'image': + elif kind == "image": self.tf.summary.image(identifier, item, step=step) - elif kind == 'text': + elif kind == "text": self.tf.summary.text(identifier, item, step=step) - elif kind == 'figure': + elif kind == "figure": self.tf.summary.figure(identifier, item, step=step) - elif self.framework == 'pytorch': - if kind == 'metric': + elif self.framework == "pytorch": + if kind == "metric": self.writer.add_scalar(identifier, item, global_step=step) - elif kind == 'image': + elif kind == "image": self.writer.add_image(identifier, item, global_step=step) - elif kind == 'text': + elif kind == "text": self.writer.add_text(identifier, item, global_step=step) - elif kind == 'figure': + elif kind == "figure": self.writer.add_figure(identifier, item, global_step=step) - elif kind == 'torch': + elif kind == "torch": self.writer.add_graph(item) @@ -903,11 +883,8 @@ class LoggersCollection(Logger): #: Supported kinds are delegated to the loggers in the collection. supported_kinds: Tuple[str] - def __init__( - self, - loggers: List[Logger] - ) -> None: - super().__init__(savedir='/tmp/mllogs_LoggersCollection', log_freq=1) + def __init__(self, loggers: List[Logger]) -> None: + super().__init__(savedir="/tmp/mllogs_LoggersCollection", log_freq=1) self.loggers = loggers def should_log(self, batch_idx: int = None) -> bool: @@ -927,10 +904,10 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - **kwargs + **kwargs, ) -> None: """Log on all loggers. @@ -953,7 +930,7 @@ def log( kind=kind, step=step, batch_idx=batch_idx, - **kwargs + **kwargs, ) def create_logger_context(self, rank: Optional[int] = None) -> Any: @@ -1011,9 +988,16 @@ class Prov4MLLogger(Logger): #: Supported kinds in the ``log`` method supported_kinds: Tuple[str] = ( - 'metric', 'flops_pb', 'flops_pe', 'system', 'carbon', - 'execution_time', 'model', 'best_model', - 'torch') + "metric", + "flops_pb", + "flops_pe", + "system", + "carbon", + "execution_time", + "model", + "best_model", + "torch", + ) def __init__( self, @@ -1023,13 +1007,13 @@ def __init__( save_after_n_logs: Optional[int] = 100, create_graph: Optional[bool] = True, create_svg: Optional[bool] = True, - log_freq: Union[int, Literal['epoch', 'batch']] = 'epoch', - log_on_workers: Union[int, List[int]] = 0 + log_freq: Union[int, Literal["epoch", "batch"]] = "epoch", + log_on_workers: Union[int, List[int]] = 0, ) -> None: super().__init__( savedir=provenance_save_dir, log_freq=log_freq, - log_on_workers=log_on_workers + log_on_workers=log_on_workers, ) self.name = experiment_name self.version = None @@ -1060,7 +1044,7 @@ def create_logger_context(self, rank: Optional[int] = None): save_after_n_logs=self.save_after_n_logs, # This class will control which workers can log collect_all_processes=True, - rank=rank + rank=rank, ) @override @@ -1071,9 +1055,7 @@ def destroy_logger_context(self): if not self.should_log(): return - prov4ml.end_run( - create_graph=self.create_graph, - create_svg=self.create_svg) + prov4ml.end_run(create_graph=self.create_graph, create_svg=self.create_svg) @override def save_hyperparameters(self, params: Dict[str, Any]) -> None: @@ -1089,11 +1071,11 @@ def log( self, item: Union[Any, List[Any]], identifier: Union[str, List[str]], - kind: str = 'metric', + kind: str = "metric", step: Optional[int] = None, batch_idx: Optional[int] = None, - context: Optional[str] = 'training', - **kwargs + context: Optional[str] = "training", + **kwargs, ) -> None: """Logs with Prov4ML. @@ -1114,33 +1096,39 @@ def log( return if kind == "metric": - prov4ml.log_metric(key=identifier, value=item, - context=context, step=step) + prov4ml.log_metric(key=identifier, value=item, context=context, step=step) elif kind == "flops_pb": model, batch = item prov4ml.log_flops_per_batch( - identifier, model=model, - batch=batch, context=context, step=step) + identifier, model=model, batch=batch, context=context, step=step + ) elif kind == "flops_pe": model, dataset = item prov4ml.log_flops_per_epoch( - identifier, model=model, - dataset=dataset, context=context, step=step) + identifier, model=model, dataset=dataset, context=context, step=step + ) elif kind == "system": prov4ml.log_system_metrics(context=context, step=step) elif kind == "carbon": prov4ml.log_carbon_metrics(context=context, step=step) elif kind == "execution_time": prov4ml.log_current_execution_time( - label=identifier, context=context, step=step) - elif kind == 'model': + label=identifier, context=context, step=step + ) + elif kind == "model": prov4ml.save_model_version( - model=item, model_name=identifier, context=context, step=step) - elif kind == 'best_model': - prov4ml.log_model(model=item, model_name=identifier, - log_model_info=True, log_as_artifact=True) - elif kind == 'torch': + model=item, model_name=identifier, context=context, step=step + ) + elif kind == "best_model": + prov4ml.log_model( + model=item, + model_name=identifier, + log_model_info=True, + log_as_artifact=True, + ) + elif kind == "torch": from torch.utils.data import DataLoader + if isinstance(item, DataLoader): prov4ml.log_dataset(dataset=item, label=identifier) else: @@ -1148,49 +1136,30 @@ def log( class EpochTimeTracker: - """Profiler for epoch execution time used to support scaling tests. - It uses CSV files to store, for each epoch, the ``name`` of the - experiment, the number of compute ``nodes`` used, the ``epoch_id``, - and the execution ``time`` in seconds. + """Tracker for epoch execution time during training.""" - Args: - series_name (str): name of the experiment/job. - csv_file (str): path to CSV file to store experiments times. - """ + def __init__( + self, strategy_name: str, save_path: Union[Path, str], num_nodes: int + ) -> None: + if isinstance(save_path, str): + save_path = Path(save_path) - def __init__(self, series_name: str, csv_file: str) -> None: - self.series_name = series_name - self._data = [] - self.csv_file = csv_file - with open(csv_file, 'w') as csvfile: - csvwriter = csv.writer(csvfile) - csvwriter.writerow(['name', 'nodes', 'epoch_id', 'time']) + self.save_path: Path = save_path + self.strategy_name = strategy_name + self.num_nodes = num_nodes + self.data = {"epoch_id": [], "time": []} def add_epoch_time(self, epoch_idx: int, time: float) -> None: - """Add row to the current experiment's CSV file in append mode. - - Args: - epoch_idx (int): epoch order idx. - time (float): epoch execution time (seconds). - """ - n_nodes = os.environ.get('SLURM_NNODES', -1) - fields = (self.series_name, n_nodes, epoch_idx, time) - self._data.append(fields) - with open(self.csv_file, 'a') as csvfile: - csvwriter = csv.writer(csvfile) - csvwriter.writerow(fields) - - def save(self, csv_file: Optional[str] = None) -> None: - """Save data to a new CSV file. - - Args: - csv_file (Optional[str], optional): path to the CSV file. - If not given, uses the one given in the constructor. - Defaults to None. - """ - if not csv_file: - csv_file = self.csv_file - with open(csv_file, 'w') as csvfile: - csvwriter = csv.writer(csvfile) - csvwriter.writerow(['name', 'nodes', 'epoch_id', 'time']) - csvwriter.writerows(self._data) + """Add epoch time to data.""" + self.data["epoch_id"].append(epoch_idx) + self.data["time"].append(time) + + def save(self) -> None: + """Save data to a new CSV file.""" + df = pd.DataFrame(self.data) + df["name"] = self.strategy_name + df["nodes"] = self.num_nodes + + self.save_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(self.save_path, index=False) + print(f"Saving EpochTimeTracking data to '{self.save_path.resolve()}'.") diff --git a/src/itwinai/scalability.py b/src/itwinai/scalability.py new file mode 100644 index 000000000..d194092e6 --- /dev/null +++ b/src/itwinai/scalability.py @@ -0,0 +1,195 @@ +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Jarl Sondre Sæther +# +# Credit: +# - Jarl Sondre Sæther - CERN +# - Matteo Bunino - CERN +# -------------------------------------------------------------------------------------- + +import uuid +from itertools import cycle +from pathlib import Path +from re import Match, Pattern, compile +from typing import Optional, Union + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + + +def convert_matching_files_to_dataframe( + log_dir: Path, pattern: Optional[str], expected_columns: Optional[set] = None +) -> pd.DataFrame: + """Reads and combines all files in a folder that matches the given regex pattern + into a single DataFrame. The files must be formatted as csv files. If pattern is + None, we assume a match on all files. + + Raises: + ValueError: If not all expected columns are found in the stored DataFrame. + ValueError: If no matching files are found in the given logging directory. + """ + re_pattern: Optional[Pattern] = None + if pattern is not None: + re_pattern = compile(pattern) + + if expected_columns is None: + expected_columns = set() + + dataframes = [] + for entry in log_dir.iterdir(): + match: Union[bool, Match] = True + if re_pattern is not None: + match = re_pattern.search(str(entry)) + + if not match: + continue + + df = pd.read_csv(entry) + if not expected_columns.issubset(df.columns): + missing_columns = expected_columns - set(df.columns) + raise ValueError( + f"Invalid data format! File at '{str(entry)}' doesn't contain all" + f" necessary columns. \nMissing columns: {missing_columns}" + ) + + dataframes.append(df) + + if len(dataframes) == 0: + if pattern is None: + error_message = f"Unable to find any files in {log_dir.resolve()}!" + else: + error_message = ( + f"No files matched pattern, '{pattern}', in log_dir, " + f"{log_dir.resolve()}!" + ) + raise ValueError(error_message) + + return pd.concat(dataframes) + + +def create_absolute_plot(avg_epoch_time_df: pd.DataFrame) -> None: + """Creates a plot showing the absolute training times for the different + distributed strategies and different number of GPUs. + """ + sns.set_theme() + fig, ax = plt.subplots() + + marker_cycle = cycle("ov^s*dXpD.+12348") + + unique_nodes = list(avg_epoch_time_df["nodes"].unique()) + unique_names = avg_epoch_time_df["name"].unique() + for name in unique_names: + # color, marker = next(color_marker_combinations) + marker = next(marker_cycle) + data = avg_epoch_time_df[avg_epoch_time_df["name"] == name] + + ax.plot( + data["nodes"], + data["time"], + marker=marker, + label=name, + linestyle="-", + markersize=6, + ) + + ax.set_yscale("log") + ax.set_xscale("log") + + ax.set_xticks(unique_nodes) + + ax.set_xlabel("Number of Nodes") + ax.set_ylabel("Average Time (s)") + ax.set_title("Average Time vs Number of Nodes") + + ax.legend(title="Method") + ax.grid(True) + + output_path = Path("plots/absolute_scalability_plot.png") + plt.savefig(output_path) + print(f"Saving absolute plot to '{output_path.resolve()}'.") + sns.reset_orig() + + +def create_relative_plot(avg_epoch_time_df: pd.DataFrame, gpus_per_node: int = 4): + """Creates a plot showing the relative training times for the different + distributed strategies and different number of GPUs. In particular, it shows the + speedup when adding more GPUs, compared to the baseline of using a single node. + """ + sns.set_theme() + + fig, ax = plt.subplots(figsize=(6, 4)) + # fig.suptitle(plot_title) + + ax.set_yscale("log") + ax.set_xscale("log") + + marker_cycle = cycle("ov^s*dXpD.+12348") + avg_epoch_time_df["num_gpus"] = avg_epoch_time_df["nodes"] * gpus_per_node + avg_epoch_time_df["linear_speedup"] = avg_epoch_time_df["nodes"].astype(float) + + # Plotting the speedup for each strategy + strategy_names = sorted(avg_epoch_time_df["name"].unique()) + for strategy in strategy_names: + strategy_data = avg_epoch_time_df[avg_epoch_time_df.name == strategy] + + base_time = strategy_data["time"].iloc[0] + speedup = base_time / strategy_data["time"] + num_gpus = strategy_data["num_gpus"] + + marker = next(marker_cycle) + ax.plot(num_gpus, speedup, marker=marker, lw=1.0, label=strategy, alpha=0.7) + + # Plotting the linear line + num_gpus = np.array(avg_epoch_time_df["num_gpus"].unique()) + linear_speedup = np.array(avg_epoch_time_df["linear_speedup"].unique()) + ax.plot( + num_gpus, linear_speedup, ls="dashed", lw=1.0, c="k", label="linear speedup" + ) + + ax.legend(ncol=1) + ax.set_xticks(num_gpus) + ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) + ax.set_ylabel("Speedup") + ax.set_xlabel("Number of GPUs (4 per node)") + ax.grid(True) + + # Sorted legend + handles, labels = ax.get_legend_handles_labels() + sorted_handles_labels = sorted(zip(handles, labels), key=lambda x: x[1]) + sorted_handles, sorted_labels = zip(*sorted_handles_labels) + plt.legend(sorted_handles, sorted_labels) + + plot_path = Path("plots/relative_scalability_plot.png") + plt.tight_layout() + plt.savefig(plot_path, bbox_inches="tight", format="png", dpi=300) + print(f"Saving relative plot to '{plot_path.resolve()}'.") + + sns.reset_orig() + + +def backup_scalability_metrics( + metric_df: pd.DataFrame, + experiment_name: Optional[str], + run_name: Optional[str], + backup_dir: str, + filename: str, +) -> None: + """Stores the data in the given dataframe as a .csv file in its own folder for the + experiment name and its own subfolder for the run_name. If these are not provided, + then they will be generated randomly using uuid4. + """ + if experiment_name is None: + random_id = str(uuid.uuid4()) + experiment_name = "exp_" + random_id[:6] + if run_name is None: + random_id = str(uuid.uuid4()) + run_name = "run_" + random_id[:6] + + backup_path = Path(backup_dir) / experiment_name / run_name / filename + backup_path.parent.mkdir(parents=True, exist_ok=True) + metric_df.to_csv(backup_path, index=False) + print(f"Storing backup file at '{backup_path.resolve()}'.") diff --git a/src/itwinai/torch/monitoring/monitoring.py b/src/itwinai/torch/monitoring/monitoring.py index 9857ccb78..602195158 100644 --- a/src/itwinai/torch/monitoring/monitoring.py +++ b/src/itwinai/torch/monitoring/monitoring.py @@ -1,3 +1,13 @@ +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Jarl Sondre Sæther +# +# Credit: +# - Jarl Sondre Sæther - CERN +# - Matteo Bunino - CERN +# -------------------------------------------------------------------------------------- + import functools import time from multiprocessing import Manager, Process @@ -83,7 +93,6 @@ def probe_gpu_utilization_loop( log_dict["probing_interval"].append(probing_interval) sample_idx += 1 - time.sleep(probing_interval) @@ -99,7 +108,7 @@ def write_logs_to_file(utilization_logs: List[Dict], output_path: Path) -> None: log_df = pd.concat(dataframes) log_df.to_csv(output_path, index=False) - print(f"Writing GPU energy dataframe to '{output_path}'.") + print(f"Writing GPU energy dataframe to '{output_path.resolve()}'.") @functools.wraps(method) def measured_method(self: TorchTrainer, *args, **kwargs) -> Any: @@ -115,11 +124,6 @@ def measured_method(self: TorchTrainer, *args, **kwargs) -> Any: num_local_gpus = strategy.local_world_size() node_idx = global_rank // num_local_gpus - output_path = Path( - f"scalability_metrics/gpu_energy_data_{strategy_name}_{num_global_gpus}.csv" - ) - output_path.parent.mkdir(exist_ok=True, parents=True) - gpu_monitor_process = None manager = None stop_flag = None @@ -165,7 +169,7 @@ def measured_method(self: TorchTrainer, *args, **kwargs) -> Any: global_utilization_log = strategy.gather_obj(local_utilization_log, dst_rank=0) if strategy.is_main_worker: - output_dir = Path("scalability_metrics/gpu_energy_data") + output_dir = Path("scalability-metrics/gpu-energy-data") output_dir.mkdir(exist_ok=True, parents=True) output_path = output_dir / f"{strategy_name}_{num_global_gpus}.csv" diff --git a/src/itwinai/torch/monitoring/plotting.py b/src/itwinai/torch/monitoring/plotting.py index 554f5dc42..4ae764d99 100644 --- a/src/itwinai/torch/monitoring/plotting.py +++ b/src/itwinai/torch/monitoring/plotting.py @@ -1,6 +1,14 @@ -from pathlib import Path -from re import Match, Pattern, compile -from typing import Optional, Tuple, Union +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Jarl Sondre Sæther +# +# Credit: +# - Jarl Sondre Sæther - CERN +# - Matteo Bunino - CERN +# -------------------------------------------------------------------------------------- + +from typing import Tuple import matplotlib import matplotlib.pyplot as plt @@ -14,53 +22,36 @@ matplotlib.use("Agg") -def read_energy_df(pattern: Optional[str], log_dir: Path) -> pd.DataFrame: - """Read files matching the given regex pattern from directory and converting them - into a Pandas DataFrame. If pattern is None, we assume a match on all files. - Expects that the existence of ``log_dir`` is handled before calling this function. +def calculate_average_gpu_utilization(gpu_data_df: pd.DataFrame) -> pd.DataFrame: + """Calculates the average GPU utilization for each strategy and + number of GPUs. - Args: - pattern: The regex string used to match files. - log_dir: The directory to search for files in. - - Raises: - ValueError: If no matching files are found in the given logging directory. + Returns: + pd.DataFrame: A DataFrame containing the average gpu utilization for + each strategy and number of GPUs, with the columns ``strategy``, + ``num_global_gpus`` and ``utilization``. """ + required_columns = {"strategy", "utilization", "num_global_gpus"} + if not required_columns.issubset(set(gpu_data_df.columns)): + missing_columns = set(required_columns) - set(gpu_data_df.columns) + raise ValueError( + f"DataFrame is missing the following columns: {missing_columns}" + ) - pattern_re: Optional[Pattern] = None - if pattern is not None: - pattern_re = compile(pattern) - - # Load and concatenate dataframes - dataframes = [] - for entry in log_dir.iterdir(): - match: Union[bool, Match] = True - if pattern_re is not None: - match = pattern_re.search(str(entry)) - - if not match: - continue - - print(f"Loading data from file: '{entry}' when creating energy DataFrame") - df = pd.read_csv(entry) - dataframes.append(df) - - if len(dataframes) == 0: - if pattern is None: - error_message = f"Unable to find any files in {log_dir.resolve()}!" - else: - error_message = ( - f"No files matched pattern, '{pattern}', in log_dir, " - f"{log_dir.resolve()}!" - ) - raise ValueError(error_message) - - return pd.concat(dataframes) + utilization_data = [] + grouped_df = gpu_data_df.groupby(["strategy", "num_global_gpus"]) + for (strategy, num_gpus), group in grouped_df: + utilization_data.append( + { + "strategy": strategy, + "num_global_gpus": num_gpus, + "utilization": group["utilization"].mean(), + } + ) + return pd.DataFrame(utilization_data) -def calculate_aggregated_energy_expenditure( - gpu_utilization_df: pd.DataFrame, -) -> pd.DataFrame: +def calculate_total_energy_expenditure(gpu_data_df: pd.DataFrame) -> pd.DataFrame: """Calculates the total energy expenditure in Watt hours for each strategy and number of GPUs. Expects that the existence of the appropriate DataFrame columns is handled before calling this function. @@ -70,9 +61,14 @@ def calculate_aggregated_energy_expenditure( each strategy and number of GPUs, with the columns ``strategy``, ``num_global_gpus`` and ``total_energy_wh``. """ + required_columns = {"strategy", "power", "num_global_gpus", "probing_interval"} + if not required_columns.issubset(set(gpu_data_df.columns)): + missing_columns = set(required_columns) - set(gpu_data_df.columns) + raise ValueError( + f"DataFrame is missing the following columns: {missing_columns}" + ) energy_data = [] - - grouped_df = gpu_utilization_df.groupby(["strategy", "num_global_gpus"]) + grouped_df = gpu_data_df.groupby(["strategy", "num_global_gpus"]) for (strategy, num_gpus), group in grouped_df: if len(group["probing_interval"].unique()) != 1: @@ -94,22 +90,27 @@ def calculate_aggregated_energy_expenditure( return pd.DataFrame(energy_data) -def gpu_energy_plot(gpu_utilization_df: pd.DataFrame) -> Tuple[Figure, Axes]: - """Makes an energy bar plot of the GPU utilization dataframe, showing the total - energy expenditure for each strategy and number of GPUs in Watt hours. +def gpu_bar_plot( + data_df: pd.DataFrame, plot_title: str, y_label: str, main_column: str +) -> Tuple[Figure, Axes]: + """Makes a bar plot of the given data for each strategy and number of GPUs. + + Args: + data_df: The dataframe to extract the data from. + plot_title: The title to give the plot. + y_label: The label for the y-axis. + main_column: The column to use for the height of the bar plot. """ - required_columns = {"strategy", "power", "num_global_gpus", "probing_interval"} - if not required_columns.issubset(set(gpu_utilization_df.columns)): - missing_columns = set(required_columns) - set(set(gpu_utilization_df.columns)) + required_columns = {"strategy", "num_global_gpus", main_column} + if not required_columns.issubset(set(data_df.columns)): + missing_columns = set(required_columns) - set(data_df.columns) raise ValueError( f"DataFrame is missing the following columns: {missing_columns}" ) sns.set_theme() - energy_df = calculate_aggregated_energy_expenditure(gpu_utilization_df) - - strategies = energy_df["strategy"].unique() - unique_gpu_counts = np.array(energy_df["num_global_gpus"].unique()) + strategies = data_df["strategy"].unique() + unique_gpu_counts = np.array(data_df["num_global_gpus"].unique()) fig, ax = plt.subplots() x = np.arange(len(unique_gpu_counts)) @@ -118,23 +119,29 @@ def gpu_energy_plot(gpu_utilization_df: pd.DataFrame) -> Tuple[Figure, Axes]: static_offset = (len(strategies) - 1) / 2 for strategy_idx, strategy in enumerate(strategies): dynamic_bar_offset = strategy_idx - static_offset - strategy_data = energy_df[energy_df["strategy"] == strategy] + strategy_data = data_df[data_df["strategy"] == strategy] # Ensuring the correct spacing of the bars strategy_num_gpus = len(strategy_data["num_global_gpus"]) ax.bar( x=x[:strategy_num_gpus] + dynamic_bar_offset * bar_width, - height=strategy_data["total_energy_wh"], + height=strategy_data[main_column], width=bar_width, label=strategy, ) ax.set_xlabel("Num GPUs") - ax.set_ylabel("Energy Consumption (Wh)") - ax.set_title("Energy Consumption by Strategy and Number of GPUs") + ax.set_ylabel(y_label) + ax.set_title(plot_title) ax.set_xticks(x) ax.set_xticklabels(unique_gpu_counts) ax.legend(title="Strategy") + figure_width = int(1.5 * len(unique_gpu_counts)) + fig.set_figheight(6) + fig.set_figwidth(figure_width) + + sns.reset_orig() + return fig, ax diff --git a/src/itwinai/torch/profiling/__init__.py b/src/itwinai/torch/profiling/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/itwinai/torch/profiling/communication_plot.py b/src/itwinai/torch/profiling/communication_plot.py index 285a62dac..fc119826e 100644 --- a/src/itwinai/torch/profiling/communication_plot.py +++ b/src/itwinai/torch/profiling/communication_plot.py @@ -1,6 +1,13 @@ -from pathlib import Path -from re import Match, Pattern, compile -from typing import Any, List, Optional, Tuple, Union +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Jarl Sondre Sæther +# +# Credit: +# - Jarl Sondre Sæther - CERN +# -------------------------------------------------------------------------------------- + +from typing import Any, List, Tuple import matplotlib import matplotlib.pyplot as plt @@ -9,13 +16,12 @@ import seaborn as sns from matplotlib.patches import Patch +# from itwinai.scalability import convert_matching_files_to_dataframe + # Doing this because otherwise I get an error about X11 Forwarding which I believe # is due to the server trying to pass the image to the client computer matplotlib.use("Agg") -# import logging -# from logging import Logger as PythonLogger - def calculate_comp_and_comm_time(df: pd.DataFrame) -> Tuple[float, float]: """Calculates the time spent computing and time spent communicating and returns a @@ -58,7 +64,7 @@ def calculate_comp_and_comm_time(df: pd.DataFrame) -> Tuple[float, float]: return comp_time, comm_time -def create_stacked_plot( +def communication_overhead_stacked_bar_plot( values: np.ndarray, strategy_labels: List, gpu_numbers: List ) -> Tuple[Any, Any]: """Creates a stacked plot showing values from 0 to 1, where the given value @@ -72,13 +78,13 @@ def create_stacked_plot( the GPU numbers in 'gpu_numbers' sorted numerically in ascending order. """ sns.set_theme() + color_map = plt.get_cmap("tab10") + hatch_patterns = ["//", r"\\"] strategy_labels = sorted(strategy_labels) gpu_numbers = sorted(gpu_numbers, key=lambda x: int(x)) width = 1 / (len(strategy_labels) + 1) - comp_color = "lightblue" - comm_color = "lightgreen" complements = 1 - values x = np.arange(len(gpu_numbers)) @@ -89,120 +95,59 @@ def create_stacked_plot( for strategy_idx in range(len(strategy_labels)): dynamic_bar_offset = strategy_idx - static_offset + color = color_map(strategy_idx % 10) + hatch = hatch_patterns[strategy_idx % 2] + ax.bar( x=x + dynamic_bar_offset * width, height=values[strategy_idx], width=width, - color=comp_color, + color=color, + label=strategy_labels[strategy_idx], + edgecolor="gray", + linewidth=0.6, ) ax.bar( x=x + dynamic_bar_offset * width, height=complements[strategy_idx], width=width, bottom=values[strategy_idx], - color=comm_color, + facecolor="none", + edgecolor="gray", + alpha=0.8, + linewidth=0.6, + hatch=hatch, ) - # Positioning the labels under the stacks - for gpu_idx in range(len(gpu_numbers)): - if np.isnan(values[strategy_idx, gpu_idx]): - continue - dynamic_label_offset = strategy_idx - static_offset - ax.text( - x=x[gpu_idx] + dynamic_label_offset * width, - y=-0.1, - s=strategy_labels[strategy_idx], - ha="center", - va="top", - fontsize=10, - rotation=60, - ) - ax.set_ylabel("Computation fraction") + ax.set_xlabel("Number of GPUs") ax.set_title("Computation vs Communication Time by Method") ax.set_xticks(x) ax.set_xticklabels(gpu_numbers) - ax.set_ylim(0, 1) + ax.set_ylim(0, 1.1) - # Setting the appropriate colors since the legend is manual - legend_elements = [ - Patch(facecolor=comm_color, label="Communication"), - Patch(facecolor=comp_color, label="Computation"), - ] - - # Positioning the legend outside of the plot to not obstruct it - ax.legend( - handles=legend_elements, - loc="upper left", - bbox_to_anchor=(0.80, 1.22), - borderaxespad=0.0, + # Adding communication time to the legend + hatch_patch = Patch( + facecolor="none", edgecolor="gray", hatch="//", label="Communication" ) - fig.subplots_adjust(bottom=0.25) - fig.subplots_adjust(top=0.85) - return fig, ax + ax.legend(handles=ax.get_legend_handles_labels()[0] + [hatch_patch]) + # Dynamically adjusting the width of the figure + figure_width = int(1.5 * len(gpu_numbers)) + fig.set_figheight(5) + fig.set_figwidth(figure_width) -def create_combined_comm_overhead_df( - log_dir: Path, pattern: Optional[str] -) -> pd.DataFrame: - """Reads and combines all files in a folder that matches the given regex pattern - into a single DataFrame. The files must be formatted as csv files. If pattern is - None, we assume a match on all files. + sns.reset_orig() - Raises: - ValueError: If not all expected columns are found in the stored DataFrame. - ValueError: If no matching files are found in the given logging directory. - """ - re_pattern: Optional[Pattern] = None - if pattern is not None: - re_pattern = compile(pattern) - - dataframes = [] - expected_columns = { - "strategy", - "num_gpus", - "global_rank", - "name", - "self_cuda_time_total", - } - for entry in log_dir.iterdir(): - match: Union[bool, Match] = True - if re_pattern is not None: - match = re_pattern.search(str(entry)) - - if not match: - continue - - df = pd.read_csv(entry) - if not expected_columns.issubset(df.columns): - missing_columns = expected_columns - set(df.columns) - raise ValueError( - f"Invalid data format! File at '{str(entry)}' doesn't contain all" - f" necessary columns. \nMissing columns: {missing_columns}" - ) - - dataframes.append(df) - - if len(dataframes) == 0: - if pattern is None: - error_message = f"Unable to find any files in {log_dir.resolve()}!" - else: - error_message = ( - f"No files matched pattern, '{pattern}', in log_dir, " - f"{log_dir.resolve()}!" - ) - raise ValueError(error_message) - - return pd.concat(dataframes) + return fig, ax def get_comp_fraction_full_array( df: pd.DataFrame, print_table: bool = False ) -> np.ndarray: - """Creates a MxN NumPy array where M is the number of strategies - and N is the number of GPU configurations. The strategies are sorted - alphabetically and the GPU configurations are sorted in ascending number - of GPUs. + """Creates a MxN NumPy array where M is the number of strategies and N is the + number of GPU configurations. The strategies are sorted alphabetically and the GPU + configurations are sorted in ascending number of GPUs. """ unique_num_gpus = sorted(df["num_gpus"].unique(), key=lambda x: int(x)) unique_strategies = sorted(df["strategy"].unique()) @@ -219,8 +164,7 @@ def get_comp_fraction_full_array( row_string = f"{strategy:>12} | {num_gpus:>10}" - # Allows asymmetric testing, i.e. not testing all num gpus and all - # strategies together + # Allows some strategies or num GPUs to not be included if len(filtered_df) == 0: comp_time, comm_time = np.NaN, np.NaN strategy_values.append(np.NaN) diff --git a/src/itwinai/torch/profiling/profiler.py b/src/itwinai/torch/profiling/profiler.py index 7ff43665b..4d35d9b5f 100644 --- a/src/itwinai/torch/profiling/profiler.py +++ b/src/itwinai/torch/profiling/profiler.py @@ -1,4 +1,12 @@ -from __future__ import annotations +# -------------------------------------------------------------------------------------- +# Part of the interTwin Project: https://www.intertwin.eu/ +# +# Created by: Jarl Sondre Sæther +# +# Credit: +# - Jarl Sondre Sæther - CERN +# - Matteo Bunino - CERN +# -------------------------------------------------------------------------------------- import functools from pathlib import Path @@ -27,9 +35,9 @@ def gather_profiling_data(key_averages: Iterable) -> pd.DataFrame: { "name": event.key, "node_id": event.node_id, - "self_cpu_time_total": event.self_cpu_time_total, - "cpu_time_total": event.cpu_time_total, - "cpu_time_total_str": event.cpu_time_total_str, + # "self_cpu_time_total": event.self_cpu_time_total, + # "cpu_time_total": event.cpu_time_total, + # "cpu_time_total_str": event.cpu_time_total_str, "self_cuda_time_total": event.self_cuda_time_total, "cuda_time_total": event.cuda_time_total, "cuda_time_total_str": event.cuda_time_total_str, @@ -85,8 +93,8 @@ def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: ) profiler = profile( - activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], - with_modules=True, + activities=[ProfilerActivity.CUDA], # , ProfilerActivity.CPU], + # with_modules=True, schedule=schedule( wait=wait_epochs, warmup=warmup_epochs, @@ -114,7 +122,7 @@ def profiled_method(self: TorchTrainer, *args, **kwargs) -> Any: profiling_dataframe["num_gpus"] = num_gpus_global profiling_dataframe["global_rank"] = global_rank - profiling_log_dir = Path("scalability_metrics/communication_data") + profiling_log_dir = Path("scalability-metrics/communication-data") profiling_log_dir.mkdir(parents=True, exist_ok=True) filename = f"{strategy_name}_{num_gpus_global}_{global_rank}.csv" diff --git a/tutorials/distributed-ml/torch-scaling-test/ddp_trainer.py b/tutorials/distributed-ml/torch-scaling-test/ddp_trainer.py index 0a25ae5ba..46389a9b1 100755 --- a/tutorials/distributed-ml/torch-scaling-test/ddp_trainer.py +++ b/tutorials/distributed-ml/torch-scaling-test/ddp_trainer.py @@ -1,28 +1,25 @@ """ Scaling test of torch Distributed Data Parallel on Imagenet using Resnet. """ -from typing import Optional import argparse -import sys import os -from timeit import default_timer as timer +import sys import time +from timeit import default_timer as timer +from typing import Optional import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F +import torchvision from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -import torchvision +from utils import imagenet_dataset -from itwinai.parser import ArgumentParser as ItAIArgumentParser from itwinai.loggers import EpochTimeTracker -from itwinai.torch.reproducibility import ( - seed_worker, set_seed -) - -from utils import imagenet_dataset +from itwinai.parser import ArgumentParser as ItAIArgumentParser +from itwinai.torch.reproducibility import seed_worker, set_seed def parse_params(): @@ -208,8 +205,9 @@ def main(): print('--------------------------------------------------------') nnod = os.environ.get('SLURM_NNODES', 'unk') epoch_time_tracker = EpochTimeTracker( - series_name="ddp-bl", - csv_file=f"epochtime_ddp-bl_{nnod}N.csv" + strategy_name="ddp-bl", + save_path=f"epochtime_ddp-bl_{nnod}N.csv", + num_nodes=int(nnod) ) et = timer() diff --git a/tutorials/distributed-ml/torch-scaling-test/deepspeed_trainer.py b/tutorials/distributed-ml/torch-scaling-test/deepspeed_trainer.py index e60220216..1cc5f8489 100644 --- a/tutorials/distributed-ml/torch-scaling-test/deepspeed_trainer.py +++ b/tutorials/distributed-ml/torch-scaling-test/deepspeed_trainer.py @@ -1,28 +1,25 @@ """ Scaling test of Microsoft Deepspeed on Imagenet using Resnet. """ -from typing import Optional import argparse -import sys import os -from timeit import default_timer as timer +import sys import time -import deepspeed +from timeit import default_timer as timer +from typing import Optional +import deepspeed import torch import torch.distributed as dist import torch.nn.functional as F +import torchvision from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -import torchvision +from utils import imagenet_dataset -from itwinai.parser import ArgumentParser as ItAIArgumentParser from itwinai.loggers import EpochTimeTracker -from itwinai.torch.reproducibility import ( - seed_worker, set_seed -) - -from utils import imagenet_dataset +from itwinai.parser import ArgumentParser as ItAIArgumentParser +from itwinai.torch.reproducibility import seed_worker, set_seed def parse_params(): @@ -222,8 +219,9 @@ def main(): print('--------------------------------------------------------') nnod = os.environ.get('SLURM_NNODES', 'unk') epoch_time_tracker = EpochTimeTracker( - series_name="deepspeed-bl", - csv_file=f"epochtime_deepspeed-bl_{nnod}N.csv" + strategy_name="deepspeed-bl", + save_path=f"epochtime_deepspeed-bl_{nnod}N.csv", + num_nodes=int(nnod) ) et = timer() diff --git a/tutorials/distributed-ml/torch-scaling-test/horovod_trainer.py b/tutorials/distributed-ml/torch-scaling-test/horovod_trainer.py index a4c3eaa47..6e27d9216 100755 --- a/tutorials/distributed-ml/torch-scaling-test/horovod_trainer.py +++ b/tutorials/distributed-ml/torch-scaling-test/horovod_trainer.py @@ -1,29 +1,27 @@ """ Scaling test of Horovod on Imagenet using Resnet. """ -from typing import Optional import argparse import os import sys -from timeit import default_timer as timer import time +from timeit import default_timer as timer +from typing import Optional +import horovod.torch as hvd import torch + # import torch.multiprocessing as mp import torch.nn.functional as F import torch.optim as optim +import torchvision from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -import horovod.torch as hvd -import torchvision +from utils import imagenet_dataset -from itwinai.parser import ArgumentParser as ItAIArgumentParser from itwinai.loggers import EpochTimeTracker -from itwinai.torch.reproducibility import ( - seed_worker, set_seed -) - -from utils import imagenet_dataset +from itwinai.parser import ArgumentParser as ItAIArgumentParser +from itwinai.torch.reproducibility import seed_worker, set_seed def parse_params(): @@ -264,8 +262,8 @@ def main(): print('--------------------------------------------------------') nnod = os.environ.get('SLURM_NNODES', 'unk') epoch_time_tracker = EpochTimeTracker( - series_name="horovod-bl", - csv_file=f"epochtime_horovod-bl_{nnod}N.csv" + strategy_name="horovod-bl", + save_path=f"epochtime_horovod-bl_{nnod}N.csv" ) et = timer() diff --git a/tutorials/distributed-ml/torch-scaling-test/itwinai_trainer.py b/tutorials/distributed-ml/torch-scaling-test/itwinai_trainer.py index b6bb99a03..f98c553b5 100644 --- a/tutorials/distributed-ml/torch-scaling-test/itwinai_trainer.py +++ b/tutorials/distributed-ml/torch-scaling-test/itwinai_trainer.py @@ -269,8 +269,8 @@ def main(): nnod = os.environ.get('SLURM_NNODES', 'unk') s_name = f"{args.strategy}-it" epoch_time_tracker = EpochTimeTracker( - series_name=s_name, - csv_file=f"epochtime_{s_name}_{nnod}N.csv" + strategy_name=s_name, + save_path=f"epochtime_{s_name}_{nnod}N.csv" ) et = timer() diff --git a/use-cases/eurac/plots/absolute_scalability_plot.png b/use-cases/eurac/plots/absolute_scalability_plot.png new file mode 100644 index 000000000..4ea81658e Binary files /dev/null and b/use-cases/eurac/plots/absolute_scalability_plot.png differ diff --git a/use-cases/eurac/plots/communication_plot.png b/use-cases/eurac/plots/communication_plot.png index c40459d02..dec75113c 100644 Binary files a/use-cases/eurac/plots/communication_plot.png and b/use-cases/eurac/plots/communication_plot.png differ diff --git a/use-cases/eurac/plots/gpu_energy_plot.png b/use-cases/eurac/plots/gpu_energy_plot.png index 618ec8c83..d951859e4 100644 Binary files a/use-cases/eurac/plots/gpu_energy_plot.png and b/use-cases/eurac/plots/gpu_energy_plot.png differ diff --git a/use-cases/eurac/plots/relative_scalability_plot.png b/use-cases/eurac/plots/relative_scalability_plot.png new file mode 100644 index 000000000..71d7b06de Binary files /dev/null and b/use-cases/eurac/plots/relative_scalability_plot.png differ diff --git a/use-cases/eurac/plots/utilization_plot.png b/use-cases/eurac/plots/utilization_plot.png new file mode 100644 index 000000000..2878c70e2 Binary files /dev/null and b/use-cases/eurac/plots/utilization_plot.png differ diff --git a/use-cases/eurac/trainer.py b/use-cases/eurac/trainer.py index 770c82591..5d0fbfa7c 100644 --- a/use-cases/eurac/trainer.py +++ b/use-cases/eurac/trainer.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from timeit import default_timer as timer +from timeit import default_timer from typing import Dict, Literal, Optional, Union, Any, Tuple import pandas as pd @@ -154,13 +154,17 @@ def train(self): """Override version of hython to support distributed strategy.""" # Tracking epoch times for scaling test if self.strategy.is_main_worker: - num_nodes = os.environ.get("SLURM_NNODES", "unk") - series_name = os.environ.get("DIST_MODE", "unk") + "-torch" - file_name = f"epochtime_{series_name}_{num_nodes}N.csv" - file_path = Path("logs_epoch") / file_name + num_nodes = int(os.environ.get("SLURM_NNODES", "unk")) + epoch_time_output_dir = Path("scalability-metrics/epoch-time") + epoch_time_file_name = f"epochtime_{self.strategy.name}_{num_nodes}N.csv" + epoch_time_output_path = epoch_time_output_dir / epoch_time_file_name + epoch_time_tracker = EpochTimeTracker( - series_name=series_name, csv_file=file_path + strategy_name=self.strategy.name, + save_path=epoch_time_output_path, + num_nodes=num_nodes ) + trainer = RNNTrainer( RNNTrainParams( experiment=self.config.experiment, @@ -182,7 +186,7 @@ def train(self): best_loss = float("inf") for epoch in tqdm(range(self.epochs)): - epoch_start_time = timer() + epoch_start_time = default_timer() self.set_epoch(epoch) self.model.train() @@ -249,12 +253,11 @@ def train(self): best_loss = avg_val_loss best_model = self.model.state_dict() - epoch_end_time = timer() - epoch_time_tracker.add_epoch_time( - epoch - 1, epoch_end_time - epoch_start_time - ) + epoch_time = default_timer() - epoch_start_time + epoch_time_tracker.add_epoch_time(epoch + 1, epoch_time) if self.strategy.is_main_worker: + epoch_time_tracker.save() self.model.load_state_dict(best_model) self.log( item=self.model, diff --git a/use-cases/virgo/runall.sh b/use-cases/virgo/runall.sh old mode 100644 new mode 100755 diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index 1813a2a4a..3cba3474b 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -16,7 +16,8 @@ from itwinai.torch.config import TrainingConfiguration from itwinai.torch.distributed import DeepSpeedStrategy, RayDeepSpeedStrategy, RayDDPStrategy from itwinai.torch.trainer import TorchTrainer, RayTorchTrainer -from deepspeed.accelerator import get_accelerator +from itwinai.torch.monitoring.monitoring import measure_gpu_utilization +from itwinai.torch.profiling.profiler import profile_torch_trainer class VirgoTrainingConfiguration(TrainingConfiguration): @@ -169,6 +170,8 @@ def custom_collate(self, batch): return torch.cat(batch) + @profile_torch_trainer + @measure_gpu_utilization def train(self): # Start the timer for profiling st = timer() @@ -185,8 +188,8 @@ def train(self): nnod = os.environ.get('SLURM_NNODES', 'unk') s_name = f"{os.environ.get('DIST_MODE', 'unk')}-torch" epoch_time_tracker = EpochTimeTracker( - series_name=s_name, - csv_file=f"epochtime_{s_name}_{nnod}N.csv" + strategy_name=s_name, + save_path=f"epochtime_{s_name}_{nnod}N.csv" ) loss_plot = [] val_loss_plot = []