diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 87e37b5e231..a921cf88eab 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -125,18 +125,15 @@ def _aws_ec2_inf_trn_init(): try: from libneuronxla.libneuronpjrt_path import libneuronpjrt_path except ImportError: - pass - else: - # Need to set NEURON_LIBRARY_PATH here for proper Neuron Cache behavior - os.environ.setdefault('NEURON_LIBRARY_PATH', libneuronpjrt_path()) - # Enable addition features and overrides - try: - from torch_neuronx import xla - except ImportError: - pass - else: - xla.init() + # Did not find libneuronxla + return False + # Need to set NEURON_LIBRARY_PATH here for proper Neuron Cache behavior + os.environ.setdefault('NEURON_LIBRARY_PATH', libneuronpjrt_path()) + # Enable addition features and overrides + try: + from torch_neuronx import xla + except ImportError: # Basic initializations if torch-neuronx is not available from ._internal import neuron if os.path.basename(sys.argv[0]) != 'neuron_parallel_compile': @@ -144,10 +141,10 @@ def _aws_ec2_inf_trn_init(): libneuronxla.configure_environment() neuron.set_envvar_defaults() neuron.configure_pjrt_environment() - # Found libneuronxla - return True - # Did not find libneuronxla - return False + else: + xla.init() + # Found libneuronxla + return True def _setup_tpu_vm_library_path() -> bool: