Skip to content

Commit

Permalink
Introduce TaskManagerConfig + change TaskManager.__init__().
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597607576
  • Loading branch information
kenjitoyama authored and copybara-github committed Jan 11, 2024
1 parent ce6a3f9 commit a7dc6f6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 24 deletions.
18 changes: 18 additions & 0 deletions android_env/components/config_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,21 @@ class FakeSimulatorConfig(SimulatorConfig):

# The dimensions in pixels of the device screen (HxW).
screen_dimensions: tuple[int, int] = (0, 0)


@dataclasses.dataclass
class TaskManagerConfig:
"""Config class for TaskManager."""

# If max_bad_states episodes finish in a bad state in a row, restart
# the simulation.
max_bad_states: int = 3
# The frequency to check for the current activity and view hierarchy.
# The unit is raw observation (i.e. each call to AndroidEnv.step()).
dumpsys_check_frequency: int = 150
# The maximum number of tries for extracting the current activity before
# forcing the episode to restart.
max_failed_current_activity: int = 10
# The maximum number of extras elements to store. If this number is exceeded,
# elements are dropped in the order they were received.
extras_max_buffer_size: int = 100
40 changes: 16 additions & 24 deletions android_env/components/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
from android_env.components import app_screen_checker
from android_env.components import config_classes
from android_env.components import dumpsys_thread
from android_env.components import log_stream as log_stream_lib
from android_env.components import logcat_thread
Expand All @@ -42,32 +43,18 @@ class TaskManager:
def __init__(
self,
task: task_pb2.Task,
max_bad_states: int = 3,
dumpsys_check_frequency: int = 150,
max_failed_current_activity: int = 10,
extras_max_buffer_size: int = 100,
config: config_classes.TaskManagerConfig = config_classes.TaskManagerConfig(),
):
"""Controls task-relevant events and information.
Args:
task: A task proto defining the RL task.
max_bad_states: How many bad states in a row are allowed before a restart
of the simulator is triggered.
dumpsys_check_frequency: Frequency, in steps, at which to check
current_activity and view hierarchy
max_failed_current_activity: The maximum number of tries for extracting
the current activity before forcing the episode to restart.
extras_max_buffer_size: The maximum number of extras elements to store. If
this number is exceeded, elements are dropped in the order they were
received.
config: Configuration for this instance.
"""
self._task = task
self._max_bad_states = max_bad_states
self._dumpsys_check_frequency = dumpsys_check_frequency
self._max_failed_current_activity = max_failed_current_activity

self._task = task
self._config = config
self._lock = threading.Lock()
self._extras_max_buffer_size = extras_max_buffer_size
self._logcat_thread = None
self._dumpsys_thread = None
self._setup_step_interpreter = None
Expand Down Expand Up @@ -246,9 +233,11 @@ def _start_dumpsys_thread(self,
self._dumpsys_thread = dumpsys_thread.DumpsysThread(
app_screen_checker=app_screen_checker.AppScreenChecker(
adb_call_parser=adb_call_parser,
expected_app_screen=self._task.expected_app_screen),
check_frequency=self._dumpsys_check_frequency,
max_failed_current_activity=self._max_failed_current_activity)
expected_app_screen=self._task.expected_app_screen,
),
check_frequency=self._config.dumpsys_check_frequency,
max_failed_current_activity=self._config.max_failed_current_activity,
)

def _stop_logcat_thread(self):
if self._logcat_thread is not None:
Expand All @@ -264,11 +253,11 @@ def _increment_bad_state(self) -> None:
to a good state.
"""
logging.warning('Bad state detected.')
if self._max_bad_states:
if self._config.max_bad_states:
self._is_bad_episode = True
self._bad_state_counter += 1
logging.warning('Bad state counter: %d.', self._bad_state_counter)
if self._bad_state_counter >= self._max_bad_states:
if self._bad_state_counter >= self._config.max_bad_states:
logging.error('Too many consecutive bad states. Restarting simulator.')
self._stats['restart_count_max_bad_states'] += 1
self._should_restart = True
Expand Down Expand Up @@ -378,7 +367,10 @@ def _process_extra(extra_name, extra):
latest_extras = self._latest_values['extra']
if extra_name in latest_extras:
# If latest extra is not flushed, append.
if len(latest_extras[extra_name]) >= self._extras_max_buffer_size:
if (
len(latest_extras[extra_name])
>= self._config.extras_max_buffer_size
):
latest_extras[extra_name].pop(0)
latest_extras[extra_name].append(extra)
else:
Expand Down

0 comments on commit a7dc6f6

Please sign in to comment.