Skip to content

Commit

Permalink
Scalability report update - Communication Plot [cleaned up] (#231)
Browse files Browse the repository at this point in the history
* make comm profiler into decorator

* add tryout profiler script with mnist

* add code for testing out pytorch profiler

* add dummy script for monitoring and profiling mnist

* add functionality for multi node GPU utilization

* split code into files and create analyzation

* update profiler to handle multi-gpu

* add comm vs comp analysis

* remove adjustable output path

* create script for comm plot

* add slurm script for comm calculation

* update jupyter notebook

* add docstrings, error handling and more comm patterns

* Do data analysis

* add docstrings etc.

* update slurm script

* make comm profiler into decorator

* accommodate asymmetric runs and make table prettier

* make comm profiler into decorator

* add scheduler to profiler

* remove regex dependency from file names in comm plot

* add dynamic specification of directories for comm plot generator

* small bugfix and black formatter

* format code

* fix linting errors

* remove unused files and create new directory for gpu-monitoring

* update docstrings

* move imports into function in cli

* move profiler to own file

* move communication plot to torch folder

* add deepspeed import in ds strategy

* fix linting errors

* remove gpu-monitoring files for this branch

* add another communication entry

* remove plots

* fix small docstring typo

* remove plots and small cleanup

* move profiling files into new profiling module

* move horovod imports and create new profiling module

* fix diffs

---------

Co-authored-by: Jarl Saether <[email protected]>
  • Loading branch information
matbun and jarlsondre authored Oct 17, 2024
1 parent 474034f commit 30765e0
Show file tree
Hide file tree
Showing 13 changed files with 629 additions and 214 deletions.
267 changes: 162 additions & 105 deletions src/itwinai/cli.py

Large diffs are not rendered by default.

76 changes: 38 additions & 38 deletions src/itwinai/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@
>>> python my_train.py --config training_pipe.yaml --lr 0.002
"""


from __future__ import annotations

import functools
import time
from abc import ABC, abstractmethod

# import logging
# from logging import Logger as PythonLogger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from .serialization import ModelLoader, Serializable
Expand All @@ -113,6 +115,7 @@ def wrapper(self: BaseComponent, *args, **kwargs) -> Any:
msg = f"'{self.name}' executed in {self.exec_t:.3f}s"
self._printout(msg)
return result

return wrapper


Expand All @@ -127,7 +130,8 @@ class BaseComponent(ABC, Serializable):
Args:
name (Optional[str], optional): unique identifier for a step.
Defaults to None.
"""
"""

_name: str = None
#: Dictionary storing constructor arguments. Needed to serialize the
#: class to dictionary. Set by ``self.save_parameters()`` method.
Expand All @@ -144,11 +148,8 @@ def __init__(

@property
def name(self) -> str:
"""Name of current component. Defaults to ``self.__class__.__name__``.
"""
return (
self._name if self._name is not None else self.__class__.__name__
)
"""Name of current component. Defaults to ``self.__class__.__name__``."""
return self._name if self._name is not None else self.__class__.__name__

@name.setter
def name(self, name: str) -> None:
Expand All @@ -157,7 +158,7 @@ def name(self, name: str) -> None:
@abstractmethod
@monitor_exec
def execute(self, *args, **kwargs) -> Any:
""""Execute some operations."""
"""Execute some operations."""

# def setup_console(self):
# """Setup Python logging"""
Expand Down Expand Up @@ -186,9 +187,9 @@ def cleanup(self):
@staticmethod
def _printout(msg: str):
msg = f"# {msg} #"
print("#"*len(msg))
print("#" * len(msg))
print(msg)
print("#"*len(msg))
print("#" * len(msg))


class DataGetter(BaseComponent):
Expand All @@ -213,7 +214,7 @@ def execute(
self,
train_dataset: MLDataset,
validation_dataset: MLDataset,
test_dataset: MLDataset
test_dataset: MLDataset,
) -> Tuple[MLDataset, MLDataset, MLDataset]:
"""Trains a machine learning model.
Expand All @@ -230,6 +231,7 @@ def execute(

class DataSplitter(BaseComponent):
"""Splits a dataset into train, validation, and test splits."""

_train_proportion: Union[int, float]
_validation_proportion: Union[int, float]
_test_proportion: Union[int, float]
Expand All @@ -239,7 +241,7 @@ def __init__(
train_proportion: Union[int, float],
validation_proportion: Union[int, float],
test_proportion: Union[int, float],
name: Optional[str] = None
name: Optional[str] = None,
) -> None:
super().__init__(name)
self.save_parameters(**self.locals2params(locals()))
Expand Down Expand Up @@ -291,10 +293,7 @@ def test_proportion(self, prop: Union[int, float]) -> None:

@abstractmethod
@monitor_exec
def execute(
self,
dataset: MLDataset
) -> Tuple[MLDataset, MLDataset, MLDataset]:
def execute(self, dataset: MLDataset) -> Tuple[MLDataset, MLDataset, MLDataset]:
"""Splits a dataset into train, validation and test splits.
Args:
Expand All @@ -315,7 +314,7 @@ def execute(
self,
train_dataset: MLDataset,
validation_dataset: MLDataset,
test_dataset: MLDataset
test_dataset: MLDataset,
) -> Tuple[MLDataset, MLDataset, MLDataset, MLModel]:
"""Trains a machine learning model.
Expand Down Expand Up @@ -348,9 +347,7 @@ def __init__(
@abstractmethod
@monitor_exec
def execute(
self,
predict_dataset: MLDataset,
model: Optional[MLModel] = None
self, predict_dataset: MLDataset, model: Optional[MLModel] = None
) -> MLDataset:
"""Applies a machine learning model on a dataset of samples.
Expand Down Expand Up @@ -433,26 +430,29 @@ def execute(self, *args) -> Tuple:
"""
result = []
for itm in self.policy:
if isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX):
arg_idx = int(itm[len(self.INPUT_PREFIX):])
if arg_idx >= len(args):
max_idx = max(map(
lambda itm: int(itm[len(self.INPUT_PREFIX):]),
if not (isinstance(itm, str) and itm.startswith(self.INPUT_PREFIX)):
result.append(itm)
continue

arg_idx = int(itm[len(self.INPUT_PREFIX) :])
if arg_idx >= len(args):
max_idx = max(
map(
lambda itm: int(itm[len(self.INPUT_PREFIX) :]),
filter(
lambda el: (
isinstance(el, str)
and el.startswith(self.INPUT_PREFIX)
isinstance(el, str) and el.startswith(self.INPUT_PREFIX)
),
self.policy
)))
raise IndexError(
f"The args received as input by '{self.name}' "
"are not consistent with the given adapter policy "
"because input args are too few! "
f"Input args are {len(args)} but the policy foresees "
f"at least {max_idx+1} items."
self.policy,
),
)
result.append(args[arg_idx])
else:
result.append(itm)
)
raise IndexError(
f"The args received as input by '{self.name}' "
"are not consistent with the given adapter policy "
"because input args are too few! "
f"Input args are {len(args)} but the policy foresees "
f"at least {max_idx+1} items."
)
result.append(args[arg_idx])
return tuple(result)
51 changes: 27 additions & 24 deletions src/itwinai/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import os
from typing import Any, Iterable, List, Literal, Optional, Tuple, Union

import deepspeed
import horovod.torch as hvd
import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -565,10 +563,7 @@ class DeepSpeedStrategy(TorchDistributedStrategy):
#: Torch distributed communication backend.
backend: Literal['nccl', 'gloo', 'mpi']

def __init__(
self,
backend: Literal['nccl', 'gloo', 'mpi']
) -> None:
def __init__(self, backend: Literal['nccl', 'gloo', 'mpi']) -> None:
super().__init__()
self.backend = backend

Expand All @@ -581,6 +576,8 @@ def init(self) -> None:
DistributedStrategyError: when trying to initialize a strategy
already initialized.
"""
import deepspeed
self.deepspeed = deepspeed
if not distributed_resources_available():
raise RuntimeError(
"Trying to run distributed on insufficient resources.")
Expand All @@ -591,10 +588,11 @@ def init(self) -> None:
# https://github.com/Lightning-AI/pytorch-lightning/issues/13567
ompi_lrank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK')
os.environ['OMPI_COMM_WORLD_LOCAL_RANK'] = os.environ.get(
'LOCAL_RANK', ompi_lrank)
'LOCAL_RANK', ompi_lrank
)

# https://deepspeed.readthedocs.io/en/latest/initialize.html#training-initialization
deepspeed.init_distributed(dist_backend=self.backend)
self.deepspeed.init_distributed(dist_backend=self.backend)
self.is_initialized = True

self.set_device()
Expand All @@ -608,9 +606,10 @@ def distributed(
"""Setup model, optimizer and scheduler for distributed."""
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
"Strategy has not been initialized. Use the init method."
)

distrib_model, optimizer, _, lr_scheduler = deepspeed.initialize(
distrib_model, optimizer, _, lr_scheduler = self.deepspeed.initialize(
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
Expand Down Expand Up @@ -752,7 +751,11 @@ def init(self) -> None:
"Trying to run distributed on insufficient resources.")
if self.is_initialized:
raise DistributedStrategyError("Strategy was already initialized")
hvd.init()

import horovod.torch as hvd
self.hvd = hvd

self.hvd.init()
self.is_initialized = True

self.set_device()
Expand All @@ -772,16 +775,16 @@ def distributed(
# Scale learning rate
# https://github.com/horovod/horovod/issues/1653#issuecomment-574764452
lr_scaler = 1
if optim_kwargs.get('op') == hvd.Adasum:
lr_scaler = hvd.local_size()
elif optim_kwargs.get('op') == hvd.Average:
lr_scaler = hvd.size()
if optim_kwargs.get('op') == self.hvd.Adasum:
lr_scaler = self.hvd.local_size()
elif optim_kwargs.get('op') == self.hvd.Average:
lr_scaler = self.hvd.size()
for g in optimizer.param_groups:
g['lr'] *= lr_scaler

self._broadcast_params(model, optimizer)

distOptimizer = hvd.DistributedOptimizer(
distOptimizer = self.hvd.DistributedOptimizer(
optimizer,
named_parameters=model.named_parameters(),
**optim_kwargs
Expand All @@ -799,8 +802,8 @@ def _broadcast_params(
optimizer (optim.Optimizer): Optimizer that is to be broadcasted
across processes.
"""
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=-0)
self.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
self.hvd.broadcast_optimizer_state(optimizer, root_rank=-0)

def global_world_size(self) -> int:
"""Returns the total number of processes (global world size).
Expand All @@ -811,7 +814,7 @@ def global_world_size(self) -> int:
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
return hvd.size()
return self.hvd.size()

def local_world_size(self) -> int:
"""Returns the local number of workers available per node,
Expand All @@ -823,7 +826,7 @@ def local_world_size(self) -> int:
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
return hvd.local_size()
return self.hvd.local_size()

def global_rank(self) -> int:
"""Returns the global rank of the current process, where
Expand All @@ -835,7 +838,7 @@ def global_rank(self) -> int:
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
return hvd.rank()
return self.hvd.rank()

def local_rank(self) -> int:
"""Returns the local rank of the current process.
Expand All @@ -846,14 +849,14 @@ def local_rank(self) -> int:
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
return hvd.local_rank()
return self.hvd.local_rank()

def clean_up(self) -> None:
"""Shuts Horovod down."""
if not self.is_initialized:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method.")
hvd.shutdown()
self.hvd.shutdown()

def allgather_obj(self, obj: Any) -> list[Any]:
"""All-gathers scalar objects across all workers to a
Expand All @@ -869,7 +872,7 @@ def allgather_obj(self, obj: Any) -> list[Any]:
raise UninitializedStrategyError(
"Strategy has not been initialized. Use the init method."
)
return hvd.allgather_object(obj)
return self.hvd.allgather_object(obj)

def gather_obj(self, obj: Any, dst_rank: int = 0) -> list[Any]:
"""The same as ``allgather_obj``, as gather is not supported
Expand Down
Empty file.
Loading

0 comments on commit 30765e0

Please sign in to comment.