-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Carcass for browser tool: client, server, servicer and empty web env #466
Merged
jjallaire
merged 3 commits into
UKGovernmentBEIS:main
from
MariaIzobava:feature/headless-browser-tool
Sep 26, 2024
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Base docker build file. | ||
|
||
FROM python:3.12-bookworm | ||
|
||
WORKDIR /app | ||
|
||
RUN apt-get update | ||
|
||
RUN pip install --upgrade pip | ||
|
||
# Install playwright | ||
RUN pip install playwright | ||
RUN playwright install | ||
RUN playwright install-deps | ||
|
||
# Install other dependancies | ||
RUN pip install dm-env-rpc pillow bs4 lxml | ||
|
||
COPY *.py ./ | ||
|
||
CMD ["python3", "web_server.py"] |
248 changes: 248 additions & 0 deletions
248
evals/gdm_capabilities/util/web_browser_tool/dm_env_servicer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
"""Environment service that allows clients to run shell commands in steps.""" | ||
|
||
import threading | ||
from typing import Any, Iterable, Type | ||
|
||
import dm_env | ||
import grpc | ||
import immutabledict | ||
from dm_env import specs | ||
from dm_env_rpc.v1 import ( | ||
dm_env_rpc_pb2, | ||
dm_env_rpc_pb2_grpc, | ||
dm_env_utils, | ||
spec_manager, | ||
) | ||
from google.rpc import code_pb2, status_pb2 | ||
|
||
_WORLD_NAME = "WebBrowser" | ||
|
||
|
||
class EnvironmentSpec: | ||
"""Specifications for a dm_environment. | ||
|
||
This class holds action and observation specs, as well as the required | ||
managers to pack actions and observations. | ||
""" | ||
|
||
def __init__(self, env: dm_env.Environment): | ||
convert = dm_env_utils.dm_env_spec_to_tensor_spec | ||
|
||
# We support either a single spec, of flat dictionary of specs. | ||
# In the dictionary case we need to map names to unique IDs. | ||
env_obs_spec: dict[str, Any] = env.observation_spec() | ||
if isinstance(env_obs_spec, specs.Array): | ||
self.observation_spec = {1: convert(env_obs_spec)} | ||
else: | ||
self.observation_spec = {} | ||
for i, obs_spec in enumerate(env_obs_spec.values()): | ||
self.observation_spec[i + 1] = convert(obs_spec) | ||
|
||
assert isinstance( | ||
env.action_spec(), specs.Array | ||
), "Only a single action type is supported." | ||
self.action_spec = {1: convert(env.action_spec())} | ||
|
||
self.observation_manager = spec_manager.SpecManager(self.observation_spec) | ||
self.action_manager = spec_manager.SpecManager(self.action_spec) | ||
|
||
|
||
class EnvironmentService(dm_env_rpc_pb2_grpc.EnvironmentServicer): | ||
"""Runs the environment as a gRPC EnvironmentServicer.""" | ||
|
||
def __init__(self, env_type: Type[dm_env.Environment]) -> None: | ||
"""Initializes the environment. | ||
|
||
Args: | ||
env_type: A dm_env class to serve. | ||
""" | ||
self._env_type = env_type | ||
self._env: dm_env.Environment = None | ||
self._spec: EnvironmentSpec = None | ||
self._lock = threading.Lock() | ||
# A server can only have one client connected at a time for now. | ||
self._has_joined_client = False | ||
|
||
self._handlers = immutabledict.immutabledict( | ||
{ | ||
dm_env_rpc_pb2.CreateWorldRequest: self._handle_create_world_request, | ||
dm_env_rpc_pb2.JoinWorldRequest: self._handle_join_world_request, | ||
dm_env_rpc_pb2.LeaveWorldRequest: self._handle_leave_world_request, | ||
dm_env_rpc_pb2.DestroyWorldRequest: self._handle_destroy_world_request, | ||
dm_env_rpc_pb2.ResetRequest: self._handle_reset_request, | ||
dm_env_rpc_pb2.StepRequest: self._handle_step_request, | ||
} | ||
) | ||
|
||
def Process( | ||
self, | ||
request_iterator: Iterable[dm_env_rpc_pb2.EnvironmentRequest], | ||
context: grpc.ServicerContext, | ||
): | ||
"""Processes incoming EnvironmentRequests. | ||
|
||
For each EnvironmentRequest the internal message is extracted and handled. | ||
The response for that message is then placed in a EnvironmentResponse which | ||
is returned to the client. | ||
|
||
An error status will be returned if an unknown message type is received or | ||
if the message is invalid for the current world state. | ||
|
||
|
||
Args: | ||
request_iterator: Message iterator provided by gRPC. | ||
context: Context provided by gRPC. | ||
|
||
Yields: | ||
EnvironmentResponse: Response for each incoming EnvironmentRequest. | ||
""" | ||
for request in request_iterator: | ||
environment_response = dm_env_rpc_pb2.EnvironmentResponse() | ||
try: | ||
message_type = request.WhichOneof("payload") | ||
internal_request = getattr(request, message_type) | ||
response = self._handlers[type(internal_request)](internal_request) | ||
getattr(environment_response, message_type).CopyFrom(response) | ||
except Exception as e: # pylint: disable=broad-except | ||
environment_response.error.CopyFrom( | ||
status_pb2.Status(code=code_pb2.INTERNAL, message=str(e)) | ||
) | ||
yield environment_response | ||
|
||
def _validate_settings(self, settings, valid_settings): | ||
""" "Validate the provided settings with list of valid setting keys.""" | ||
unrecognized_settings = [ | ||
setting for setting in settings if setting not in valid_settings | ||
] | ||
|
||
if unrecognized_settings: | ||
raise ValueError( | ||
"Unrecognized settings provided! Invalid settings:" | ||
f" {unrecognized_settings}" | ||
) | ||
|
||
def _add_spec_to_response(self, response: dm_env_rpc_pb2.EnvironmentResponse): | ||
"""Modifies given respose to include action/observation specifications.""" | ||
for uid, action in self._spec.action_spec.items(): | ||
response.specs.actions[uid].CopyFrom(action) | ||
for uid, observation in self._spec.observation_spec.items(): | ||
response.specs.observations[uid].CopyFrom(observation) | ||
|
||
def _handle_create_world_request( | ||
self, request: dm_env_rpc_pb2.CreateWorldRequest | ||
) -> dm_env_rpc_pb2.CreateWorldResponse: | ||
"""Handles create_world requests.""" | ||
self._validate_settings(request.settings, []) | ||
del request | ||
with self._lock: | ||
self._env = self._env_type() | ||
self._spec = EnvironmentSpec(self._env) | ||
return dm_env_rpc_pb2.CreateWorldResponse(world_name=_WORLD_NAME) | ||
|
||
def _handle_join_world_request( | ||
self, request: dm_env_rpc_pb2.JoinWorldRequest | ||
) -> dm_env_rpc_pb2.JoinWorldResponse: | ||
"""Handles join_world requests.""" | ||
self._validate_settings(request.settings, []) | ||
response = dm_env_rpc_pb2.JoinWorldResponse() | ||
with self._lock: | ||
if request.world_name != _WORLD_NAME: | ||
raise ValueError( | ||
f"Joining with the wrong world_name {request.world_name}" | ||
) | ||
if self._has_joined_client: | ||
raise ValueError("Only one client can join the environment at a time.") | ||
self._has_joined_client = True | ||
self._add_spec_to_response(response) | ||
del request | ||
return response | ||
|
||
def _handle_leave_world_request( | ||
self, request: dm_env_rpc_pb2.LeaveWorldRequest | ||
) -> dm_env_rpc_pb2.LeaveWorldResponse: | ||
"""Handles leave_world requests.""" | ||
del request | ||
with self._lock: | ||
self._has_joined_client = False | ||
|
||
response = dm_env_rpc_pb2.LeaveWorldResponse() | ||
return response | ||
|
||
def _handle_destroy_world_request( | ||
self, request: dm_env_rpc_pb2.DestroyWorldRequest | ||
) -> dm_env_rpc_pb2.DestroyWorldResponse: | ||
"""Handles destroy_world requests.""" | ||
del request | ||
with self._lock: | ||
if self._has_joined_client: | ||
raise ValueError("Destroying environment which has joined client.") | ||
if self._env is None: | ||
raise ValueError("Can not destroy uncreated environment.") | ||
self._env.close() | ||
self._env = None | ||
response = dm_env_rpc_pb2.DestroyWorldResponse() | ||
return response | ||
|
||
def _handle_reset_request( | ||
self, request: dm_env_rpc_pb2.ResetRequest | ||
) -> dm_env_rpc_pb2.ResetResponse: | ||
"""Handles reset requests.""" | ||
response = dm_env_rpc_pb2.ResetResponse() | ||
with self._lock: | ||
assert self._env, "Please create world before calling reset." | ||
self._env.reset() | ||
self._add_spec_to_response(response) | ||
return response | ||
|
||
def _handle_step_request( | ||
self, request: dm_env_rpc_pb2.StepRequest | ||
) -> dm_env_rpc_pb2.StepResponse: | ||
"""Handles step requests. | ||
|
||
Args: | ||
request: The request, which should contain a 'command' entry. | ||
|
||
Returns: | ||
Response including requested observations. | ||
|
||
Raises: | ||
KeyError: If the requested observation is not in the list of available | ||
observations. | ||
""" | ||
with self._lock: | ||
assert self._has_joined_client, "Please join world before calling step." | ||
|
||
action = self._spec.action_manager.unpack(request.actions) | ||
|
||
if "command" in action: | ||
command = action["command"] | ||
else: | ||
# For some reason dm_env calls step without actions after a reset. | ||
command = "" | ||
|
||
timestep: dm_env.TimeStep = self._env.step(command) | ||
|
||
packed_observations = self._spec.observation_manager.pack( | ||
timestep.observation | ||
) | ||
|
||
match timestep.step_type: | ||
case dm_env.StepType.MID: | ||
step_state = dm_env_rpc_pb2.RUNNING | ||
case dm_env.StepType.LAST: | ||
step_state = dm_env_rpc_pb2.TERMINATED | ||
case _: | ||
raise ValueError(f"Unsupported step type {timestep.step_type}.") | ||
|
||
response = dm_env_rpc_pb2.StepResponse(state=step_state) | ||
for requested_observation in request.requested_observations: | ||
if requested_observation not in packed_observations: | ||
name = self._spec.observation_manager.uid_to_name( | ||
requested_observation | ||
) | ||
raise KeyError(f"Requested observation not found: {name}") | ||
response.observations[requested_observation].CopyFrom( | ||
packed_observations[requested_observation] | ||
) | ||
|
||
return response |
45 changes: 45 additions & 0 deletions
45
evals/gdm_capabilities/util/web_browser_tool/mock_environment.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
"""A mock dm_env for unit testing.""" | ||
|
||
from typing import Any | ||
|
||
import dm_env | ||
from dm_env import specs | ||
|
||
|
||
class MockEnvironment(dm_env.Environment): | ||
"""A Mock DM environment.""" | ||
|
||
def __init__(self): | ||
"""Initializes the environment.""" | ||
super().__init__() | ||
self._last_command = "" | ||
|
||
def reset(self) -> dm_env.TimeStep: | ||
"""Starts a new sequence and returns the first `TimeStep` of this sequence.""" | ||
self._last_command = "" | ||
return dm_env.restart(observation=self.get_observations()) | ||
|
||
def step(self, action: list[str]) -> dm_env.TimeStep: | ||
"""Updates the environment according to the action and returns a `TimeStep`.""" | ||
self._last_command = " ".join(action) | ||
return dm_env.transition( | ||
reward=0.0, | ||
observation=self.get_observations(), | ||
) | ||
|
||
def observation_spec(self) -> dict[str, specs.Array]: | ||
"""Defines the observations provided by the environment.""" | ||
obs_shapes = { | ||
"last_command": specs.Array(shape=(), dtype=str, name="last_command"), | ||
} | ||
return obs_shapes | ||
|
||
def action_spec(self) -> specs.Array: | ||
"""Defines the actions that should be provided to `step`.""" | ||
return specs.Array(shape=(), dtype=str, name="command") | ||
|
||
def get_observations(self) -> dict[str, Any]: | ||
"""Returns dictionary containing observations.""" | ||
return { | ||
"last_command": self._last_command, | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we thinking that the container that hosts the web browser will be dedicated to just the web browser? If so, that would result in an extra container per-sample compared to integrating it. That's is fine by me if you think that's the best approach, just wanted to clarify my understanding and flag the additional resource usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I've thought about this further I very much like the idea of the web_browser having its own container. This will make "importing" it into challenges very easy (just add a service to compose.yaml). I also think that compared to the overhead of Chromium the overhead of a container will be pretty low. So no reservations at all about this approach!