diff --git a/MaxText/maxengine_server.py b/MaxText/maxengine_server.py index 36826eba4..741af8e74 100644 --- a/MaxText/maxengine_server.py +++ b/MaxText/maxengine_server.py @@ -19,8 +19,6 @@ import sys import pyconfig -# pylint: disable-next=unused-import -import register_jax_proxy_backend import maxengine_config from jetstream.core import server_lib, config_lib diff --git a/MaxText/register_jax_proxy_backend.py b/MaxText/register_jax_proxy_backend.py deleted file mode 100644 index 6bdf2f7fc..000000000 --- a/MaxText/register_jax_proxy_backend.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Copyright 2024 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""Register the IFRT Proxy as a backend for JAX.""" - -import jax - -try: - from jaxlib.xla_extension import ifrt_proxy - - jax._src.xla_bridge.register_backend_factory( # pylint: disable=protected-access - "proxy", - lambda: ifrt_proxy.get_client( - jax.config.read("jax_backend_target"), - ifrt_proxy.ClientConnectionOptions(), - ), - priority=-1, - ) -except ImportError: - pass diff --git a/MaxText/train.py b/MaxText/train.py index 5ef3161fc..cc045e3f1 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -43,8 +43,6 @@ import optimizers import profiler import pyconfig -# pylint: disable-next=unused-import -import register_jax_proxy_backend from vertex_tensorboard import VertexTensorboardManager # Placeholder: internal