Skip to content

Commit

Permalink
Made it easier to instantiate environments (#74)
Browse files Browse the repository at this point in the history
Added a new way to create environments based on an explicit task. Also made it possible to see available environments (although it is approximate).
  • Loading branch information
whitead authored Oct 18, 2024
1 parent 95dfe61 commit a20118e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
4 changes: 4 additions & 0 deletions packages/gsm8k/src/aviary/gsm8k/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def __init__(
self.check_tool = Tool.from_function(self.check_answer)
self.tools = [self.calc_tool, self.check_tool]

@classmethod
def from_task(cls, task: str) -> "CalculatorEnv":
return cls(problem_id="task", problem=task, answer=0.0)

async def reset(self) -> tuple[list[Message], list[Tool]]:
self.state = None # this environment is effectively stateless
return [Message(content=self.problem)], self.tools
Expand Down
4 changes: 4 additions & 0 deletions packages/hotpotqa/src/aviary/hotpotqa/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ def __init__(
create_tool(self.finish, "Finish"),
]

@classmethod
def from_task(cls, task: str) -> "HotPotQAEnv":
return cls(question=task, correct_answer=0.0)

def calculate_reward(self, answer: str | None) -> float:
"""Calculate the reward based on the agent's answer.
Expand Down
51 changes: 44 additions & 7 deletions src/aviary/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,35 @@ async def close(self) -> None:
"""

@classmethod
def from_name(cls, name: str, **env_kwargs) -> Self:
return _construct_obj_from_name(ENV_REGISTRY, name, **env_kwargs)
def from_task(cls, task: str) -> Self:
"""Create an environment from a task description.
A task is meant to be closer to a user prompt - like what you would expect
in calling an LLM. This is how the environment should be used after training
and in deployment. We don't take config here, because the default environment config
should be general for arbitrary tasks. Or, the config should be coupled to the agent
training (future TODO).
"""
raise NotImplementedError(f"{cls.__name__} does not implement from_task")

@classmethod
def from_name(cls, name: str, task: str | None = None, **env_kwargs) -> Self:
"""Create an environment from the name of the class. Call `Environment.available()` to see list."""
new_cls = _get_cls_from_name(ENV_REGISTRY, name)
if task is not None:
if env_kwargs:
raise ValueError("Cannot pass both a task and environment kwargs.")
return new_cls.from_task(task)
return new_cls(**env_kwargs)

@classmethod
def available(cls) -> set[str]:
"""See list of available environment classes for `from_name`.
This is not exhaustive, because some may be importable and so you should just
try to call `from_name`. This is more for logging/debugging purposes.
"""
return set(ENV_REGISTRY.keys())


# Maps baseline environment names to their module and class names
Expand All @@ -272,7 +299,7 @@ class TaskDataset(ABC, Generic[TEnvironment]):

@classmethod
def from_name(cls, name: str, **env_kwargs) -> TaskDataset:
return _construct_obj_from_name(TASK_DATASET_REGISTRY, name, **env_kwargs)
return _get_cls_from_name(TASK_DATASET_REGISTRY, name)(**env_kwargs)

def __len__(self) -> int:
raise TypeError(f'"Object of type {self.__class__.__name__}" has no len()')
Expand Down Expand Up @@ -372,8 +399,13 @@ class DummyEnv(Environment[DummyEnvState]):

State = DummyEnvState

def __init__(self, end_immediately: bool = True):
def __init__(self, task: str | None = None, end_immediately: bool = True):
self.end_immediately = end_immediately
self.task = task

@classmethod
def from_task(cls, task: str) -> DummyEnv:
return cls(task=task)

async def step(
self, action: ToolRequestMessage
Expand Down Expand Up @@ -407,7 +439,12 @@ def cast_int(x: float) -> int:
Tool.from_function(cast_int, allow_empty_param_descriptions=True),
]
self.state = type(self).State(
messages=[Message(content="Write a 5 word story via print_story")]
messages=[
Message(
content="Write a 5 word story via print_story"
+ (f" about {self.task}" if self.task else "")
)
],
)
return self.state.messages, self.tools

Expand All @@ -432,7 +469,7 @@ def __bool__(self) -> bool:
return True


def _construct_obj_from_name(registry: dict[str, tuple[str, str]], name: str, **kwargs):
def _get_cls_from_name(registry: dict[str, tuple[str, str]], name: str):
try:
module_name, cls_name = registry[name]
except KeyError:
Expand All @@ -446,4 +483,4 @@ def _construct_obj_from_name(registry: dict[str, tuple[str, str]], name: str, **
f"Could not import env from {module_name}; you need to install it."
) from None

return getattr(module, cls_name)(**kwargs)
return getattr(module, cls_name)

0 comments on commit a20118e

Please sign in to comment.