Skip to content

Commit

Permalink
Revert "Use SMP rank and size when applicable (#411)" (#424)
Browse files Browse the repository at this point in the history
This reverts commit 07a3fd9.

Co-authored-by: Nihal Harish <[email protected]>
  • Loading branch information
ndodda-amazon and NihalHarish authored Jan 16, 2021
1 parent c6554a7 commit e431609
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 89 deletions.
103 changes: 36 additions & 67 deletions smdebug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,54 +23,6 @@
from smdebug.exceptions import IndexReaderException

_is_invoked_via_smddp = None

try:
import smdistributed.modelparallel.tensorflow as smp

_smp_imported = smp
except (ImportError, ModuleNotFoundError):
try:
import smdistributed.modelparallel.torch as smp

_smp_imported = smp
except (ImportError, ModuleNotFoundError):
_smp_imported = None


try:
import torch.distributed as dist

_torch_dist_imported = dist
except (ImportError, ModuleNotFoundError):
_torch_dist_imported = None


try:
import horovod.torch as hvd

_hvd_imported = hvd
except (ModuleNotFoundError, ImportError):
try:
import horovod.tensorflow as hvd

_hvd_imported = hvd
except (ModuleNotFoundError, ImportError):
_hvd_imported = None


try:
import smdistributed.dataparallel.torch.distributed as smdataparallel

_smdataparallel_imported = smdataparallel
except (ModuleNotFoundError, ImportError):
try:
import smdistributed.dataparallel.tensorflow as smdataparallel

_smdataparallel_imported = smdataparallel
except (ModuleNotFoundError, ImportError):
_smdataparallel_imported = None


logger = get_logger()


Expand Down Expand Up @@ -365,34 +317,51 @@ def get_tb_worker():


def get_distributed_worker():
"""
Get the rank for horovod or torch distributed. If none of them are being used,
"""Get the rank for horovod or torch distributed. If none of them are being used,
return None"""
rank = None
if (
_torch_dist_imported
and hasattr(_torch_dist_imported, "is_initialized")
and _torch_dist_imported.is_initialized()
):
rank = _torch_dist_imported.get_rank()
elif _smp_imported and smp.core.initialized:
rank = smp.rank()
elif check_smdataparallel_env():
# smdistributed.dataparallel should be invoked via `mpirun`.
# It supports EC2 machines with 8 GPUs per machine.
assert smdataparallel is not None
try:
import torch.distributed as dist
except (ImportError, ModuleNotFoundError):
dist = None
rank = None
if dist and hasattr(dist, "is_initialized") and dist.is_initialized():
rank = dist.get_rank()
else:
try:
if smdataparallel.get_world_size():
return smdataparallel.get_rank()
except ValueError:
import horovod.torch as hvd

if hvd.size():
rank = hvd.rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass
elif _hvd_imported:

try:
import horovod.tensorflow as hvd

if hvd.size():
rank = hvd.rank()
except ValueError:
except (ModuleNotFoundError, ValueError, ImportError):
pass

# smdistributed.dataparallel should be invoked via `mpirun`.
# It supports EC2 machines with 8 GPUs per machine.
if check_smdataparallel_env():
try:
import smdistributed.dataparallel.torch.distributed as smdataparallel

if smdataparallel.get_world_size():
return smdataparallel.get_rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass

try:
import smdistributed.dataparallel.tensorflow as smdataparallel

if smdataparallel.size():
return smdataparallel.rank()
except (ModuleNotFoundError, ValueError, ImportError):
pass
return rank


Expand Down
24 changes: 2 additions & 22 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@
)

try:
import smdistributed.modelparallel.tensorflow as smp # noqa isort:skip

_smp_importable = True
pass
except ImportError:
_smp_importable = False
pass


DEFAULT_INCLUDE_COLLECTIONS = [
Expand Down Expand Up @@ -185,15 +183,6 @@ def _get_worker_name(self) -> str:
"""
self._assert_distribution_strategy()
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
if _smp_importable:
# when model parallel is being used, there will be multiple processes
# with same hvd rank, hence use smp.rank
import smdistributed.modelparallel.tensorflow as smp

if smp.core.initialized:
# if smp is in use
return f"worker_{smp.rank()}"

import horovod.tensorflow as hvd

return f"worker_{hvd.rank()}"
Expand Down Expand Up @@ -271,15 +260,6 @@ def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["C
def _get_num_workers(self):
self._assert_distribution_strategy()
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
if _smp_importable:
# when model parallel is being used, there will be multiple hvd process groups,
# hence use smp.size
import smdistributed.modelparallel.tensorflow as smp

if smp.core.initialized:
# if smp is in use
return smp.size()

import horovod.tensorflow as hvd

return hvd.size()
Expand Down

0 comments on commit e431609

Please sign in to comment.