Skip to content

Commit

Permalink
pytorch#2213 refactor signature checker / update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fco-dv committed Nov 23, 2021
1 parent c402cd7 commit b5a9ac6
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 51 deletions.
24 changes: 9 additions & 15 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,21 +277,15 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]:
fn_params_name = set(
map(
lambda param: param.name,
filter(
lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values()
),
)
)
extra_keys = kwargs_dict.keys() - fn_params_name
if extra_keys:
warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.")
for k in extra_keys:
del kwargs_dict[k]
return kwargs_dict
def _check_signature(self, fn: Callable, **kwargs: Any) -> Dict[str, Any]:
try:
fn_signature = signature(fn)
fn_signature.bind(**kwargs)
except TypeError:
extra_params = kwargs.keys() - set(fn_signature.parameters)
warnings.warn(f"Extra params : {extra_params} will not be used by {self._backend}.)")
kwargs = {key: kwargs[key] for key in kwargs.keys() if key not in extra_params}
return kwargs

@abstractmethod
def barrier(self, **kwargs: Any) -> None:
Expand Down
9 changes: 2 additions & 7 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)

def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=hvd.allreduce, kwargs_dict=kwargs)
if "tensor" in kwargs:
del kwargs["tensor"]
if "name" in kwargs:
del kwargs["name"]
kwargs = self._check_signature(fn=hvd.allreduce, **kwargs)
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(tensor=torch.tensor(0, device="cpu"), name="barrier", **kwargs)
hvd.allreduce(**kwargs)
2 changes: 1 addition & 1 deletion ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return tensor

def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=dist.barrier, kwargs_dict=kwargs)
kwargs = self._check_signature(fn=dist.barrier, **kwargs)
dist.barrier(**kwargs)

def _expand_hostlist(nodelist: str) -> List[str]:
Expand Down
6 changes: 2 additions & 4 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return tensor

def barrier(self, **kwargs: Any) -> None:
kwargs = self._check_barrier_fn_kwargs(barrier_fn=xm.rendezvous, kwargs_dict=kwargs)
if "tag" in kwargs:
del kwargs["tag"]
xm.rendezvous(tag="barrier", **kwargs)
kwargs = self._check_signature(fn=xm.rendezvous, **kwargs)
xm.rendezvous(**kwargs)
2 changes: 0 additions & 2 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,10 +433,8 @@ def barrier(**kwargs: Any) -> None:
- | "horovod" : ``average`` (default, None), ``compression`` (default, Compression.none),
| ``op`` (default, None), ``prescale_factor`` (default, 1.0), ``postscale_factor`` (default, 1.0),
| ``process_set`` (default, global_process_set).
| Arguments ``tensor=torch.tensor(0, device="cpu")`` and ``name="barrier"`` are redefined.
- | "xla-tpu" : ``payload`` (default, b""), ``replicas`` (default, []).
| Argument ``tag="barrier"`` is redefined.
.. versionchanged:: 0.5.1
Method now accepts ``kwargs`` for all supported backends.
Expand Down
16 changes: 8 additions & 8 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,14 @@ def distributed_context_multi_node_nccl(multi_node_conf):
_destroy_mnodes_dist_context()


def _xla_template_worker_task(index, fn, args):
def _xla_template_worker_task(index, fn, args, kwargs=None):
import torch_xla.core.xla_model as xm

xm.rendezvous("init")
fn(index, *args)
fn(index, *args, **kwargs)


def _xla_execute(fn, args, nprocs):
def _xla_execute(fn, args, nprocs, kwargs_dict=None):

import torch_xla.distributed.xla_multiprocessing as xmp

Expand All @@ -275,7 +275,7 @@ def _xla_execute(fn, args, nprocs):
spawn_kwargs["start_method"] = "fork"

try:
xmp.spawn(_xla_template_worker_task, args=(fn, args), nprocs=nprocs, **spawn_kwargs)
xmp.spawn(_xla_template_worker_task, args=(fn, args, kwargs_dict), nprocs=nprocs, **spawn_kwargs)
except SystemExit as ex_:
assert ex_.code == 0, "Didn't successfully exit in XLA test"

Expand All @@ -294,19 +294,19 @@ def mock_gpu_is_not_available():
yield mock_cuda


def _hvd_task_with_init(func, args):
def _hvd_task_with_init(func, args, kwargs):
import horovod.torch as hvd

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

func(*args)
func(*args, **kwargs)
hvd.shutdown()


def _gloo_hvd_execute(func, args, np=1, do_init=False):
def _gloo_hvd_execute(func, args, np=1, do_init=False, kwargs_dict=None):
try:
# old API
from horovod.run.runner import run
Expand All @@ -317,7 +317,7 @@ def _gloo_hvd_execute(func, args, np=1, do_init=False):
kwargs = dict(use_gloo=True, np=np)

if do_init:
return run(_hvd_task_with_init, args=(func, args), **kwargs)
return run(_hvd_task_with_init, args=(func, args, kwargs_dict), **kwargs)

return run(func, args=args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/distributed/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,15 +215,15 @@ def _test(data_src, data_others, safe_mode):
idist.broadcast(None, src=0)


def _test_distrib_barrier(device, kwargs_dict=None):
def _test_distrib_barrier(device, **kwargs):

t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float)
true_res = sum([i for i in range(idist.get_world_size())])

if idist.get_rank() == 0:
t += 10.0

idist.barrier(**kwargs_dict) if kwargs_dict else idist.barrier()
idist.barrier(**kwargs)

tt = idist.all_reduce(t)
assert tt.item() == true_res + 10.0
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_idist_barrier_kwargs_hvd(gloo_hvd_executor):
postscale_factor=1.0,
process_set=global_process_set,
)
gloo_hvd_executor(_test_distrib_barrier, (device, kwargs_dict,), np=np, do_init=True)
gloo_hvd_executor(_test_distrib_barrier, (device,), np=np, do_init=True, kwargs_dict=kwargs_dict)


def _test_idist_methods_overhead(ok_factor, sync_model):
Expand Down
12 changes: 6 additions & 6 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,13 +281,13 @@ def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl):
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl."
UserWarning, match=r"Extra params : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl."
):
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)


@pytest.mark.distributed
Expand All @@ -298,13 +298,13 @@ def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo):
from torch.distributed import GroupMember

kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None}
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)

kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []})
with pytest.warns(
UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo."
UserWarning, match=r"Extra params : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo."
):
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)


def _test_idist_methods_overhead(ok_factor):
Expand Down
10 changes: 5 additions & 5 deletions tests/ignite/distributed/utils/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_idist_barrier_xla():

def _test_idist_barrier_xla_in_child_proc(index, kwargs_dict=None):
device = idist.device()
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)


@pytest.mark.tpu
Expand All @@ -196,16 +196,16 @@ def test_idist_barrier_kwargs_xla():

device = idist.device()
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)

from torch.distributed import GroupMember

kwargs_dict.update({"group": GroupMember.WORLD, "async_op": False, "device_ids": None})
with pytest.warns(
UserWarning,
match=r"Extra keys : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.",
match=r"Extra params : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.",
):
_test_distrib_barrier(device, kwargs_dict)
_test_distrib_barrier(device, **kwargs_dict)


@pytest.mark.tpu
Expand All @@ -222,7 +222,7 @@ def test_idist_barrier_xla_in_child_proc(xmp_executor):
def test_idist_barrier_kwargs_xla_in_child_proc(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []}
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(kwargs_dict,), nprocs=n)
xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(), nprocs=n, kwargs_dict=kwargs_dict)


@pytest.mark.tpu
Expand Down

0 comments on commit b5a9ac6

Please sign in to comment.