diff --git a/demos/run_vizier_client.py b/demos/run_vizier_client.py index 28a2a884c..747bfe3c8 100644 --- a/demos/run_vizier_client.py +++ b/demos/run_vizier_client.py @@ -14,15 +14,19 @@ """Example of a Vizier Client, which can be run on multiple machines. -This is meant to be used after the Vizier Server (see `run_vizier_server.py`) -has been launched and provided an address to connect to. Example of a launch -command: +For distributed cases, this is meant to be used after the Vizier Server (see +run_vizier_server.py`) has been launched and provided an address to connect to. +Example of a launch command: ``` python run_vizier_client.py --address="localhost:[PORT]" ``` where `address` was provided by the server. + +If not provided, the Vizier Server will be created locally, which still allows +parallelization via multithreading, but will not be able to coordinate jobs +across different machines. """ from typing import Sequence @@ -35,8 +39,8 @@ from vizier.service import pyvizier as vz flags.DEFINE_string( - 'address', '', - "Address of the Vizier Server which will be used by this demo. Should be of the form e.g. 'localhost:6006' if running on the same machine, or `[IP]:[PORT]` if running on a remote machine." + 'address', clients.NO_ENDPOINT, + "Address of the Vizier Server which will be used by this demo. Should be of the form e.g. 'localhost:6006' if running on the same machine, or `[IP]:[PORT]` if running on a remote machine. If unset, a local Vizier server will be created inside this process." ) flags.DEFINE_integer( 'max_num_iterations', 10, @@ -68,11 +72,13 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') - if not FLAGS.address: + if FLAGS.address == clients.NO_ENDPOINT: logging.info( - 'You did not specify the server address. Please see the documentation on the `address` FLAGS.' + 'You did not specify the server address. The Vizier Service will be created locally.' ) - clients.environment_variables.service_endpoint = FLAGS.address # Set address. + else: + # Set address. + clients.environment_variables.service_endpoint = FLAGS.address study_config = vz.StudyConfig() # Search space, metrics, and algorithm. root = study_config.search_space.root diff --git a/docs/guides/user/running_vizier.ipynb b/docs/guides/user/running_vizier.ipynb index 500181824..547f7f2f9 100644 --- a/docs/guides/user/running_vizier.ipynb +++ b/docs/guides/user/running_vizier.ipynb @@ -93,45 +93,67 @@ { "cell_type": "markdown", "metadata": { - "id": "SwaqnJECLtwa" + "id": "5e2B91UvZYIM" }, "source": [ - "## Setting up the server\n", - "As a basic example, the server will be run locally, and will process all client requests. However, it can also be run remotely, to service multiple clients in a distributed manner.\n" + "## Setting up the client\n", + "Starts a `study_client`. By default, it will implicitly create a local Vizier Server which will be shared across other clients in the same Python process." ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "1IEUw_UfLvtx" + "id": "X2AR4OmXX3in" }, "outputs": [], "source": [ - "service = vizier_service.DefaultVizierService()" + "study_client = clients.Study.from_study_config(\n", + " study_config, owner='owner', study_id='example_study_id')" ] }, { "cell_type": "markdown", "metadata": { - "id": "5e2B91UvZYIM" + "id": "w3m48cPsXcxD" }, "source": [ - "## Setting up the client\n", - "Starts a `study_client`, which will connect to the server. Multiple machines can simultaneously call the code below to work on the same study, useful for parallelizing evaluation workload." + "## Distributed Setup\n", + "When using multiple Python processes (on a single machine or over multiple machines), we may explicitly create the server in a separate process to accept requests from all other client processes." ] }, { "cell_type": "code", "execution_count": null, "metadata": { - "id": "X2AR4OmXX3in" + "id": "V6ef6OfMXdpz" }, "outputs": [], "source": [ + "service = vizier_service.DefaultVizierService() # Ideally created on a separate process such as a server machine.\n", "clients.environment_variables.service_endpoint = service.endpoint # Server address.\n", - "study_client = clients.Study.from_study_config(\n", - " study_config, owner='owner', study_id='example_study_id')" + "study_client = clients.Study.from_study_config(study_config, owner='owner', study_id = 'example_study_id') # Now connects to the explicitly created service." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z0Ycmc-exzqm" + }, + "source": [ + "Regardless of whether the setup is local or distributed, we may simultaneously create multiple clients to work on the same study, useful for parallelizing evaluation workload." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VlfFb5t3yILl" + }, + "outputs": [], + "source": [ + "another_study_client = clients.Study.from_resource_name(\n", + " study_client.resource_name)" ] }, { diff --git a/vizier/__init__.py b/vizier/__init__.py index 596f83031..dc3b65a45 100644 --- a/vizier/__init__.py +++ b/vizier/__init__.py @@ -21,4 +21,4 @@ sys.path.append(PROTO_ROOT) -__version__ = "0.0.13" +__version__ = "0.0.14" diff --git a/vizier/service/clients.py b/vizier/service/clients.py index f67fb4b9d..3752cbdd6 100644 --- a/vizier/service/clients.py +++ b/vizier/service/clients.py @@ -25,11 +25,15 @@ from vizier.service import vizier_client from vizier.service import vizier_service_pb2_grpc +NO_ENDPOINT = vizier_client.NO_ENDPOINT + +# TODO: Consider if user should set a one-line flag explicitly to +# denote local NO_ENDPOINT server will be used. @attr.define class _EnviromentVariables: service_endpoint: str = attr.field( - default='UNSET', validator=attr.validators.instance_of(str)) + default=NO_ENDPOINT, validator=attr.validators.instance_of(str)) environment_variables = _EnviromentVariables() diff --git a/vizier/service/vizier_client.py b/vizier/service/vizier_client.py index 4a36fc922..6c5253172 100644 --- a/vizier/service/vizier_client.py +++ b/vizier/service/vizier_client.py @@ -18,6 +18,7 @@ """ import datetime +import functools import time from typing import Any, Dict, List, Mapping, Optional, Tuple, Union @@ -29,6 +30,7 @@ from vizier.service import resources from vizier.service import stubs_util from vizier.service import study_pb2 +from vizier.service import vizier_server from vizier.service import vizier_service_pb2 from vizier.service import vizier_service_pb2_grpc from vizier.utils import attrs_utils @@ -50,6 +52,21 @@ VizierService = Union[vizier_service_pb2_grpc.VizierServiceStub, vizier_service_pb2_grpc.VizierServiceServicer] +NO_ENDPOINT = 'NO_ENDPOINT' + + +@functools.cache +def _create_local_vizier_server( +) -> vizier_service_pb2_grpc.VizierServiceServicer: + return vizier_server.VizierService() + + +def _create_vizier_server_or_stub(endpoint: str) -> VizierService: + if endpoint == NO_ENDPOINT: + logging.info('Using cached local Vizier server.') + return _create_local_vizier_server() + return stubs_util.create_vizier_server_stub(endpoint) + @attr.frozen(init=True) class VizierClient: @@ -84,7 +101,8 @@ def from_endpoint(cls, service_endpoint: str, study_resource_name: str, Args: service_endpoint: Address of VizierService for creation of gRPC stub, e.g. - 'localhost:8998'. + 'localhost:8998'. If equal to UNSET_ENDPOINT, creates a local Vizier + server inside the client. study_resource_name: An identifier of the study. The full study name will be `owners/{owner_id}/studies/{study_id}`. client_id: An ID that identifies the worker requesting a `Trial`. Workers @@ -98,8 +116,8 @@ def from_endpoint(cls, service_endpoint: str, study_resource_name: str, Vizier client. """ return cls( - stubs_util.create_vizier_server_stub(service_endpoint), - study_resource_name, client_id) + _create_vizier_server_or_stub(service_endpoint), study_resource_name, + client_id) @property def _owner_id(self) -> str: @@ -175,7 +193,8 @@ def report_intermediate_objective_value( metric_list: List[Mapping[str, Union[int, float]]], trial_id: int, ) -> pyvizier.Trial: - """Sends intermediate objective value for the trial identified by trial_id.""" + """Sends intermediate objective value for the trial identified by trial_id. + """ new_metric_list = [] for metric in metric_list: for metric_name in metric: @@ -216,7 +235,8 @@ def complete_trial( trial_id: int, final_measurement: Optional[pyvizier.Measurement] = None, infeasibility_reason: Optional[str] = None) -> pyvizier.Trial: - """Completes the trial, which is infeasible if given a infeasibility_reason.""" + """Completes the trial, which is infeasible if given a infeasibility_reason. + """ request = vizier_service_pb2.CompleteTrialRequest( name=resources.TrialResource(self._owner_id, self._study_id, trial_id).name, @@ -358,7 +378,8 @@ def create_or_load_study( Args: service_endpoint: Address of VizierService for creation of gRPC stub, e.g. - 'localhost:8998'. + 'localhost:8998'. If equal to UNSET_ENDPOINT, creates a local Vizier + server inside the client. owner_id: An owner id. client_id: ID for the VizierClient. See class for notes. study_id: Each study is uniquely identified by the tuple (owner_id, @@ -377,7 +398,7 @@ def create_or_load_study( ValueError: Indicates that study_config is not supplied and the study with the given study_id does not exist. """ - vizier_stub = stubs_util.create_vizier_server_stub(service_endpoint) + vizier_stub = _create_vizier_server_or_stub(service_endpoint) study = study_pb2.Study( display_name=study_id, study_spec=study_config.to_proto()) request = vizier_service_pb2.CreateStudyRequest( diff --git a/vizier/service/vizier_client_test.py b/vizier/service/vizier_client_test.py index 540201ab2..f4e772d2a 100644 --- a/vizier/service/vizier_client_test.py +++ b/vizier/service/vizier_client_test.py @@ -27,6 +27,7 @@ from vizier.service import study_pb2 from vizier.service import vizier_client from vizier.service import vizier_service +from vizier.service import vizier_service_pb2_grpc from absl.testing import absltest from absl.testing import parameterized @@ -255,6 +256,37 @@ def test_update_metadata(self): on_study=on_study_metadata, on_trials={1: on_trial1_metadata}) self.client.update_metadata(metadata_delta) + def test_unset_endpoint_client(self): + study_id = 'dummy_study' + study_config = pyvizier.StudyConfig() + study_resource_name = resources.StudyResource(self.owner_id, study_id).name + + # Check if server is stored in client. + local_client1 = vizier_client.create_or_load_study( + service_endpoint=vizier_client.NO_ENDPOINT, + owner_id=self.owner_id, + client_id='local_client1', + study_id=study_id, + study_config=study_config) + self.assertIsInstance(local_client1._server, + vizier_service_pb2_grpc.VizierServiceServicer) + + # Check if the local server is shared. + local_client2 = vizier_client.VizierClient.from_endpoint( + service_endpoint=vizier_client.NO_ENDPOINT, + study_resource_name=study_resource_name, + client_id='local_client2') + self.assertEqual(local_client1._server, local_client2._server) + + # Same server still exists globally in cache after clients are deleted. + del local_client1 + del local_client2 + local_client3 = vizier_client.VizierClient.from_endpoint( + service_endpoint=vizier_client.NO_ENDPOINT, + study_resource_name=study_resource_name, + client_id='local_client3') + self.assertLen(local_client3.list_studies(), 1) + if __name__ == '__main__': absltest.main()