diff --git a/tensorflow_federated/python/core/environments/xla_backend/runtime.py b/tensorflow_federated/python/core/environments/xla_backend/runtime.py index e3ede7b118..adcd21a880 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/runtime.py +++ b/tensorflow_federated/python/core/environments/xla_backend/runtime.py @@ -14,6 +14,7 @@ """Runtime components for use by the XLA executor.""" from jax.lib import xla_client +from jax.lib import xla_extension import numpy as np from tensorflow_federated.proto.v0 import computation_pb2 as pb @@ -116,7 +117,7 @@ def __init__( 'Unsupported computation type: {}'.format(which_computation) ) xla_comp = xla_serialization.unpack_xla_computation(comp_pb.xla.hlo_module) - mhlo_module = xla_client._xla.mlir.xla_computation_to_mlir_module(xla_comp) + mhlo_module = xla_extension.mlir.xla_computation_to_mlir_module(xla_comp) compile_options = xla_client.CompileOptions() self._executable = backend.compile(mhlo_module, compile_options) self._inverted_parameter_tensor_indexes = list(