Skip to content

Commit

Permalink
Introducing a connection timeout in a ifrt proxy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696187916
  • Loading branch information
Google-ML-Automation committed Nov 13, 2024
1 parent c65b83d commit 9741339
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/python/ifrt_proxy/client/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/time",
"@nanobind",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:statusor",
Expand Down
11 changes: 11 additions & 0 deletions xla/python/ifrt_proxy/client/py_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.
#include "xla/python/ifrt_proxy/client/py_module.h"

#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
Expand All @@ -26,6 +27,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/function.h" // IWYU pragma: keep
#include "nanobind/stl/optional.h" // IWYU pragma: keep
Expand All @@ -48,6 +50,7 @@ namespace {
struct PyClientConnectionOptions {
std::optional<std::function<void(std::string)>> on_disconnect;
std::optional<std::function<void(std::string)>> on_connection_update;
std::optional<int64_t> connection_timeout_in_seconds;
};

absl::StatusOr<nb_class_ptr<PyClient>> GetClient(
Expand Down Expand Up @@ -90,6 +93,11 @@ absl::StatusOr<nb_class_ptr<PyClient>> GetClient(
};
}

if (py_options.connection_timeout_in_seconds.has_value()) {
options.connection_timeout =
absl::Seconds(*py_options.connection_timeout_in_seconds);
}

{
nb::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(client, CreateClient(proxy_server_address, options));
Expand All @@ -110,6 +118,9 @@ void BuildIfrtProxySubmodule(nb::module_& m) {
nb::arg().none())
.def_rw("on_connection_update",
&PyClientConnectionOptions::on_connection_update,
nb::arg().none())
.def_rw("connection_timeout_in_seconds",
&PyClientConnectionOptions::connection_timeout_in_seconds,
nb::arg().none());

sub_module.def("get_client", xla::ValueOrThrowWrapper(GetClient),
Expand Down
6 changes: 6 additions & 0 deletions xla/python/ifrt_proxy/jax/ifrt_proxy_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ class ConnectionOptions:
on_connection_update: Optional, a callback that will be called with status
updates about initial connection establishment. The updates will be
provided as human-readable strings, and an end-user may find them helpful.
connection_timeout_in_seconds: Optional, the timeout for establishing a
connection to the proxy server.
"""

on_disconnect: Optional[Callable[[str], None]] = None
on_connection_update: Optional[Callable[[str], None]] = None
connection_timeout_in_seconds: Optional[int] = None


_backend_created: bool = False
Expand All @@ -52,6 +55,9 @@ def get_client(proxy_server_address: str) -> xla_client.Client:
cpp_options = py_module.ClientConnectionOptions()
cpp_options.on_disconnect = _connection_options.on_disconnect
cpp_options.on_connection_update = _connection_options.on_connection_update
cpp_options.connection_timeout_in_seconds = (
_connection_options.connection_timeout_in_seconds
)
client = py_module.get_client(proxy_server_address, cpp_options)
if client is not None:
_backend_created = True
Expand Down
2 changes: 2 additions & 0 deletions xla/python/xla_extension/ifrt_proxy.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================

import datetime
from typing import Any, Optional, Callable

from xla.python import xla_extension
Expand All @@ -24,6 +25,7 @@ Client = xla_extension.Client
class ClientConnectionOptions:
on_disconnect: Optional[Callable[[_Status], None]] = None
on_connection_update: Optional[Callable[[str], None]] = None
connection_timeout_in_seconds: Optional[int] = None


def get_client(
Expand Down

0 comments on commit 9741339

Please sign in to comment.