diff --git a/src/inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py b/src/inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py index 977168cc4..fb2bd742c 100644 --- a/src/inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py +++ b/src/inspect_ai/tool/_tools/_web_browser/_resources/playwright_crawler.py @@ -285,10 +285,12 @@ def render(self, output_format: CrawlerOutputFormat) -> Any: the currently active webpage rendered using given format. """ match output_format: + case CrawlerOutputFormat.HTML: + return self._page.content() case CrawlerOutputFormat.AT: return self._render_at() case _: - # TODO: Implement DOM, HTML, PIXELS formats + # TODO: Implement DOM, PIXELS formats raise NotImplementedError( "Playwright crawler does not currently support" f" {output_format} output." diff --git a/src/inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py b/src/inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py index f09210e36..4432f89ed 100644 --- a/src/inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py +++ b/src/inspect_ai/tool/_tools/_web_browser/_resources/test_playwright_crawler.py @@ -24,6 +24,13 @@ def test_nodes_change_on_update(self): self._crawler.update() self.assertTrue(self._crawler._nodes) + def test_render_html(self): + self._crawler.go_to_page("https://www.example.com") + html = self._crawler.render(playwright_crawler.CrawlerOutputFormat.HTML) + self.assertIn(" dict[str, specs.Array]: obs_shapes = { "web_url": specs.Array(shape=(), dtype=str, name="web_url"), "web_at": specs.Array(shape=(), dtype=str, name="web_at"), + "web_html": specs.Array(shape=(), dtype=str, name="web_html"), "error": specs.Array(shape=(), dtype=str, name="error"), "info": specs.Array(shape=(), dtype=str, name="info"), } @@ -158,6 +159,7 @@ def render(mode): obs_map = { "web_url": lambda: self._web.url.split("?")[0], "web_at": render(playwright_crawler.CrawlerOutputFormat.AT), + "web_html": render(playwright_crawler.CrawlerOutputFormat.HTML), "error": lambda: self._last_error, "info": lambda: json.dumps(self.info), } diff --git a/src/inspect_ai/tool/_tools/_web_browser/_web_browser.py b/src/inspect_ai/tool/_tools/_web_browser/_web_browser.py index 677baa272..f5f0f7bca 100644 --- a/src/inspect_ai/tool/_tools/_web_browser/_web_browser.py +++ b/src/inspect_ai/tool/_tools/_web_browser/_web_browser.py @@ -1,5 +1,6 @@ import re from textwrap import dedent +from typing import Literal from inspect_ai._util.error import PrerequisiteError from inspect_ai.tool._tool import Tool, ToolError, tool @@ -11,20 +12,25 @@ from inspect_ai.util._store import store -def web_browser(interactive: bool = True) -> list[Tool]: +def web_browser( + interactive: bool = True, + output_format: Literal["at", "html"] = "at", +) -> list[Tool]: """Tools used for web browser navigation. Args: interactive (bool): Provide interactive tools (enable clicking, typing, and submitting forms). Defaults to True. + output_format (Literal["at", "html"]): Output format for + web browser tools. Defaults to "at" (accessibility tree). Returns: List of tools used for web browser navigation. """ # start with go tool (excluding interactive docs if necessary) - go = web_browser_go() + go = web_browser_go(output_format) if not interactive: go = go_without_interactive_docs(go) tools = [go] @@ -32,22 +38,22 @@ def web_browser(interactive: bool = True) -> list[Tool]: # add interactive tools if requested if interactive: tools = tools + [ - web_browser_click(), - web_browser_type_submit(), - web_browser_type(), + web_browser_click(output_format), + web_browser_type_submit(output_format), + web_browser_type(output_format), ] # add navigational tools return tools + [ - web_browser_scroll(), - web_browser_back(), - web_browser_forward(), - web_browser_refresh(), + web_browser_scroll(output_format), + web_browser_back(output_format), + web_browser_forward(output_format), + web_browser_refresh(output_format), ] @tool(parallel=False) -def web_browser_go() -> Tool: +def web_browser_go(output_format: Literal["at", "html"] = "at") -> Tool: """Web Browser tool for navigation to a URL. Returns: @@ -57,7 +63,9 @@ def web_browser_go() -> Tool: async def execute(url: str) -> str: """Navigate the web browser to a URL. - Once you have navigated to a page, you will be presented with a web accessibilty tree of the elements on the page. Each element has an ID, which is displayed in brackets at the beginning of its line. For example: + Once you have navigated to a page, you will be presented with HTML or a web accessibility tree of the elements on the page. + + In web accessibility trees each element has an ID, which is displayed in brackets at the beginning of its line. For example: ``` [1] RootWebArea "Google" [focused: True, url: https://www.google.com/] @@ -78,9 +86,9 @@ async def execute(url: str) -> str: url (str): URL to navigate to. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_go", url) + return await web_browser_cmd("web_go", output_format, url) return execute @@ -97,11 +105,15 @@ def go_without_interactive_docs(tool: Tool) -> Tool: # custom viewer for interactive tool calls that shows a truncated # version of current the web accessiblity tree if available +WEB_BROWSER_HTML = "web_browser:html" WEB_BROWSER_AT = "web_browser:at" -def web_at_viewer(call: ToolCall) -> ToolCallView: - # get the web accessiblity tree, if we have it create a view from it +def web_viewer(call: ToolCall) -> ToolCallView: + web_html = store().get(WEB_BROWSER_HTML, "") + if web_html: + return ToolCallView(context=ToolCallContent(format="text", content=web_html)) + web_at = store().get(WEB_BROWSER_AT, "") element_id = call.arguments.get("element_id", 0) if web_at and element_id: @@ -126,8 +138,10 @@ def web_at_viewer(call: ToolCall) -> ToolCallView: return ToolCallView() -@tool(viewer=web_at_viewer, parallel=False) -def web_browser_click() -> Tool: +@tool(viewer=web_viewer, parallel=False) +def web_browser_click( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for clicking an element on a web page. Returns: @@ -137,7 +151,7 @@ def web_browser_click() -> Tool: async def execute(element_id: int) -> str: """Click an element on the page currently displayed by the web browser. - For example, with the following web accessibilty tree: + For example, with the following web accessibility tree: ``` [304] RootWebArea "Poetry Foundation" [focused: True, url: https://www.poetryfoundation.org/] @@ -154,15 +168,17 @@ async def execute(element_id: int) -> str: element_id (int): ID of the element to click. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_click", str(element_id)) + return await web_browser_cmd("web_click", output_format, str(element_id)) return execute -@tool(viewer=web_at_viewer, parallel=False) -def web_browser_type_submit() -> Tool: +@tool(viewer=web_viewer, parallel=False) +def web_browser_type_submit( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for typing and submitting input. Returns: @@ -192,15 +208,19 @@ async def execute(element_id: int, text: str) -> str: text (str): Text to type. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_type_submit", str(element_id), text) + return await web_browser_cmd( + "web_type_submit", output_format, str(element_id), text + ) return execute -@tool(viewer=web_at_viewer, parallel=False) -def web_browser_type() -> Tool: +@tool(viewer=web_viewer, parallel=False) +def web_browser_type( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for typing into inputs. Returns: @@ -230,15 +250,17 @@ async def execute(element_id: int, text: str) -> str: text (str): Text to type. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_type", str(element_id), text) + return await web_browser_cmd("web_type", output_format, str(element_id), text) return execute @tool(parallel=False) -def web_browser_scroll() -> Tool: +def web_browser_scroll( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for scrolling up or down one page. Returns: @@ -260,15 +282,17 @@ async def execute(direction: str) -> str: direction (str): "up" or "down" Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_scroll", direction) + return await web_browser_cmd("web_scroll", output_format, direction) return execute @tool(parallel=False) -def web_browser_back() -> Tool: +def web_browser_back( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for navigating back in the browser history. Returns: @@ -281,15 +305,17 @@ async def execute() -> str: If you want to view a page that you have previously browsed (or perhaps just didn't find what you were looking for on a page and want to backtrack) use the web_browser_back tool. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_back") + return await web_browser_cmd("web_back", output_format) return execute @tool(parallel=False) -def web_browser_forward() -> Tool: +def web_browser_forward( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for navigating forward in the browser history. Returns: @@ -302,15 +328,17 @@ async def execute() -> str: If you have navigated back in the browser history and then want to navigate forward use the web_browser_forward tool. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_forward") + return await web_browser_cmd("web_forward", output_format) return execute @tool(parallel=False) -def web_browser_refresh() -> Tool: +def web_browser_refresh( + output_format: Literal["at", "html"] = "at", +) -> Tool: """Web Browser tool for refreshing the current page. Returns: @@ -323,9 +351,9 @@ async def execute() -> str: If you have interacted with a page by clicking buttons and want to reset it to its original state, use the web_browser_refresh tool. Returns: - Web accessibility tree of the visible elements of the web page. The element_id of each element is displayed in brackets at the beginning of the line. + HTML or web accessibility tree of the visible elements of the web page. If the latter, the element_id of each element is displayed in brackets at the beginning of the line. """ - return await web_browser_cmd("web_refresh") + return await web_browser_cmd("web_refresh", output_format) return execute @@ -335,7 +363,9 @@ async def execute() -> str: BROWSER_SESSION_ID = "BROWSER_SESSION_ID" -async def web_browser_cmd(cmd: str, *args: str) -> str: +async def web_browser_cmd( + cmd: str, output_format: Literal["at", "html"] = "at", *args: str +) -> str: sandbox_env = await sandbox_with(WEB_CLIENT_NEW_SESSION) session_flag = "" if sandbox_env: @@ -369,18 +399,28 @@ async def web_browser_cmd(cmd: str, *args: str) -> str: ) else: response = parse_web_browser_output(result.stdout) - if "web_at" in response: - web_at = ( - str(response.get("web_at")) or "(no web accessiblity tree available)" - ) - # Remove base64 data from images. - web_at_lines = web_at.split("\n") - web_at_lines = [ - line.partition("data:image/png;base64")[0] for line in web_at_lines - ] - web_at = "\n".join(web_at_lines) - store().set(WEB_BROWSER_AT, web_at) - return web_at + if "web_html" or "web_at" in response: + if output_format == "html": + html_content = ( + str(response.get("web_html")) or "(no HTML content available)" + ) + store().set(WEB_BROWSER_HTML, html_content) + return html_content + elif output_format == "at": + web_at = ( + str(response.get("web_at")) + or "(no web accessiblity tree available)" + ) + # Remove base64 data from images. + web_at_lines = web_at.split("\n") + web_at_lines = [ + line.partition("data:image/png;base64")[0] for line in web_at_lines + ] + web_at = "\n".join(web_at_lines) + store().set(WEB_BROWSER_AT, web_at) + return web_at + else: + raise ValueError(f"Unknown output format: {output_format}") elif "error" in response: raise ToolError(str(response.get("error")) or "(unknown error)") else: @@ -418,7 +458,9 @@ async def web_browser_sandbox() -> SandboxEnvironment: def parse_web_browser_output(output: str) -> dict[str, str]: - response: dict[str, str] = dict(web_url="", web_at="", info="", error="") + response: dict[str, str] = dict( + web_url="", web_at="", web_html="", info="", error="" + ) active_field: str | None = None active_field_lines: list[str] = [] @@ -428,7 +470,9 @@ def collect_active_field() -> None: active_field_lines.clear() for line in output.splitlines(): - field_match = re.match(r"^(error|web_at|web_url|info)\s*:\s*(.+)$", line) + field_match = re.match( + r"^(error|web_at|web_html|web_url|info)\s*:\s*(.+)$", line + ) if field_match: collect_active_field() active_field = field_match.group(1)