-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'ray-torchtrainer-integration' into hpo-tutorial
- Loading branch information
Showing
8 changed files
with
73 additions
and
153 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
# - Matteo Bunino <[email protected]> - CERN | ||
# - Jarl Sondre Sæther <[email protected]> - CERN | ||
# - Henry Mutegeki <[email protected]> - CERN | ||
# - Anna Lappe <[email protected]> - CERN | ||
# -------------------------------------------------------------------------------------- | ||
|
||
import abc | ||
|
@@ -63,18 +64,18 @@ def _initialize_ray() -> None: | |
the Ray backend if it is not already initialized. | ||
Raises: | ||
EnvironmentError: If required environment variables `IP_HEAD` or `HEAD_NODE_IP` | ||
are not set. These should be set from the slurm script where the ray cluster | ||
is launched. | ||
EnvironmentError: If required environment variables `HEAD_NODE_PORT` or | ||
`HEAD_NODE_IP` are not set. | ||
These should be set from the slurm script where the ray cluster is launched. | ||
""" | ||
if not ray.is_initialized(): | ||
IP_HEAD = os.environ.get("IP_HEAD") | ||
HEAD_NODE_PORT = os.environ.get("HEAD_NODE_PORT") | ||
HEAD_NODE_IP = os.environ.get("HEAD_NODE_IP") | ||
|
||
if not IP_HEAD or not HEAD_NODE_IP: | ||
if not HEAD_NODE_PORT or not HEAD_NODE_IP: | ||
raise EnvironmentError( | ||
"Ray initialization requires env variables 'IP_HEAD' and \ | ||
'HEAD_NODE_IP' to be set." | ||
"Ray initialization requires env variables 'HEAD_NODE_PORT' and " | ||
"'HEAD_NODE_IP' to be set." | ||
) | ||
|
||
ray.init(address="auto") | ||
|
@@ -1041,7 +1042,7 @@ def gather(self, tensor: torch.Tensor, dst_rank: int = 0): | |
return [tensor] | ||
|
||
|
||
class RayDDPStrategy(TorchDistributedStrategy): | ||
class RayDDPStrategy(TorchDDPStrategy): | ||
"""A distributed data-parallel (DDP) strategy using Ray Train for PyTorch training.""" | ||
|
||
def __init__(self) -> None: | ||
|
@@ -1098,7 +1099,7 @@ def create_dataloader( | |
persistent_workers: bool = False, | ||
pin_memory_device: str = "", | ||
): | ||
dataloader = super().create_dataloader( | ||
dataloader = DataLoader( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
sampler=sampler, | ||
|
@@ -1118,18 +1119,6 @@ def create_dataloader( | |
|
||
return ray.train.torch.prepare_data_loader(dataloader) | ||
|
||
def clean_up(self) -> None: | ||
pass | ||
|
||
def allgather_obj(self, obj: Any) -> List[Any]: | ||
pass | ||
|
||
def gather_obj(self, obj: Any, dst_rank: int = 0) -> List[Any]: | ||
pass | ||
|
||
def gather(self, tensor: torch.Tensor, dst_rank: int = 0) -> List | None: | ||
pass | ||
|
||
|
||
class RayDeepSpeedStrategy(DeepSpeedStrategy): | ||
"""A distributed strategy using Ray and DeepSpeed for PyTorch training. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters