Skip to content

Commit

Permalink
Merge branch 'ray-torchtrainer-integration' into hpo-tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
annaelisalappe committed Nov 28, 2024
2 parents 0162603 + 2b47188 commit 98ee685
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 153 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"prov4ml@git+https://github.com/matbun/ProvML@new-main",
"pandas",
"seaborn",
"ray[default, train, tune]",
"ray[default,train,tune]",
]

# dynamic = ["version", "description"]
Expand Down
31 changes: 10 additions & 21 deletions src/itwinai/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# - Matteo Bunino <[email protected]> - CERN
# - Jarl Sondre Sæther <[email protected]> - CERN
# - Henry Mutegeki <[email protected]> - CERN
# - Anna Lappe <[email protected]> - CERN
# --------------------------------------------------------------------------------------

import abc
Expand Down Expand Up @@ -63,18 +64,18 @@ def _initialize_ray() -> None:
the Ray backend if it is not already initialized.
Raises:
EnvironmentError: If required environment variables `IP_HEAD` or `HEAD_NODE_IP`
are not set. These should be set from the slurm script where the ray cluster
is launched.
EnvironmentError: If required environment variables `HEAD_NODE_PORT` or
`HEAD_NODE_IP` are not set.
These should be set from the slurm script where the ray cluster is launched.
"""
if not ray.is_initialized():
IP_HEAD = os.environ.get("IP_HEAD")
HEAD_NODE_PORT = os.environ.get("HEAD_NODE_PORT")
HEAD_NODE_IP = os.environ.get("HEAD_NODE_IP")

if not IP_HEAD or not HEAD_NODE_IP:
if not HEAD_NODE_PORT or not HEAD_NODE_IP:
raise EnvironmentError(
"Ray initialization requires env variables 'IP_HEAD' and \
'HEAD_NODE_IP' to be set."
"Ray initialization requires env variables 'HEAD_NODE_PORT' and "
"'HEAD_NODE_IP' to be set."
)

ray.init(address="auto")
Expand Down Expand Up @@ -1041,7 +1042,7 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0):
return [tensor]


class RayDDPStrategy(TorchDistributedStrategy):
class RayDDPStrategy(TorchDDPStrategy):
"""A distributed data-parallel (DDP) strategy using Ray Train for PyTorch training."""

def __init__(self) -> None:
Expand Down Expand Up @@ -1098,7 +1099,7 @@ def create_dataloader(
persistent_workers: bool = False,
pin_memory_device: str = "",
):
dataloader = super().create_dataloader(
dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
Expand All @@ -1118,18 +1119,6 @@ def create_dataloader(

return ray.train.torch.prepare_data_loader(dataloader)

def clean_up(self) -> None:
pass

def allgather_obj(self, obj: Any) -> List[Any]:
pass

def gather_obj(self, obj: Any, dst_rank: int = 0) -> List[Any]:
pass

def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None:
pass


class RayDeepSpeedStrategy(DeepSpeedStrategy):
"""A distributed strategy using Ray and DeepSpeed for PyTorch training.
Expand Down
30 changes: 19 additions & 11 deletions src/itwinai/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.utils.data.distributed import DistributedSampler

from itwinai.torch.raytune import get_raytune_schedule, get_raytune_search_alg
from itwinai.torch.tuning import get_raytune_schedule, get_raytune_search_alg

# Imports from this repository
from ..components import Trainer, monitor_exec
Expand Down Expand Up @@ -172,7 +172,7 @@ def strategy(self) -> TorchDistributedStrategy:
return self._strategy

@strategy.setter
def strategy(self, strategy: Union[str, TorchDistributedStrategy]) -> None:
def strategy(self, strategy: str | TorchDistributedStrategy) -> None:
if isinstance(strategy, TorchDistributedStrategy):
self._strategy = strategy
else:
Expand Down Expand Up @@ -1233,7 +1233,7 @@ def dist_train(
return dist_train


DEFAULT_CONFIG = {
DEFAULT_RAY_CONFIG = {
"scaling_config": {
"num_workers": 4, # Default to 4 workers
"use_gpu": True,
Expand All @@ -1253,15 +1253,10 @@ def dist_train(
"learning_rate": 1e-3,
"batch_size": 32,
"epochs": 10,
"shuffle_train": False,
"shuffle_validation": False,
"shuffle_test": False,
"pin_gpu_memory": False,
"optimizer": "adam",
"loss": "cross_entropy",
"optim_momentum": 0.9,
"optim_weight_decay": 0,
"num_workers_dataloader": 4,
"random_seed": 21,
},
}
Expand Down Expand Up @@ -1290,6 +1285,7 @@ def __init__(
self.logger = logger
self._set_strategy_and_init_ray(strategy)
self._set_configs(config=config)
self.torch_rng = set_seed(self.train_loop_config["random_seed"])

def _set_strategy_and_init_ray(self, strategy: str):
"""Set the distributed training strategy. This will initialize the ray backend.
Expand All @@ -1309,7 +1305,8 @@ def _set_strategy_and_init_ray(self, strategy: str):
raise ValueError(f"Unsupported strategy: {strategy}")

def _set_configs(self, config: Dict):
self.config = deep_update(DEFAULT_CONFIG, config)
# TODO: Think about how to implement the config more nicely
self.config = deep_update(DEFAULT_RAY_CONFIG, config)
self._set_scaling_config()
self._set_tune_config()
self._set_run_config()
Expand All @@ -1329,6 +1326,8 @@ def create_dataloaders(
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
batch_size: int = 1,
num_workers_dataloader: int = 4,
pin_memory: bool = False,
shuffle_train: bool | None = False,
shuffle_test: bool | None = False,
shuffle_validation: bool | None = False,
Expand Down Expand Up @@ -1356,6 +1355,9 @@ def create_dataloaders(
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_train,
sampler=sampler,
collate_fn=collate_fn,
Expand All @@ -1364,6 +1366,9 @@ def create_dataloaders(
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_validation,
sampler=sampler,
collate_fn=collate_fn,
Expand All @@ -1374,6 +1379,9 @@ def create_dataloaders(
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers_dataloader,
pin_memory=pin_memory,
generator=self.torch_rng,
shuffle=shuffle_test,
sampler=sampler,
collate_fn=collate_fn,
Expand Down Expand Up @@ -1477,8 +1485,8 @@ def _set_train_loop_config(self):
self.train_loop_config = self._set_searchspace(train_loop_config)
else:
print(
"INFO: No training_loop_config detected. \
No parameters are being tuned or passed to the training function."
"INFO: No training_loop_config detected. "
"No parameters are being tuned or passed to the training function."
)
self.train_loop_config = {}

Expand Down
39 changes: 37 additions & 2 deletions src/itwinai/torch/raytune.py → src/itwinai/torch/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,24 @@
from ray.tune.search.hyperopt import HyperOptSearch


def get_raytune_search_alg(tune_config, seeds=False):
def get_raytune_search_alg(
tune_config, seeds=False
) -> TuneBOHB | BayesOptSearch | HyperOptSearch | None:
"""Get the appropriate Ray Tune search algorithm based on the provided configuration.
Args:
tune_config (Dict): Configuration dictionary specifying the search algorithm,
metric, mode, and, depending on the search algorithm, other parameters.
seeds (bool, optional): Whether to use a fixed seed for reproducibility for some
search algorithms that take a seed. Defaults to False.
Returns:
An instance of the chosen Ray Tune search algorithm or None if no search algorithm is
used or if the search algorithm does not match any of the supported options.
Notes:
- `TuneBOHB` is automatically chosen for BOHB scheduling.
"""
if "scheduler" in tune_config:
scheduler = tune_config["scheduler"]["name"]
else:
Expand Down Expand Up @@ -67,7 +84,25 @@ def get_raytune_search_alg(tune_config, seeds=False):
return None


def get_raytune_schedule(tune_config):
def get_raytune_schedule(
tune_config,
) -> (
AsyncHyperBandScheduler
| HyperBandScheduler
| HyperBandForBOHB
| PopulationBasedTraining
| PB2
| None
):
"""Get the appropriate Ray Tune scheduler based on the provided configuration.
Args:
tune_config (Dict): Configuration dictionary specifying the scheduler type,
metric, mode, and, depending on the scheduler, other parameters.
Returns:
An instance of the chosen Ray Tune scheduler or None if no scheduler is used
or if the scheduler does not match any of the supported options.
"""
scheduler = tune_config["scheduler"]["name"]
metric = tune_config["metric"]
mode = tune_config["mode"]
Expand Down
2 changes: 1 addition & 1 deletion use-cases/virgo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ ray_training_pipeline:
init_args:
config:
scaling_config:
num_workers: 4
num_workers: 2
use_gpu: true
resources_per_worker:
CPU: 5
Expand Down
67 changes: 0 additions & 67 deletions use-cases/virgo/pipeline_runner_for_testing.py

This file was deleted.

2 changes: 1 addition & 1 deletion use-cases/virgo/slurm_ray.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ port=7639 # This port will be used by Ray to communicate with worker nodes

# This is so that the ray.init() command called from the hpo.py script knows
# which ports to connect to
export IP_HEAD="$head_node"i:"$port"
export HEAD_NODE_PORT="$head_node"i:"$port"
export HEAD_NODE_IP="$head_node"i

export MASTER_ADDR=$HEAD_NODE_IP
Expand Down
53 changes: 4 additions & 49 deletions use-cases/virgo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,54 +441,6 @@ def create_model_loss_optimizer(self) -> None:
self.model, self.optimizer, **distribute_kwargs
)

def create_dataloaders(
self,
train_dataset: Dataset,
validation_dataset: Dataset | None = None,
test_dataset: Dataset | None = None,
) -> None:
"""Override the create_dataloaders function to use the custom_collate function."""
# This is the case if a small dataset is used in-memory
# - we can use the default collate_fn function
if isinstance(train_dataset, TensorDataset):
return super().create_dataloaders(
train_dataset=train_dataset,
validation_dataset=validation_dataset,
test_dataset=test_dataset,
)

# If we are using a custom dataset for the large dataset,
# we need to overwrite the collate_fn function
self.train_dataloader = self.strategy.create_dataloader(
dataset=train_dataset,
batch_size=self.training_config["batch_size"],
num_workers=self.training_config["num_workers_dataloader"],
pin_memory=self.training_config["pin_gpu_memory"],
# generator=self.torch_rng,
shuffle=self.training_config["shuffle_train"],
collate_fn=self.custom_collate,
)
if validation_dataset is not None:
self.validation_dataloader = self.strategy.create_dataloader(
dataset=validation_dataset,
batch_size=self.training_config["batch_size"],
num_workers=self.training_config["num_workers_dataloader"],
pin_memory=self.training_config["pin_gpu_memory"],
# generator=self.torch_rng,
shuffle=self.training_config["shuffle_validation"],
collate_fn=self.custom_collate,
)
if test_dataset is not None:
self.test_dataloader = self.strategy.create_dataloader(
dataset=test_dataset,
batch_size=self.training_config["batch_size"],
num_workers=self.training_config["num_workers_dataloader"],
pin_memory=self.training_config["pin_gpu_memory"],
# generator=self.torch_rng,
shuffle=self.training_config["shuffle_test"],
collate_fn=self.custom_collate,
)

def custom_collate(self, batch):
"""Custom collate function to concatenate input tensors along their first dimension."""
# Some batches contain None values,
Expand All @@ -511,7 +463,10 @@ def train(self, config, data):
self.create_model_loss_optimizer()

self.create_dataloaders(
train_dataset=data[0], validation_dataset=data[1], test_dataset=data[2]
train_dataset=data[0],
validation_dataset=data[1],
test_dataset=data[2],
collate_fn=self.custom_collate,
)

self.initialize_logger(hyperparams=config, rank=self.strategy.global_rank())
Expand Down

0 comments on commit 98ee685

Please sign in to comment.