diff --git a/ignite/distributed/comp_models/base.py b/ignite/distributed/comp_models/base.py index aef14246e89..532881c9de6 100644 --- a/ignite/distributed/comp_models/base.py +++ b/ignite/distributed/comp_models/base.py @@ -277,21 +277,15 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor: def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: pass - def _check_barrier_fn_kwargs(self, barrier_fn: Callable, kwargs_dict: Dict[str, Any]) -> Dict[str, Any]: - fn_params_name = set( - map( - lambda param: param.name, - filter( - lambda param: param.kind == param.POSITIONAL_OR_KEYWORD, signature(barrier_fn).parameters.values() - ), - ) - ) - extra_keys = kwargs_dict.keys() - fn_params_name - if extra_keys: - warnings.warn(f"Extra keys : {extra_keys} will not be used by {self._backend}.") - for k in extra_keys: - del kwargs_dict[k] - return kwargs_dict + def _check_signature(self, fn: Callable, **kwargs: Any) -> Dict[str, Any]: + try: + fn_signature = signature(fn) + fn_signature.bind(**kwargs) + except TypeError: + extra_params = kwargs.keys() - set(fn_signature.parameters) + warnings.warn(f"Extra params : {extra_params} will not be used by {self._backend}.)") + kwargs = {key: kwargs[key] for key in kwargs.keys() if key not in extra_params} + return kwargs @abstractmethod def barrier(self, **kwargs: Any) -> None: diff --git a/ignite/distributed/comp_models/horovod.py b/ignite/distributed/comp_models/horovod.py index a143146f097..20c409b701f 100644 --- a/ignite/distributed/comp_models/horovod.py +++ b/ignite/distributed/comp_models/horovod.py @@ -195,11 +195,6 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return hvd.broadcast(tensor, root_rank=src) def barrier(self, **kwargs: Any) -> None: - kwargs = self._check_barrier_fn_kwargs(barrier_fn=hvd.allreduce, kwargs_dict=kwargs) - if "tensor" in kwargs: - del kwargs["tensor"] - if "name" in kwargs: - del kwargs["name"] + kwargs = self._check_signature(fn=hvd.allreduce, **kwargs) # https://github.com/horovod/horovod/issues/159#issuecomment-424834603 - # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier") - hvd.allreduce(tensor=torch.tensor(0, device="cpu"), name="barrier", **kwargs) + hvd.allreduce(**kwargs) diff --git a/ignite/distributed/comp_models/native.py b/ignite/distributed/comp_models/native.py index af152f1aec2..5bf39592c62 100644 --- a/ignite/distributed/comp_models/native.py +++ b/ignite/distributed/comp_models/native.py @@ -433,7 +433,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return tensor def barrier(self, **kwargs: Any) -> None: - kwargs = self._check_barrier_fn_kwargs(barrier_fn=dist.barrier, kwargs_dict=kwargs) + kwargs = self._check_signature(fn=dist.barrier, **kwargs) dist.barrier(**kwargs) def _expand_hostlist(nodelist: str) -> List[str]: diff --git a/ignite/distributed/comp_models/xla.py b/ignite/distributed/comp_models/xla.py index eca77197bc2..2c437baee62 100644 --- a/ignite/distributed/comp_models/xla.py +++ b/ignite/distributed/comp_models/xla.py @@ -161,7 +161,5 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: return tensor def barrier(self, **kwargs: Any) -> None: - kwargs = self._check_barrier_fn_kwargs(barrier_fn=xm.rendezvous, kwargs_dict=kwargs) - if "tag" in kwargs: - del kwargs["tag"] - xm.rendezvous(tag="barrier", **kwargs) + kwargs = self._check_signature(fn=xm.rendezvous, **kwargs) + xm.rendezvous(**kwargs) diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index ca6c7db4f01..9667ede28fb 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -433,10 +433,8 @@ def barrier(**kwargs: Any) -> None: - | "horovod" : ``average`` (default, None), ``compression`` (default, Compression.none), | ``op`` (default, None), ``prescale_factor`` (default, 1.0), ``postscale_factor`` (default, 1.0), | ``process_set`` (default, global_process_set). - | Arguments ``tensor=torch.tensor(0, device="cpu")`` and ``name="barrier"`` are redefined. - | "xla-tpu" : ``payload`` (default, b""), ``replicas`` (default, []). - | Argument ``tag="barrier"`` is redefined. .. versionchanged:: 0.5.1 Method now accepts ``kwargs`` for all supported backends. diff --git a/tests/ignite/conftest.py b/tests/ignite/conftest.py index ce52495c32c..f58044770c4 100644 --- a/tests/ignite/conftest.py +++ b/tests/ignite/conftest.py @@ -259,14 +259,14 @@ def distributed_context_multi_node_nccl(multi_node_conf): _destroy_mnodes_dist_context() -def _xla_template_worker_task(index, fn, args): +def _xla_template_worker_task(index, fn, args, kwargs=None): import torch_xla.core.xla_model as xm xm.rendezvous("init") - fn(index, *args) + fn(index, *args, **kwargs) -def _xla_execute(fn, args, nprocs): +def _xla_execute(fn, args, nprocs, kwargs_dict=None): import torch_xla.distributed.xla_multiprocessing as xmp @@ -275,7 +275,7 @@ def _xla_execute(fn, args, nprocs): spawn_kwargs["start_method"] = "fork" try: - xmp.spawn(_xla_template_worker_task, args=(fn, args), nprocs=nprocs, **spawn_kwargs) + xmp.spawn(_xla_template_worker_task, args=(fn, args, kwargs_dict), nprocs=nprocs, **spawn_kwargs) except SystemExit as ex_: assert ex_.code == 0, "Didn't successfully exit in XLA test" @@ -294,7 +294,7 @@ def mock_gpu_is_not_available(): yield mock_cuda -def _hvd_task_with_init(func, args): +def _hvd_task_with_init(func, args, kwargs): import horovod.torch as hvd hvd.init() @@ -302,11 +302,11 @@ def _hvd_task_with_init(func, args): if torch.cuda.is_available(): torch.cuda.set_device(lrank) - func(*args) + func(*args, **kwargs) hvd.shutdown() -def _gloo_hvd_execute(func, args, np=1, do_init=False): +def _gloo_hvd_execute(func, args, np=1, do_init=False, kwargs_dict=None): try: # old API from horovod.run.runner import run @@ -317,7 +317,7 @@ def _gloo_hvd_execute(func, args, np=1, do_init=False): kwargs = dict(use_gloo=True, np=np) if do_init: - return run(_hvd_task_with_init, args=(func, args), **kwargs) + return run(_hvd_task_with_init, args=(func, args, kwargs_dict), **kwargs) return run(func, args=args, **kwargs) diff --git a/tests/ignite/distributed/utils/__init__.py b/tests/ignite/distributed/utils/__init__.py index 800aea18dba..3019e5a1e4d 100644 --- a/tests/ignite/distributed/utils/__init__.py +++ b/tests/ignite/distributed/utils/__init__.py @@ -215,7 +215,7 @@ def _test(data_src, data_others, safe_mode): idist.broadcast(None, src=0) -def _test_distrib_barrier(device, kwargs_dict=None): +def _test_distrib_barrier(device, **kwargs): t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float) true_res = sum([i for i in range(idist.get_world_size())]) @@ -223,7 +223,7 @@ def _test_distrib_barrier(device, kwargs_dict=None): if idist.get_rank() == 0: t += 10.0 - idist.barrier(**kwargs_dict) if kwargs_dict else idist.barrier() + idist.barrier(**kwargs) tt = idist.all_reduce(t) assert tt.item() == true_res + 10.0 diff --git a/tests/ignite/distributed/utils/test_horovod.py b/tests/ignite/distributed/utils/test_horovod.py index 9b30889d6fc..5fc6bcb84a1 100644 --- a/tests/ignite/distributed/utils/test_horovod.py +++ b/tests/ignite/distributed/utils/test_horovod.py @@ -204,7 +204,7 @@ def test_idist_barrier_kwargs_hvd(gloo_hvd_executor): postscale_factor=1.0, process_set=global_process_set, ) - gloo_hvd_executor(_test_distrib_barrier, (device, kwargs_dict,), np=np, do_init=True) + gloo_hvd_executor(_test_distrib_barrier, (device,), np=np, do_init=True, kwargs_dict=kwargs_dict) def _test_idist_methods_overhead(ok_factor, sync_model): diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index 3a49b836123..42d33441095 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -281,13 +281,13 @@ def test_idist_barrier_kwargs_nccl(distributed_context_single_node_nccl): from torch.distributed import GroupMember kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None} - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []}) with pytest.warns( - UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl." + UserWarning, match=r"Extra params : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by nccl." ): - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) @pytest.mark.distributed @@ -298,13 +298,13 @@ def test_idist_barrier_kwargs_gloo(distributed_context_single_node_gloo): from torch.distributed import GroupMember kwargs_dict = {"group": GroupMember.WORLD, "async_op": False, "device_ids": None} - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) kwargs_dict.update({"tag": "barrier", "payload": b"", "replicas": []}) with pytest.warns( - UserWarning, match=r"Extra keys : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo." + UserWarning, match=r"Extra params : \{((, )?('payload'|'replicas'|'tag')(, )?)+\} will not be used by gloo." ): - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) def _test_idist_methods_overhead(ok_factor): diff --git a/tests/ignite/distributed/utils/test_xla.py b/tests/ignite/distributed/utils/test_xla.py index b66ecc408a5..dddcf5985ae 100644 --- a/tests/ignite/distributed/utils/test_xla.py +++ b/tests/ignite/distributed/utils/test_xla.py @@ -186,7 +186,7 @@ def test_idist_barrier_xla(): def _test_idist_barrier_xla_in_child_proc(index, kwargs_dict=None): device = idist.device() - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) @pytest.mark.tpu @@ -196,16 +196,16 @@ def test_idist_barrier_kwargs_xla(): device = idist.device() kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []} - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) from torch.distributed import GroupMember kwargs_dict.update({"group": GroupMember.WORLD, "async_op": False, "device_ids": None}) with pytest.warns( UserWarning, - match=r"Extra keys : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.", + match=r"Extra params : \{((, )?('async_op'|'group'|'device_ids')(, )?)+\} will not be used by xla-tpu.", ): - _test_distrib_barrier(device, kwargs_dict) + _test_distrib_barrier(device, **kwargs_dict) @pytest.mark.tpu @@ -222,7 +222,7 @@ def test_idist_barrier_xla_in_child_proc(xmp_executor): def test_idist_barrier_kwargs_xla_in_child_proc(xmp_executor): n = int(os.environ["NUM_TPU_WORKERS"]) kwargs_dict = {"tag": "barrier", "payload": b"", "replicas": []} - xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(kwargs_dict,), nprocs=n) + xmp_executor(_test_idist_barrier_xla_in_child_proc, args=(), nprocs=n, kwargs_dict=kwargs_dict) @pytest.mark.tpu