diff --git a/pyproject.toml b/pyproject.toml index 98ad83762..9eb93b1dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,12 @@ dependencies = { file = ["requirements.txt"] } [tool.setuptools_scm] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["inspect_ai*"] + + [tool.ruff] extend-exclude = ["docs"] src = ["src"] @@ -59,7 +65,7 @@ warn_unused_configs = true mypy_path = "tests" [[tool.mypy.overrides]] -module = "inspect_ai.*" +module = ["inspect_ai.*", "inspect_evals.*"] warn_return_any = true disallow_untyped_defs = true disallow_any_generics = true @@ -101,6 +107,10 @@ Documentation = "https://inspect.ai-safety-institute.org.uk/" inspect = "inspect_ai._cli.main:main" [project.optional-dependencies] +inspect_evals = [ + "datasets", + "pillow" +] dev = [ "anthropic", "aioboto3", diff --git a/src/inspect_ai/_util/registry.py b/src/inspect_ai/_util/registry.py index 5472d52fd..24cb332d6 100644 --- a/src/inspect_ai/_util/registry.py +++ b/src/inspect_ai/_util/registry.py @@ -146,10 +146,16 @@ def _lookup() -> object | None: return None o = _lookup() - # if if appears to be package scoped (/ with no .) then don't - # load entry points - if o is None and name.find("/") != -1 and name.find(".") == -1: - ensure_entry_points() + + # try to recover + if o is None: + # load entry points as required + if name.find("/") != -1 and name.find(".") == -1: + ensure_entry_points() + # if this is a request for inspect_evals/ import its registry + if name.startswith("inspect_evals/"): + import inspect_evals._registry # noqa: F401 + return _lookup() else: return o diff --git a/src/inspect_ai/solver/_critique.py b/src/inspect_ai/solver/_critique.py index 54d99f01d..02cbb53e7 100644 --- a/src/inspect_ai/solver/_critique.py +++ b/src/inspect_ai/solver/_critique.py @@ -38,21 +38,21 @@ def self_critique( is used). """ # resolve templates - critique_template = resource(critique_template or DEFAULT_CRITIQUE_TEMPLATE) - completion_template = resource( + critique_templ = resource(critique_template or DEFAULT_CRITIQUE_TEMPLATE) + completion_templ = resource( completion_template or DEFAULT_CRITIQUE_COMPLETION_TEMPLATE ) # resolve model - model = get_model(model) + critique_model = get_model(model) async def solve(state: TaskState, generate: Generate) -> TaskState: # metadata without critique template variables metadata = omit(state.metadata, ["question", "completion", "critique"]) # run critique - critique = await model.generate( - critique_template.format( + critique = await critique_model.generate( + critique_templ.format( question=state.input_text, completion=state.output.completion, **metadata, @@ -62,7 +62,7 @@ async def solve(state: TaskState, generate: Generate) -> TaskState: # add the critique as a user message state.messages.append( ChatMessageUser( - content=completion_template.format( + content=completion_templ.format( question=state.input_text, completion=state.output.completion, critique=critique.completion, diff --git a/src/inspect_evals/__init__.py b/src/inspect_evals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/inspect_evals/_registry.py b/src/inspect_evals/_registry.py new file mode 100644 index 000000000..7923f6317 --- /dev/null +++ b/src/inspect_evals/_registry.py @@ -0,0 +1,3 @@ +# ruff: noqa: F401 + +from .gaia import gaia diff --git a/src/inspect_evals/gaia/README.md b/src/inspect_evals/gaia/README.md new file mode 100644 index 000000000..8018ce4df --- /dev/null +++ b/src/inspect_evals/gaia/README.md @@ -0,0 +1,73 @@ +# GAIA +This is an Inspect AI implementation of [the GAIA (General AI Assistants)](https://arxiv.org/abs/2311.12983) benchmark, consisting of 450 questions testing tool use on realistic assistant tasks (mostly web browsing). + +## Prerequisites + +1) **Installation** Install the `inspect_evals` Python package with: + + ```bash + pip install git+https://github.com/UKGovernmentBEIS/inspect_ai[inspect_evals] + ``` + +2) **Docker Engine** The GAIA task uses [tool calling](https://inspect.ai-safety-institute.org.uk/tools.html) to enable the model to execute bash commands. Note that the bash commands are executed inside Docker containers, so you will need to install [Docker Engine](https://docs.docker.com/engine/install/) in order to run the evaluation. + +3) **Hugging Face Dataset** Upon running the GAIA task, it will attempt to download the dataset from the HuggingFace hub. For this to work, you will need to gain access to the dataset (by filling out a form on [the GAIA huggingface repository](https://huggingface.co/datasets/gaia-benchmark/GAIA)), and also [create and set an access token](https://huggingface.co/docs/hub/en/security-tokens). You will need to define the `HF_TOKEN` environment variable to access the dataset: + + ``` + HF_TOKEN= + ``` + +4) **Search Provider** The default solver for the GAIA task uses the inspect [web_search()](https://inspect.ai-safety-institute.org.uk/tools.html#sec-web-search) tool. The `web_search()` tool uses [Google Programmable Search Engine](https://programmablesearchengine.google.com/about/). To use it you will therefore need to setup your own Google Programmable Search Engine and also enable the [Programmable Search Element Paid API](https://developers.google.com/custom-search/docs/paid_element). Then, ensure that the following environment variables are defined: + + ``` + GOOGLE_CSE_ID= + GOOGLE_CSE_API_KEY= + ``` + + +## Usage + +After fulfiling the prerequisites, run the GAIA task aginst various models with: + +```bash +$ inspect eval inspect_evals/gaia --model openai/gpt-4o +$ inspect eval inspect_evals/gaia --model google/gemini-1.5-pro +``` + +You might want to limit the number of samples evaluated for initial experimentation: + +```bash +$ inspect eval inspect_evals/gaia --model openai/gpt-4o --limit 5 +``` + +## Development + +If you want to develop solver for the GAIA task, import the `@task` from the `gaia` module and pass your own solver to it. For example: + +```python +from inspect_ai import eval +from inspect_ai.solver import use_tools, generate, system_message, basic_agent +from inspect_ai.tool import bash, web_search + +from inspect_evals.gaia import gaia + +agent = basic_agent(tools=[bash()]) +task = gaia(plan=agent) + +eval(task, model="openai/gpt-4o") +``` + +See the documentation on the Inspect [Agents API](https://inspect.ai-safety-institute.org.uk/agents-api.html) for additional information on developing agents. + +### Test Split + +You can evaluate against the "test" split with: + +```python +agent = basic_agent(tools=[bash()]) +task = gaia(plan=agent, split="test") + +eval(task, model="openai/gpt-4o") +``` + +Note that the GAIA "test" split does not come with any solutions so is not scored (rather, solutions are uploaded to the online leaderboard). \ No newline at end of file diff --git a/src/inspect_evals/gaia/__init__.py b/src/inspect_evals/gaia/__init__.py new file mode 100644 index 000000000..140b85f41 --- /dev/null +++ b/src/inspect_evals/gaia/__init__.py @@ -0,0 +1,3 @@ +from .gaia import gaia, gaia_dataset + +__all__ = ["gaia", "gaia_dataset"] diff --git a/src/inspect_evals/gaia/compose.yaml b/src/inspect_evals/gaia/compose.yaml new file mode 100644 index 000000000..001fb6ccc --- /dev/null +++ b/src/inspect_evals/gaia/compose.yaml @@ -0,0 +1,5 @@ +services: + default: + image: "python:3.12-bookworm" + init: true + command: "tail -f /dev/null" diff --git a/src/inspect_evals/gaia/gaia.py b/src/inspect_evals/gaia/gaia.py new file mode 100644 index 000000000..abc37d179 --- /dev/null +++ b/src/inspect_evals/gaia/gaia.py @@ -0,0 +1,138 @@ +import os +import shutil +from pathlib import Path +from typing import Any, Callable, Literal + +from huggingface_hub import snapshot_download # type: ignore +from platformdirs import user_cache_dir + +from inspect_ai import Task, task +from inspect_ai.dataset import Dataset, Sample, hf_dataset +from inspect_ai.scorer import match +from inspect_ai.solver import Solver, basic_agent, system_message +from inspect_ai.tool import bash, web_search + + +@task +def gaia( + solver: Solver | None = None, + input_prompt: str | None = None, + max_messages: int = 30, + subset: Literal[ + "2023_all", "2023_level1", "2023_level2", "2023_level3" + ] = "2023_all", + split: Literal["test", "validation"] = "validation", + filter: Callable[[Sample], bool] | None = None, +) -> Task: + # read dataset + dataset = gaia_dataset( + input_prompt=input_prompt or DEFAULT_INPUT_PROMPT, + subset=subset, + split=split, + filter=filter if filter else lambda x: True, + ) + + # provide default plan if required + if solver is None: + solver = basic_agent( + init=system_message( + "Please solve the coding task below. Once you are done, " + + "use the submit() tool to provide your answer." + ), + tools=[bash(), web_search()], + ) + + # resolve scorer (test split has no answers) + scorer = match() if split == "validation" else None + + # docker compose file is alongside + COMPOSE_FILE = Path(__file__).parent / "compose.yaml" + + # return task + return Task( + dataset=dataset, + plan=solver, + scorer=scorer, + sandbox=("docker", COMPOSE_FILE.as_posix()), + max_messages=max_messages, + ) + + +DEFAULT_INPUT_PROMPT = """Please answer the question below. You should: + +- Return only your answer, which should be a number, or a short phrase with as + few words as possible, or a comma separated list of numbers and/or strings. +- If the answer is a number, return only the number without any units unless + specified otherwise. +- If the answer is a string, don't include articles, and don't use + abbreviations (e.g. for states). +- If the answer is a comma separated list, apply the above rules to each + element in the list. + +Any files or attachments mentioned in the question can be found in the +/shared_files/ directory (some questions do not have associated files). Here +is the question: + +{question}""" + + +def gaia_dataset( + input_prompt: str, + subset: str, + split: str, + filter: Callable[[Sample], bool] = lambda x: True, +) -> Dataset: + # use user cache dir for dataset + GAIA_DATASET_LOCATION = Path(user_cache_dir("gaia_eval")) / "GAIA" + + # download dataset if required + if not os.path.exists(GAIA_DATASET_LOCATION): + GAIA_DATASET_LOCATION.mkdir(parents=True, exist_ok=True) + try: + snapshot_download( + repo_id="gaia-benchmark/GAIA", + repo_type="dataset", + local_dir=GAIA_DATASET_LOCATION, + ) + except Exception as ex: + shutil.rmtree(GAIA_DATASET_LOCATION, True) + raise ex + + # map record to sample + def record_to_sample(record: dict[str, Any]) -> Sample: + # map fields + sample = Sample( + input=input_prompt.format(question=record["Question"]), + target=record["Final answer"], + id=record["task_id"], + metadata={ + "level": record["Level"], + "Annotator Metadata": record["Annotator Metadata"], + }, + setup="mkdir -p /shared_files/", + ) + + # apply input prompt + sample.input = input_prompt.format(question=sample.input) + + # provide sample files + files_location = GAIA_DATASET_LOCATION / "2023" / split + files = [file for file in os.listdir(files_location) if str(sample.id) in file] + if len(files) > 0: + sample.files = { + "/shared_files/" + files[0]: (files_location / files[0]).as_posix() + } + + return sample + + # read dataset + dataset = hf_dataset( + GAIA_DATASET_LOCATION.as_posix(), + name=subset, + split=split, + sample_fields=record_to_sample, + trust=True, # Trust GAIA's remote code execution during dataset loading + ) + + # apply filter (if any) and return + return dataset.filter(filter)