From fd3992473b6c5e4d908fa7caeef84d69768aca1d Mon Sep 17 00:00:00 2001 From: Rui Silva Date: Thu, 19 Sep 2024 23:27:03 +0000 Subject: [PATCH] Introduce multi-node SPMD initialization for Neuron --- torch_xla/runtime.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torch_xla/runtime.py b/torch_xla/runtime.py index e4560df6c704..1039699c4473 100644 --- a/torch_xla/runtime.py +++ b/torch_xla/runtime.py @@ -253,6 +253,14 @@ def use_spmd(auto: Optional[bool] = False): torch_xla._XLAC._xla_set_auto_sharding() os.environ["XLA_AUTO_SPMD"] = "1" + if os.environ[xenv.PJRT_DEVICE] == 'NEURON': + # In case of Neuron, retrigger the PJRT initialization if possible. + try: + from torch_neuronx.initialization import initialize + initialize() + except ImportError: + pass + def is_spmd(): """Returns if SPMD is set for execution."""