Skip to content

Commit

Permalink
Incorporate PR comments (most importantly, change inheritance for ray…
Browse files Browse the repository at this point in the history
… strategies)
  • Loading branch information
annaelisalappe committed Nov 26, 2024
1 parent 038c1ab commit 78d57f2
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 92 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ dependencies = [
# "prov4ml@git+https://github.com/HPCI-Lab/ProvML@main", # Prov4ML
# "prov4ml@git+https://github.com/matbun/ProvML@main",
"pandas",
"seaborn"
"seaborn",
"ray[default, train, tune]"
]

# dynamic = ["version", "description"]
Expand Down
122 changes: 34 additions & 88 deletions src/itwinai/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.modules import Module
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset, DistributedSampler, Sampler
Expand Down Expand Up @@ -48,6 +47,28 @@ def wrapper(self: "TorchDistributedStrategy", *args, **kwargs):
return wrapper


def _initialize_ray() -> None:
"""This method is used by the RayDDPStrategy and RayDeepSpeedStrategy to initialize
the Ray backend if it is not already initialized.
Raises:
EnvironmentError: If required environment variables `IP_HEAD` or `HEAD_NODE_IP`
are not set. These should be set from the slurm script where the ray cluster
is launched.
"""
if not ray.is_initialized():
IP_HEAD = os.environ.get("IP_HEAD")
HEAD_NODE_IP = os.environ.get("HEAD_NODE_IP")

if not IP_HEAD or not HEAD_NODE_IP:
raise EnvironmentError(
"Ray initialization requires env variables 'IP_HEAD' and \
'HEAD_NODE_IP' to be set."
)

ray.init(address="auto")


class TorchDistributedStrategy(DistributedStrategy):
"""Abstract class to define the distributed backend methods for
PyTorch models.
Expand Down Expand Up @@ -1009,27 +1030,11 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0):
return [tensor]


class RayDistributedStrategy(TorchDistributedStrategy):
def _initialize_ray(self) -> None:
if not ray.is_initialized():
try:
ip_head = os.environ.get("ip_head")
head_node_ip = os.environ.get("head_node_ip")

if not ip_head or not head_node_ip:
raise EnvironmentError(
"Ray initialization requires 'ip_head' and 'head_node_ip' to be set."
)

except Exception as e:
raise RuntimeError(f"Error initializing Ray: {str(e)}")

ray.init(address="auto")
class RayDDPStrategy(TorchDistributedStrategy):
"""A distributed data-parallel (DDP) strategy using Ray Train for PyTorch training."""


class RayDDPStrategy(RayDistributedStrategy):
def __init__(self) -> None:
super()._initialize_ray()
_initialize_ray()

def init(self) -> None:
self.is_initialized = True
Expand All @@ -1056,7 +1061,7 @@ def distributed(
model: nn.Module,
optimizer: Optimizer,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, Optimizer, LRScheduler | None]:
) -> Tuple[nn.Module, Optimizer, LRScheduler | None]:
model = ray.train.torch.prepare_model(model)

return model, optimizer, lr_scheduler
Expand Down Expand Up @@ -1115,72 +1120,13 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None:
pass


class RayDeepSpeedStrategy(RayDistributedStrategy):
def __init__(self) -> None:
import deepspeed

self.deepspeed = deepspeed
class RayDeepSpeedStrategy(DeepSpeedStrategy):
"""A distributed strategy using Ray and DeepSpeed for PyTorch training.
super()._initialize_ray()

def init(self) -> None:
# This block of code should be removed as some point
if os.environ.get("LOCAL_RANK"):
os.environ["OMPI_COMM_WORLD_LOCAL_RANK"] = os.environ.get("LOCAL_RANK")

self.deepspeed.init_distributed()

print("Deepspeed initialized")
self.is_initialized = True

self.set_device()

@check_initialized
def global_world_size(self) -> int:
return dist.get_world_size()

@check_initialized
def local_world_size(self) -> int:
return torch.cuda.device_count()

@check_initialized
def global_rank(self) -> int:
return dist.get_rank()

@check_initialized
def local_rank(self) -> int:
return dist.get_rank() % torch.cuda.device_count()

@check_initialized
def distributed(
self,
model: Module,
optimizer: Optimizer,
lr_scheduler: Optional[LRScheduler] = None,
model_parameters: Optional[Any] = None,
**init_kwargs,
) -> Tuple[Module | Optimizer | LRScheduler | None]:
master_port = os.environ.get("MASTER_PORT")

distrib_model, optimizer, _, lr_scheduler = self.deepspeed.initialize(
model=model,
model_parameters=model_parameters,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
distributed_port=master_port,
dist_init_required=True,
**init_kwargs,
)
return distrib_model, optimizer, lr_scheduler

def clean_up(self) -> None:
pass

def allgather_obj(self, obj: Any) -> List[Any]:
pass

def gather_obj(self, obj: Any, dst_rank: int = 0) -> List[Any]:
pass
Args:
backend (Literal["nccl", "gloo", "mpi"]): The backend for distributed communication.
"""

def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None:
pass
def __init__(self, backend: Literal["nccl", "gloo", "mpi"]) -> None:
_initialize_ray()
super.__init__(backend=backend)
6 changes: 3 additions & 3 deletions use-cases/virgo/slurm_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ port=7639 # This port will be used by Ray to communicate with worker nodes

# This is so that the ray.init() command called from the hpo.py script knows
# which ports to connect to
export ip_head="$head_node"i:"$port"
export head_node_ip="$head_node"i
export IP_HEAD="$head_node"i:"$port"
export HEAD_NODE_IP="$head_node"i

export MASTER_ADDR=$head_node_ip
export MASTER_ADDR=$HEAD_NODE_IP
export MASTER_PORT=$port

echo "Starting HEAD at $head_node"
Expand Down

0 comments on commit 78d57f2

Please sign in to comment.