From 7482480597f782d5601006c942556ecd9c7f82f9 Mon Sep 17 00:00:00 2001 From: Daniel Toyama Date: Mon, 7 Oct 2024 08:23:24 -0700 Subject: [PATCH] Pass `BaseSimulator` and `TaskManager` directly to `AndroidEnv`. This is the first of a series of changes to keep the `AndroidEnv` and `TaskManager` classes focused on RL interactions (e.g. rewards, begin/end of episodes, resetting etc), while making the rest (e.g. `BaseSimulator`, `Coordinator`) more independent and easier to use in other domains such as LLMs. For now, only `AndroidEnv.stats()` has been changed to minimize diffs, but more will slowly come. PiperOrigin-RevId: 683192156 --- android_env/components/coordinator.py | 5 +-- android_env/environment.py | 16 +++++-- android_env/environment_test.py | 60 ++++++++++++++++++++++----- android_env/loader.py | 4 +- 4 files changed, 67 insertions(+), 18 deletions(-) diff --git a/android_env/components/coordinator.py b/android_env/components/coordinator.py index 21af0b6..266ec7d 100644 --- a/android_env/components/coordinator.py +++ b/android_env/components/coordinator.py @@ -51,6 +51,7 @@ def __init__( Args: simulator: A BaseSimulator instance. task_manager: The TaskManager, responsible for coordinating RL tasks. + config: Settings to customize this Coordinator. """ self._simulator = simulator self._task_manager = task_manager @@ -453,9 +454,7 @@ def _get_time_since_last_observation(self) -> float: def stats(self) -> dict[str, Any]: """Returns various statistics.""" - output = copy.deepcopy(self._stats) - output.update(self._task_manager.stats()) - return output + return copy.deepcopy(self._stats) def load_state( self, request: state_pb2.LoadStateRequest diff --git a/android_env/environment.py b/android_env/environment.py index 8700535..42d9c51 100644 --- a/android_env/environment.py +++ b/android_env/environment.py @@ -20,9 +20,10 @@ from absl import logging from android_env import env_interface from android_env.components import coordinator as coordinator_lib +from android_env.components import task_manager as task_manager_lib +from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 -from android_env.proto import task_pb2 import dm_env import numpy as np @@ -30,10 +31,17 @@ class AndroidEnv(env_interface.AndroidEnvInterface): """An RL environment that interacts with Android apps.""" - def __init__(self, coordinator: coordinator_lib.Coordinator): + def __init__( + self, + simulator: base_simulator.BaseSimulator, + coordinator: coordinator_lib.Coordinator, + task_manager: task_manager_lib.TaskManager, + ): """Initializes the state of this AndroidEnv object.""" + self._simulator = simulator self._coordinator = coordinator + self._task_manager = task_manager self._latest_action = {} self._latest_observation = {} self._latest_extras = {} @@ -133,7 +141,9 @@ def raw_observation(self): return self._latest_observation.copy() def stats(self) -> dict[str, Any]: - return self._coordinator.stats() + coordinator_stats = self._coordinator.stats() + task_manager_stats = self._task_manager.stats() + return coordinator_stats | task_manager_stats def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: return self._coordinator.execute_adb_call(call) diff --git a/android_env/environment_test.py b/android_env/environment_test.py index b4d3e99..93a3933 100644 --- a/android_env/environment_test.py +++ b/android_env/environment_test.py @@ -19,10 +19,12 @@ from absl.testing import absltest from android_env import environment +from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib +from android_env.components import task_manager as task_manager_lib +from android_env.components.simulators.fake import fake_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 -from android_env.proto import task_pb2 import dm_env import numpy as np @@ -47,7 +49,14 @@ def _create_mock_coordinator() -> coordinator_lib.Coordinator: class AndroidEnvTest(absltest.TestCase): def test_specs(self): - env = environment.AndroidEnv(_create_mock_coordinator()) + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) + coordinator = _create_mock_coordinator() + task_manager = mock.create_autospec(task_manager_lib.TaskManager) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) # Check action spec. self.assertNotEmpty(env.action_spec()) @@ -77,7 +86,11 @@ def test_specs(self): self.assertEqual(env.observation_spec()['orientation'].shape, (4,)) def test_reset_and_step(self): - coordinator = mock.create_autospec(coordinator_lib.Coordinator) + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) + coordinator = _create_mock_coordinator() + task_manager = mock.create_autospec(task_manager_lib.TaskManager) coordinator.action_spec.return_value = { 'action_type': dm_env.specs.DiscreteArray(num_values=3), @@ -90,7 +103,9 @@ def test_reset_and_step(self): 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), } - env = environment.AndroidEnv(coordinator) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) coordinator.rl_reset.return_value = dm_env.TimeStep( step_type=dm_env.StepType.FIRST, reward=0.0, @@ -125,9 +140,8 @@ def test_reset_and_step(self): self.assertIn('click', extras) self.assertEqual(extras['click'], np.array([246], dtype=np.int64)) - coordinator.stats.return_value = { - 'my_measurement': 135, - } + coordinator.stats.return_value = {'my_measurement': 135} + task_manager.stats.return_value = {'another_measurement': 79} # Step again in the environment and check expectations again. pixels = np.random.rand(987, 654, 3) @@ -189,8 +203,14 @@ def test_reset_and_step(self): np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) def test_adb_call(self): + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) coordinator = _create_mock_coordinator() - env = environment.AndroidEnv(coordinator) + task_manager = mock.create_autospec(task_manager_lib.TaskManager) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) call = adb_pb2.AdbRequest( force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah')) expected_response = adb_pb2.AdbResponse( @@ -203,8 +223,14 @@ def test_adb_call(self): coordinator.execute_adb_call.assert_called_once_with(call) def test_load_state(self): + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) coordinator = _create_mock_coordinator() - env = environment.AndroidEnv(coordinator) + task_manager = mock.create_autospec(task_manager_lib.TaskManager) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) expected_response = state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.OK ) @@ -215,8 +241,14 @@ def test_load_state(self): coordinator.load_state.assert_called_once_with(request) def test_save_state(self): + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) coordinator = _create_mock_coordinator() - env = environment.AndroidEnv(coordinator) + task_manager = mock.create_autospec(task_manager_lib.TaskManager) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) expected_response = state_pb2.SaveStateResponse( status=state_pb2.SaveStateResponse.Status.OK ) @@ -227,8 +259,14 @@ def test_save_state(self): coordinator.save_state.assert_called_once_with(request) def test_double_close(self): + simulator = fake_simulator.FakeSimulator( + config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) + ) coordinator = _create_mock_coordinator() - env = environment.AndroidEnv(coordinator) + task_manager = mock.create_autospec(task_manager_lib.TaskManager) + env = environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) env.close() env.close() coordinator.close.assert_called_once() diff --git a/android_env/loader.py b/android_env/loader.py index 92666f6..c3c5601 100644 --- a/android_env/loader.py +++ b/android_env/loader.py @@ -59,7 +59,9 @@ def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv: raise ValueError('Unsupported simulator config: {config.simulator}') coordinator = coordinator_lib.Coordinator(simulator, task_manager) - return environment.AndroidEnv(coordinator=coordinator) + return environment.AndroidEnv( + simulator=simulator, coordinator=coordinator, task_manager=task_manager + ) def _process_emulator_launcher_config(