Skip to content

Commit

Permalink
fix: incorrect soft-timeout implementation & fix hard-timeout follow-…
Browse files Browse the repository at this point in the history
…up command (#6280)
  • Loading branch information
xingyaoww authored Jan 16, 2025
1 parent da1a603 commit 0bed177
Show file tree
Hide file tree
Showing 18 changed files with 243 additions and 90 deletions.
24 changes: 12 additions & 12 deletions evaluation/benchmarks/commit0_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def initialize_runtime(
action = CmdRunAction(
command=f'git clone -b commit0_combined https://github.com/{instance["repo"]}.git'
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -181,7 +181,7 @@ def initialize_runtime(
)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -191,7 +191,7 @@ def initialize_runtime(
)

action = CmdRunAction(command='git checkout -b openhands')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -201,7 +201,7 @@ def initialize_runtime(

# Install commit0
action = CmdRunAction(command='/root/.cargo/bin/uv pip install commit0')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -231,7 +231,7 @@ def complete_runtime(
workspace_dir_name = _get_commit0_workspace_dir_name(instance)

action = CmdRunAction(command='git add .')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -241,7 +241,7 @@ def complete_runtime(
)

action = CmdRunAction(command='git commit -m "openhands edits"')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -258,7 +258,7 @@ def complete_runtime(
action = CmdRunAction(
command=f"git diff {instance['base_commit']} HEAD -- . ':(exclude)spec.pdf.bz2'"
)
action.timeout = 600 + 100 * n_retries
action.set_hard_timeout(600 + 100 * n_retries)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -282,7 +282,7 @@ def complete_runtime(
action = CmdRunAction(
command=f"{instance['test']['test_cmd']} --json-report --json-report-file=report.json --continue-on-collection-errors {test_dir} > test_output.txt 2>&1"
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -292,7 +292,7 @@ def complete_runtime(
)
# Read test output
action = CmdRunAction(command='cat test_output.txt')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -305,7 +305,7 @@ def complete_runtime(

# Save pytest exit code
action = CmdRunAction(command='echo $?')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -318,7 +318,7 @@ def complete_runtime(

# Read the test report
action = CmdRunAction(command='cat report.json')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -330,7 +330,7 @@ def complete_runtime(
repo_name = instance['repo'].split('/')[1]
repo_name = repo_name.replace('.', '-')
action = CmdRunAction(command=f'commit0 get-tests {repo_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
10 changes: 5 additions & 5 deletions evaluation/benchmarks/swe_bench/eval_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def process_instance(

# Set +x
action = CmdRunAction(command='chmod +x /tmp/eval.sh')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -189,7 +189,7 @@ def process_instance(
"echo 'APPLY_PATCH_FAIL')))"
)
action = CmdRunAction(command=exec_command)
action.timeout = 600
action.set_hard_timeout(600)
obs = runtime.run_action(action)
assert isinstance(obs, CmdOutputObservation)
apply_patch_output = obs.content
Expand All @@ -212,7 +212,7 @@ def process_instance(
# Run eval script in background and save output to log file
log_file = '/tmp/eval_output.log'
action = CmdRunAction(command=f'/tmp/eval.sh > {log_file} 2>&1 & echo $!')
action.timeout = 60 # Short timeout just to get the process ID
action.set_hard_timeout(60) # Short timeout just to get the process ID
obs = runtime.run_action(action)

if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0:
Expand All @@ -235,7 +235,7 @@ def process_instance(
check_action = CmdRunAction(
command=f'ps -p {pid} > /dev/null; echo $?'
)
check_action.timeout = 60
check_action.set_hard_timeout(60)
check_obs = runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
Expand All @@ -252,7 +252,7 @@ def process_instance(

# Read the log file
cat_action = CmdRunAction(command=f'cat {log_file}')
cat_action.timeout = 300
cat_action.set_hard_timeout(300)
cat_obs = runtime.run_action(cat_action)

# Grade answer
Expand Down
30 changes: 15 additions & 15 deletions evaluation/benchmarks/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def initialize_runtime(
action = CmdRunAction(
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -182,7 +182,7 @@ def initialize_runtime(
)

action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -194,7 +194,7 @@ def initialize_runtime(

# inject the instance info
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -223,14 +223,14 @@ def initialize_runtime(
'/swe_util/',
)
action = CmdRunAction(command='cat ~/.bashrc')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')

action = CmdRunAction(command='source ~/.bashrc')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -239,7 +239,7 @@ def initialize_runtime(
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')

action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
action.timeout = 3600
action.set_hard_timeout(3600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -249,7 +249,7 @@ def initialize_runtime(
)
else:
action = CmdRunAction(command='source /swe_util/swe_entry.sh')
action.timeout = 1800
action.set_hard_timeout(1800)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -259,7 +259,7 @@ def initialize_runtime(
)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -269,7 +269,7 @@ def initialize_runtime(
)

action = CmdRunAction(command='git reset --hard')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -278,14 +278,14 @@ def initialize_runtime(
action = CmdRunAction(
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')

action = CmdRunAction(command='which python')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -316,7 +316,7 @@ def complete_runtime(
workspace_dir_name = _get_swebench_workspace_dir_name(instance)

action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -326,7 +326,7 @@ def complete_runtime(
)

action = CmdRunAction(command='git config --global core.pager ""')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -336,7 +336,7 @@ def complete_runtime(
)

action = CmdRunAction(command='git add -A')
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand All @@ -351,7 +351,7 @@ def complete_runtime(
action = CmdRunAction(
command=f'git diff --no-color --cached {instance["base_commit"]}'
)
action.timeout = 600 + 100 * n_retries
action.set_hard_timeout(600 + 100 * n_retries)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
2 changes: 1 addition & 1 deletion evaluation/benchmarks/the_agent_company/browsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def pre_login(
instruction = action.to_instruction()

browser_action = BrowseInteractiveAction(browser_actions=instruction)
browser_action.timeout = 10000
browser_action.set_hard_timeout(10000)
logger.info(browser_action, extra={'msg_type': 'ACTION'})
obs: BrowserOutputObservation = runtime.run_action(browser_action)
logger.debug(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
4 changes: 2 additions & 2 deletions evaluation/benchmarks/the_agent_company/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig):
'bash /utils/init.sh'
)
action = CmdRunAction(command=command)
action.timeout = 900
action.set_hard_timeout(900)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down Expand Up @@ -172,7 +172,7 @@ def run_evaluator(
f'python_default /utils/eval.py --trajectory_path {trajectory_path} --result_path {result_path}'
)
action = CmdRunAction(command=command)
action.timeout = 600
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
10 changes: 7 additions & 3 deletions openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,12 @@ def timeout(self) -> int | None:
return self._timeout # type: ignore[attr-defined]
return None

@timeout.setter
def timeout(self, value: int | None) -> None:
def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None:
"""Set the timeout for the event.
NOTE, this is a hard timeout, meaning that the event will be blocked
until the timeout is reached.
"""
self._timeout = value
if value is not None and value > 600:
from openhands.core.logger import openhands_logger as logger
Expand All @@ -78,7 +82,7 @@ def timeout(self, value: int | None) -> None:
# Check if .blocking is an attribute of the event
if hasattr(self, 'blocking'):
# .blocking needs to be set to True if .timeout is set
self.blocking = True
self.blocking = blocking

# optional metadata, LLM call cost of the edit
@property
Expand Down
3 changes: 2 additions & 1 deletion openhands/events/serialization/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def action_from_dict(action: dict) -> Action:
try:
decoded_action = action_class(**args)
if 'timeout' in action:
decoded_action.timeout = action['timeout']
blocking = args.get('blocking', False)
decoded_action.set_hard_timeout(action['timeout'], blocking=blocking)

# Set timestamp if it was provided
if timestamp:
Expand Down
2 changes: 1 addition & 1 deletion openhands/resolver/resolve_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def complete_runtime(
git_patch = None
while n_retries < 5:
action = CmdRunAction(command=f'git diff --no-color --cached {base_commit}')
action.timeout = 600 + 100 * n_retries
action.set_hard_timeout(600 + 100 * n_retries)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
Expand Down
5 changes: 4 additions & 1 deletion openhands/runtime/action_execution_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ async def ainit(self):
self.bash_session = BashSession(
work_dir=self._initial_cwd,
username=self.username,
no_change_timeout_seconds=int(
os.environ.get('NO_CHANGE_TIMEOUT_SECONDS', 30)
),
)
self.bash_session.initialize()
await wait_all(
Expand Down Expand Up @@ -163,7 +166,7 @@ async def _init_bash_commands(self):
logger.debug(f'Initializing by running {len(INIT_COMMANDS)} bash commands...')
for command in INIT_COMMANDS:
action = CmdRunAction(command=command)
action.timeout = 300
action.set_hard_timeout(300)
logger.debug(f'Executing init command: {command}')
obs = await self.run(action)
assert isinstance(obs, CmdOutputObservation)
Expand Down
3 changes: 2 additions & 1 deletion openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def on_event(self, event: Event) -> None:

async def _handle_action(self, event: Action) -> None:
if event.timeout is None:
event.timeout = self.config.sandbox.timeout
# We don't block the command if this is a default timeout action
event.set_hard_timeout(self.config.sandbox.timeout, blocking=False)
assert event.timeout is not None
try:
observation: Observation = await call_sync_from_async(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def send_action_for_execution(self, action: Action) -> Observation:

# set timeout to default if not set
if action.timeout is None:
action.timeout = self.config.sandbox.timeout
# We don't block the command if this is a default timeout action
action.set_hard_timeout(self.config.sandbox.timeout, blocking=False)

with self.action_semaphore:
if not action.runnable:
Expand Down
Loading

0 comments on commit 0bed177

Please sign in to comment.