From 1e7f33d10c83e238aa202dd421f03a59c589489f Mon Sep 17 00:00:00 2001 From: Chris Rawles Date: Thu, 18 Jan 2024 06:40:28 -0800 Subject: [PATCH] Add a loader for connecting to an already running emulator. PiperOrigin-RevId: 599491105 --- android_env/loader.py | 64 ++++++++++++++++++++++++-------------- android_env/loader_test.py | 50 ++++++++++++++++++++++++----- 2 files changed, 83 insertions(+), 31 deletions(-) diff --git a/android_env/loader.py b/android_env/loader.py index 1814674..0feb4e7 100644 --- a/android_env/loader.py +++ b/android_env/loader.py @@ -27,13 +27,23 @@ from google.protobuf import text_format -def load(task_path: str, - avd_name: str, - android_avd_home: str = '~/.android/avd', - android_sdk_root: str = '~/Android/Sdk', - emulator_path: str = '~/Android/Sdk/emulator/emulator', - adb_path: str = '~/Android/Sdk/platform-tools/adb', - run_headless: bool = False) -> environment.AndroidEnv: +def _get_task_manager(task_path: str) -> task_manager_lib.TaskManager: + task = task_pb2.Task() + with open(task_path, 'r') as proto_file: + text_format.Parse(proto_file.read(), task) + return task_manager_lib.TaskManager(task) + + +def load( + task_path: str, + avd_name: str | None = None, + android_avd_home: str = '~/.android/avd', + android_sdk_root: str = '~/Android/Sdk', + emulator_path: str = '~/Android/Sdk/emulator/emulator', + adb_path: str = '~/Android/Sdk/platform-tools/adb', + run_headless: bool = False, + console_port: int | None = None, +) -> environment.AndroidEnv: """Loads an AndroidEnv instance. Args: @@ -44,33 +54,41 @@ def load(task_path: str, emulator_path: Path to the emulator binary. adb_path: Path to the ADB (Android Debug Bridge). run_headless: If True, the emulator display is turned off. + console_port: The console port number; for connecting to an already running + device/emulator. + Returns: env: An AndroidEnv instance. """ + connect_to_existing_device = console_port is not None + if not connect_to_existing_device and avd_name is None: + raise ValueError('An avd name must be provided if launching an emulator.') + + if connect_to_existing_device: + launcher_args = dict( + emulator_console_port=console_port, + adb_port=console_port + 1, + grpc_port=8554, + ) + else: + launcher_args = dict( + avd_name=avd_name, + android_avd_home=os.path.expanduser(android_avd_home), + android_sdk_root=os.path.expanduser(android_sdk_root), + emulator_path=os.path.expanduser(emulator_path), + run_headless=run_headless, + gpu_mode='swiftshader_indirect', + ) # Create simulator. simulator = emulator_simulator.EmulatorSimulator( - emulator_launcher_args=dict( - avd_name=avd_name, - android_avd_home=os.path.expanduser(android_avd_home), - android_sdk_root=os.path.expanduser(android_sdk_root), - emulator_path=os.path.expanduser(emulator_path), - run_headless=run_headless, - gpu_mode='swiftshader_indirect', - ), + emulator_launcher_args=launcher_args, adb_controller_config=config_classes.AdbControllerConfig( adb_path=os.path.expanduser(adb_path), adb_server_port=5037, ), ) - # Prepare task. - task = task_pb2.Task() - with open(task_path, 'r') as proto_file: - text_format.Parse(proto_file.read(), task) - - task_manager = task_manager_lib.TaskManager(task) + task_manager = _get_task_manager(task_path) coordinator = coordinator_lib.Coordinator(simulator, task_manager) - - # Load environment. return environment.AndroidEnv(coordinator=coordinator) diff --git a/android_env/loader_test.py b/android_env/loader_test.py index 0c601cc..cb9462a 100644 --- a/android_env/loader_test.py +++ b/android_env/loader_test.py @@ -35,7 +35,9 @@ class LoaderTest(absltest.TestCase): @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) - def test_load(self, mock_open, coordinator, simulator, task_manager): + def test_load( + self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager + ): mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = '' @@ -51,7 +53,7 @@ def test_load(self, mock_open, coordinator, simulator, task_manager): ) self.assertIsInstance(env, environment.AndroidEnv) - simulator.assert_called_with( + mock_simulator_class.assert_called_with( emulator_launcher_args=dict( avd_name='my_avd', android_avd_home=os.path.expanduser('~/.android/avd'), @@ -65,18 +67,50 @@ def test_load(self, mock_open, coordinator, simulator, task_manager): adb_server_port=5037, ), ) - coordinator.assert_called_with( - simulator.return_value, - task_manager.return_value, + mock_coordinator.assert_called_with( + mock_simulator_class.return_value, + mock_task_manager.return_value, ) @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) - def test_task(self, mock_open, coordinator, simulator, task_manager): + def test_load_existing_device( + self, mock_open, mock_coordinator, mock_simulator_class, mock_task_manager + ): + mock_open.return_value.__enter__ = mock_open + mock_open.return_value.read.return_value = '' - del coordinator, simulator + env = loader.load( + task_path='some/path/', + console_port=5554, + adb_path='~/Android/Sdk/platform-tools/adb', + ) + + self.assertIsInstance(env, environment.AndroidEnv) + mock_simulator_class.assert_called_with( + emulator_launcher_args=dict( + emulator_console_port=5554, adb_port=5555, grpc_port=8554 + ), + adb_controller_config=config_classes.AdbControllerConfig( + adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'), + adb_server_port=5037, + ), + ) + mock_coordinator.assert_called_with( + mock_simulator_class.return_value, + mock_task_manager.return_value, + ) + + @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) + @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) + @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) + @mock.patch.object(builtins, 'open', autospec=True) + def test_task( + self, mock_open, mock_coordinator, mock_simulator, mock_task_manager + ): + del mock_coordinator, mock_simulator mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = r''' id: "fake_task" @@ -96,7 +130,7 @@ def test_task(self, mock_open, coordinator, simulator, task_manager): expected_task.description = 'Task for testing loader.' expected_task.max_episode_sec = 0 - task_manager.assert_called_with(expected_task) + mock_task_manager.assert_called_with(expected_task) assert isinstance(env, environment.AndroidEnv)