Skip to content

Commit

Permalink
Add HTML output support to web browser tool
Browse files Browse the repository at this point in the history
  • Loading branch information
farrelmahaztra committed Dec 7, 2024
1 parent ac1c8ba commit cd99abb
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 60 deletions.
3 changes: 2 additions & 1 deletion src/inspect_ai/tool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from ._tool_params import ToolParam, ToolParams
from ._tool_with import tool_with
from ._tools._execute import bash, python
from ._tools._web_browser import web_browser
from ._tools._web_browser import CrawlerOutputFormat, web_browser
from ._tools._web_search import web_search

__all__ = [
"bash",
"python",
"web_browser",
"CrawlerOutputFormat",
"web_search",
"tool",
"tool_with",
Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/tool/_tools/_web_browser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from ._web_browser import web_browser
from ._web_browser import CrawlerOutputFormat, web_browser

__all__ = ["web_browser"]
__all__ = ["web_browser", "CrawlerOutputFormat"]
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ def render(self, output_format: CrawlerOutputFormat) -> Any:
the currently active webpage rendered using given format.
"""
match output_format:
case CrawlerOutputFormat.HTML:
return self._page.content()
case CrawlerOutputFormat.AT:
return self._render_at()
case _:
# TODO: Implement DOM, HTML, PIXELS formats
# TODO: Implement DOM, PIXELS formats
raise NotImplementedError(
"Playwright crawler does not currently support"
f" {output_format} output."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def test_nodes_change_on_update(self):
self._crawler.update()
self.assertTrue(self._crawler._nodes)

def test_render_html(self):
self._crawler.go_to_page("https://www.example.com")
html = self._crawler.render(playwright_crawler.CrawlerOutputFormat.HTML)
self.assertIn("<html", html)
self.assertIn("Example Domain", html)
self.assertIn("More information...", html)

def test_render_accessibility_tree(self):
self._crawler.go_to_page("https://www.example.com")
at_no_update = self._crawler.render(playwright_crawler.CrawlerOutputFormat.AT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class WebEnvironment(dm_env.Environment):
"""A DM environment where an agent controls a web browser."""

DEFAULT_OBSERVATIONS = ["web_url", "web_at", "error", "info"]
DEFAULT_OBSERVATIONS = ["web_url", "web_at", "web_html", "error", "info"]

def __init__(self, browser_context):
"""Initializes the environment."""
Expand Down Expand Up @@ -115,6 +115,7 @@ def observation_spec(self) -> dict[str, specs.Array]:
obs_shapes = {
"web_url": specs.Array(shape=(), dtype=str, name="web_url"),
"web_at": specs.Array(shape=(), dtype=str, name="web_at"),
"web_html": specs.Array(shape=(), dtype=str, name="web_html"),
"error": specs.Array(shape=(), dtype=str, name="error"),
"info": specs.Array(shape=(), dtype=str, name="info"),
}
Expand Down Expand Up @@ -158,6 +159,7 @@ def render(mode):
obs_map = {
"web_url": lambda: self._web.url.split("?")[0],
"web_at": render(playwright_crawler.CrawlerOutputFormat.AT),
"web_html": render(playwright_crawler.CrawlerOutputFormat.HTML),
"error": lambda: self._last_error,
"info": lambda: json.dumps(self.info),
}
Expand Down
Loading

0 comments on commit cd99abb

Please sign in to comment.