-
Notifications
You must be signed in to change notification settings - Fork 156
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add inspect_evals package * mypy --------- Co-authored-by: aisi-inspect <[email protected]>
- Loading branch information
1 parent
26acdb7
commit 12952e6
Showing
9 changed files
with
249 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# ruff: noqa: F401 | ||
|
||
from .gaia import gaia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# GAIA | ||
This is an Inspect AI implementation of [the GAIA (General AI Assistants)](https://arxiv.org/abs/2311.12983) benchmark, consisting of 450 questions testing tool use on realistic assistant tasks (mostly web browsing). | ||
|
||
## Prerequisites | ||
|
||
1) **Installation** Install the `inspect_evals` Python package with: | ||
|
||
```bash | ||
pip install git+https://github.com/UKGovernmentBEIS/inspect_ai[inspect_evals] | ||
``` | ||
|
||
2) **Docker Engine** The GAIA task uses [tool calling](https://inspect.ai-safety-institute.org.uk/tools.html) to enable the model to execute bash commands. Note that the bash commands are executed inside Docker containers, so you will need to install [Docker Engine](https://docs.docker.com/engine/install/) in order to run the evaluation. | ||
|
||
3) **Hugging Face Dataset** Upon running the GAIA task, it will attempt to download the dataset from the HuggingFace hub. For this to work, you will need to gain access to the dataset (by filling out a form on [the GAIA huggingface repository](https://huggingface.co/datasets/gaia-benchmark/GAIA)), and also [create and set an access token](https://huggingface.co/docs/hub/en/security-tokens). You will need to define the `HF_TOKEN` environment variable to access the dataset: | ||
|
||
``` | ||
HF_TOKEN=<hf-token> | ||
``` | ||
|
||
4) **Search Provider** The default solver for the GAIA task uses the inspect [web_search()](https://inspect.ai-safety-institute.org.uk/tools.html#sec-web-search) tool. The `web_search()` tool uses [Google Programmable Search Engine](https://programmablesearchengine.google.com/about/). To use it you will therefore need to setup your own Google Programmable Search Engine and also enable the [Programmable Search Element Paid API](https://developers.google.com/custom-search/docs/paid_element). Then, ensure that the following environment variables are defined: | ||
|
||
``` | ||
GOOGLE_CSE_ID=<google-search-engine-id> | ||
GOOGLE_CSE_API_KEY=<google-api-key> | ||
``` | ||
|
||
|
||
## Usage | ||
|
||
After fulfiling the prerequisites, run the GAIA task aginst various models with: | ||
|
||
```bash | ||
$ inspect eval inspect_evals/gaia --model openai/gpt-4o | ||
$ inspect eval inspect_evals/gaia --model google/gemini-1.5-pro | ||
``` | ||
|
||
You might want to limit the number of samples evaluated for initial experimentation: | ||
|
||
```bash | ||
$ inspect eval inspect_evals/gaia --model openai/gpt-4o --limit 5 | ||
``` | ||
|
||
## Development | ||
|
||
If you want to develop solver for the GAIA task, import the `@task` from the `gaia` module and pass your own solver to it. For example: | ||
|
||
```python | ||
from inspect_ai import eval | ||
from inspect_ai.solver import use_tools, generate, system_message, basic_agent | ||
from inspect_ai.tool import bash, web_search | ||
|
||
from inspect_evals.gaia import gaia | ||
|
||
agent = basic_agent(tools=[bash()]) | ||
task = gaia(plan=agent) | ||
|
||
eval(task, model="openai/gpt-4o") | ||
``` | ||
|
||
See the documentation on the Inspect [Agents API](https://inspect.ai-safety-institute.org.uk/agents-api.html) for additional information on developing agents. | ||
|
||
### Test Split | ||
|
||
You can evaluate against the "test" split with: | ||
|
||
```python | ||
agent = basic_agent(tools=[bash()]) | ||
task = gaia(plan=agent, split="test") | ||
|
||
eval(task, model="openai/gpt-4o") | ||
``` | ||
|
||
Note that the GAIA "test" split does not come with any solutions so is not scored (rather, solutions are uploaded to the online leaderboard). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .gaia import gaia, gaia_dataset | ||
|
||
__all__ = ["gaia", "gaia_dataset"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
services: | ||
default: | ||
image: "python:3.12-bookworm" | ||
init: true | ||
command: "tail -f /dev/null" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
from typing import Any, Callable, Literal | ||
|
||
from huggingface_hub import snapshot_download # type: ignore | ||
from platformdirs import user_cache_dir | ||
|
||
from inspect_ai import Task, task | ||
from inspect_ai.dataset import Dataset, Sample, hf_dataset | ||
from inspect_ai.scorer import match | ||
from inspect_ai.solver import Solver, basic_agent, system_message | ||
from inspect_ai.tool import bash, web_search | ||
|
||
|
||
@task | ||
def gaia( | ||
solver: Solver | None = None, | ||
input_prompt: str | None = None, | ||
max_messages: int = 30, | ||
subset: Literal[ | ||
"2023_all", "2023_level1", "2023_level2", "2023_level3" | ||
] = "2023_all", | ||
split: Literal["test", "validation"] = "validation", | ||
filter: Callable[[Sample], bool] | None = None, | ||
) -> Task: | ||
# read dataset | ||
dataset = gaia_dataset( | ||
input_prompt=input_prompt or DEFAULT_INPUT_PROMPT, | ||
subset=subset, | ||
split=split, | ||
filter=filter if filter else lambda x: True, | ||
) | ||
|
||
# provide default plan if required | ||
if solver is None: | ||
solver = basic_agent( | ||
init=system_message( | ||
"Please solve the coding task below. Once you are done, " | ||
+ "use the submit() tool to provide your answer." | ||
), | ||
tools=[bash(), web_search()], | ||
) | ||
|
||
# resolve scorer (test split has no answers) | ||
scorer = match() if split == "validation" else None | ||
|
||
# docker compose file is alongside | ||
COMPOSE_FILE = Path(__file__).parent / "compose.yaml" | ||
|
||
# return task | ||
return Task( | ||
dataset=dataset, | ||
plan=solver, | ||
scorer=scorer, | ||
sandbox=("docker", COMPOSE_FILE.as_posix()), | ||
max_messages=max_messages, | ||
) | ||
|
||
|
||
DEFAULT_INPUT_PROMPT = """Please answer the question below. You should: | ||
- Return only your answer, which should be a number, or a short phrase with as | ||
few words as possible, or a comma separated list of numbers and/or strings. | ||
- If the answer is a number, return only the number without any units unless | ||
specified otherwise. | ||
- If the answer is a string, don't include articles, and don't use | ||
abbreviations (e.g. for states). | ||
- If the answer is a comma separated list, apply the above rules to each | ||
element in the list. | ||
Any files or attachments mentioned in the question can be found in the | ||
/shared_files/ directory (some questions do not have associated files). Here | ||
is the question: | ||
{question}""" | ||
|
||
|
||
def gaia_dataset( | ||
input_prompt: str, | ||
subset: str, | ||
split: str, | ||
filter: Callable[[Sample], bool] = lambda x: True, | ||
) -> Dataset: | ||
# use user cache dir for dataset | ||
GAIA_DATASET_LOCATION = Path(user_cache_dir("gaia_eval")) / "GAIA" | ||
|
||
# download dataset if required | ||
if not os.path.exists(GAIA_DATASET_LOCATION): | ||
GAIA_DATASET_LOCATION.mkdir(parents=True, exist_ok=True) | ||
try: | ||
snapshot_download( | ||
repo_id="gaia-benchmark/GAIA", | ||
repo_type="dataset", | ||
local_dir=GAIA_DATASET_LOCATION, | ||
) | ||
except Exception as ex: | ||
shutil.rmtree(GAIA_DATASET_LOCATION, True) | ||
raise ex | ||
|
||
# map record to sample | ||
def record_to_sample(record: dict[str, Any]) -> Sample: | ||
# map fields | ||
sample = Sample( | ||
input=input_prompt.format(question=record["Question"]), | ||
target=record["Final answer"], | ||
id=record["task_id"], | ||
metadata={ | ||
"level": record["Level"], | ||
"Annotator Metadata": record["Annotator Metadata"], | ||
}, | ||
setup="mkdir -p /shared_files/", | ||
) | ||
|
||
# apply input prompt | ||
sample.input = input_prompt.format(question=sample.input) | ||
|
||
# provide sample files | ||
files_location = GAIA_DATASET_LOCATION / "2023" / split | ||
files = [file for file in os.listdir(files_location) if str(sample.id) in file] | ||
if len(files) > 0: | ||
sample.files = { | ||
"/shared_files/" + files[0]: (files_location / files[0]).as_posix() | ||
} | ||
|
||
return sample | ||
|
||
# read dataset | ||
dataset = hf_dataset( | ||
GAIA_DATASET_LOCATION.as_posix(), | ||
name=subset, | ||
split=split, | ||
sample_fields=record_to_sample, | ||
trust=True, # Trust GAIA's remote code execution during dataset loading | ||
) | ||
|
||
# apply filter (if any) and return | ||
return dataset.filter(filter) |