Skip to content

Commit

Permalink
Merge branch 'main' of github.com:UKGovernmentBEIS/inspect_ai
Browse files Browse the repository at this point in the history
  • Loading branch information
jjallaire committed Aug 27, 2024
2 parents 163f75c + 51df36b commit 2236c7c
Show file tree
Hide file tree
Showing 22 changed files with 606 additions and 72 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

## Unreleased

- [Per-sample](https://inspect.ai-safety-institute.org.uk/agents.html#sec-per-sample-sandbox) Sandbox environments can now be specified (e.g. allowing for a distinct Dockerfile or Docker compose file for each sample).
- [input_screen()](https://inspect.ai-safety-institute.org.uk/interactivity.html) context manager to temporarily clear task display for user input.
- Add optional user parameter to SandboxEnvironment.exec for specifying the user. Currently only DockerSandboxEnvironment is supported.
- Sandbox environments can now be specified [per-sample](https://inspect.ai-safety-institute.org.uk/agents.html#sec-per-sample-sandbox) (e.g. allowing for a distinct Dockerfile or Docker compose file for each sample).
- [input_screen()](https://inspect.ai-safety-institute.org.uk/interactivity.html) context manager to temporairly clear task display for user input.
- Fix issue with resolving Docker configuration files when not running from the task directory.
- Treat `cwd` that are relative paths as relative to sample working directry.
- Raise error when a Solver does not return a TaskState.
- Only run tests that use model APIs when the `--runapi` flag is passed to `pytest` (prevents unintented token usage)
- Only run tests that use model APIs when the `--runapi` flag is passed to `pytest` (prevents unintended token usage)
- Added [CommonsenseQA](https://github.com/UKGovernmentBEIS/inspect_ai/tree/main/benchmarks/commonsense_qa) benchmark.

## v0.3.25 (25 August 2024)

Expand Down
3 changes: 2 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ This directory contains evals for several benchmarks. Datasets for evals are not
| HumanEval: Evaluating Large Language Models Trained on Code | <https://arxiv.org/pdf/2107.03374> | [humaneval.py](humaneval/humaneval.py) | Hugging Face |
| DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs | <https://arxiv.org/pdf/1903.00161> | [drop.py](drop/drop.py) | Hugging Face |
| WINOGRANDE: An Adversarial Winograd Schema Challenge at Scale | <https://arxiv.org/pdf/1907.10641> | [winogrande.py](winogrande/winogrande.py) | Hugging Face |
| RACE-H: A benchmark for testing reading comprehension and reasoning abilities of neural models. | <https://arxiv.org/abs/1704.04683> | [race-h.py](race-h/race-h.py) | Hugging Face |
| RACE-H: A benchmark for testing reading comprehension and reasoning abilities of neural models. | <https://arxiv.org/abs/1704.04683> | [race-h.py](race-h/race-h.py) | Hugging Face |
| CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge | <https://arxiv.org/pdf/1811.00937v2> | [commonsense_qa.py](commonsense_qa/commonsense_qa.py) | Hugging Face |
19 changes: 19 additions & 0 deletions benchmarks/commonsense_qa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge

[CommonsenseQA](https://arxiv.org/pdf/1811.00937) is a dataset designed to evaluate commonsense reasoning capabilities in natural language processing models. It consists of 12,247 multiple-choice questions that require background knowledge and commonsense to answer correctly. The dataset was constructed using CONCEPTNET, a graph-based knowledge base, where crowd-workers authored questions with complex semantics to challenge existing AI models.

## Execution
Here is an example from the dataset:
```
Question: Where can I stand on a river to see water falling without getting wet?
Options:
A) Waterfall
B) Bridge
C) Valley
D) Stream
E) Bottom
```
The model is required to choose the correct answer from the given options. In this case, the correct answer is B) Bridge.

## Evaluation
The model is prompted with the question and 5 options as input and required to choose one option by generating the corresponding answer choice A, B, C, D or E. The prompt tempate is based on the multiple choice template in OpenAI's [simple evals](https://github.com/openai/simple-evals/blob/main/mmlu_eval.py).
49 changes: 49 additions & 0 deletions benchmarks/commonsense_qa/commonsense_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge
Alon Talmor, Jonathan Herzig, Nicholas Lourie, Jonathan Berant
https://arxiv.org/pdf/1811.00937v2
# eval w/ 500 randomly selected samples
inspect eval commonsense_qa.py --limit 500
"""

from inspect_ai import Task, task
from inspect_ai.dataset import Sample
from inspect_ai.dataset._sources.hf import hf_dataset
from inspect_ai.model._generate_config import GenerateConfig
from inspect_ai.scorer import choice
from inspect_ai.solver import multiple_choice


@task
def commonsense_qa():
dataset = hf_dataset(
path="tau/commonsense_qa",
split="validation",
sample_fields=record_to_sample,
trust=True,
shuffle=True,
)

return Task(
dataset=dataset,
plan=multiple_choice(),
scorer=choice(),
config=GenerateConfig(temperature=0),
)


def record_to_sample(record):
return Sample(
input=record["question"],
choices=[
str(record["choices"]["text"][0]),
str(record["choices"]["text"][1]),
str(record["choices"]["text"][2]),
str(record["choices"]["text"][3]),
str(record["choices"]["text"][4]),
],
target=record["answerKey"],
metadata={"question_concept": record["question_concept"]},
)
5 changes: 5 additions & 0 deletions docs/extensions.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ The static class methods control the lifecycle of containers and other computing
@sandboxenv(name="podman")
class PodmanSandboxEnvironment(SandboxEnvironment):

@classmethod
def config_files(cls) -> list[str]:
...

@classmethod
async def task_init(
cls, task_name: str, config: str | None
Expand Down Expand Up @@ -160,6 +164,7 @@ The class methods take care of various stages of initialisation, setup, and tear

| Method | Lifecycle | Purpose |
|-------------------|-------------------|----------------------------------|
| `config_files()` | Called once to determine the names of 'default' config files for this provider (e.g. 'compose.yaml'). |
| `task_init()` | Called once for each unique sandbox environment config before executing the tasks in an `eval()` run. | Expensive initialisation operations (e.g. pulling or building images) |
| `sample_init()` | Called at the beginning of each `Sample`. | Create `SandboxEnvironment` instances for the sample. |
| `sample_cleanup()` | Called at the end of each `Sample` | Cleanup `SandboxEnvironment` instances for the sample. |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ norecursedirs = [
"tests/test_package",
"tests/test_task_list",
]
asyncio_mode = "auto"
log_level = "warning"

[tool.mypy]
Expand Down
10 changes: 7 additions & 3 deletions src/inspect_ai/_display/rich.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,13 @@ def task_targets(profile: TaskProfile) -> str:
def task_config(profile: TaskProfile, generate_config: bool = True) -> str:
# merge config
theme = rich_theme()
config = dict(profile.task_args) | dict(
profile.eval_config.model_dump(exclude_none=True)
)
# wind params back for display
task_args = dict(profile.task_args)
for key in task_args.keys():
value = task_args[key]
if isinstance(value, dict) and "plan" in value and "params" in value:
task_args[key] = value["plan"]
config = task_args | dict(profile.eval_config.model_dump(exclude_none=True))
if generate_config:
config = config | dict(profile.generate_config.model_dump(exclude_none=True))
config_print: list[str] = []
Expand Down
60 changes: 50 additions & 10 deletions src/inspect_ai/_eval/loader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import ast
import inspect
import os
from dataclasses import dataclass, field
from importlib.machinery import SourceFileLoader
from importlib.util import module_from_spec, spec_from_loader
from pathlib import Path
from types import ModuleType
from typing import Any, cast
from typing import Any, Callable, cast

from inspect_ai._eval.task.util import task_file
from inspect_ai._eval.task.util import task_file, task_src_dir
from inspect_ai._util.dotenv import dotenv_environ
from inspect_ai._util.path import chdir_python
from inspect_ai._util.registry import (
Expand All @@ -19,11 +20,12 @@
)
from inspect_ai.model import Model, ModelName
from inspect_ai.util import SandboxEnvironmentSpec
from inspect_ai.util._sandbox.registry import registry_find_sandboxenv

from .list import task_files
from .registry import task_create
from .task import PreviousTask, Task, TaskInfo, Tasks
from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR, TASK_SRC_DIR_ATTR
from .task.run import EvalSampleSource, eval_log_sample_source


Expand Down Expand Up @@ -61,13 +63,7 @@ def as_resolved_tasks(tasks: list[Task]) -> list[ResolvedTask]:
task_args=resolve_task_args(task),
task_file=task_file(task, relative=True),
model=model,
sandbox=(
(sandbox, None)
if isinstance(sandbox, str)
else sandbox
if sandbox is not None
else task.sandbox
),
sandbox=resolve_task_sandbox(task, sandbox),
sequence=sequence,
)
for sequence, task in enumerate(tasks)
Expand Down Expand Up @@ -144,6 +140,49 @@ def resolve_task_args(task: Task) -> dict[str, Any]:
return {}


def resolve_task_sandbox(
task: Task, sandbox: SandboxEnvironmentSpec | None
) -> tuple[str, str | None] | None:
# do the resolution
resolved_sandbox = (
(sandbox, None)
if isinstance(sandbox, str)
else sandbox
if sandbox is not None
else task.sandbox
)

# if we have a sandbox with no config, see if there are implcit
# config files available for the provider
if resolved_sandbox is not None:
# look for default
if resolved_sandbox[1] is None:
# get config files for this type
sandboxenv_type = registry_find_sandboxenv(resolved_sandbox[0])
config_files_fn = cast(
Callable[..., list[str]], getattr(sandboxenv_type, "config_files")
)
config_files = config_files_fn()

# probe for them in task src dir
src_dir = task_src_dir(task)
for config_file in config_files:
config_file_path = os.path.join(src_dir, config_file)
if os.path.isfile(config_file_path):
resolved_sandbox = (resolved_sandbox[0], config_file)
break

# resolve relative paths
if resolved_sandbox[1] is not None:
file_path = Path(resolved_sandbox[1])
if not file_path.is_absolute():
file_path = Path(task_src_dir(task)) / file_path
resolved_sandbox = (resolved_sandbox[0], file_path.as_posix())

# return resolved sandbox
return resolved_sandbox


def load_tasks(
task_specs: list[str] | None, model: Model, task_args: dict[str, Any] = {}
) -> list[Task]:
Expand Down Expand Up @@ -231,6 +270,7 @@ def create_file_tasks(
# (will be used later to ensure it runs in the directory)
task = task_create(task_spec, model, **task_args)
setattr(task, TASK_FILE_ATTR, file.as_posix())
setattr(task, TASK_SRC_DIR_ATTR, file.parent.as_posix())
if task.attribs.get("chdir", True):
setattr(task, TASK_RUN_DIR_ATTR, file.parent.as_posix())
tasks.append(task)
Expand Down
1 change: 1 addition & 0 deletions src/inspect_ai/_eval/task/constants.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
TASK_FILE_ATTR = "__task_file__"
TASK_RUN_DIR_ATTR = "__task_run_dir__"
TASK_SRC_DIR_ATTR = "__task_src_dir__"
6 changes: 5 additions & 1 deletion src/inspect_ai/_eval/task/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from inspect_ai.model import ChatMessage, ChatMessageUser

from ..task import Task
from .constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR
from .constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR, TASK_SRC_DIR_ATTR


def sample_messages(sample: Sample) -> list[ChatMessage]:
Expand All @@ -24,6 +24,10 @@ def task_run_dir(task: Task) -> str:
return getattr(task, TASK_RUN_DIR_ATTR, os.getcwd())


def task_src_dir(task: Task) -> str:
return getattr(task, TASK_SRC_DIR_ATTR, os.getcwd())


def task_file(task: Task, relative: bool = False) -> str | None:
file = cast(str | None, getattr(task, TASK_FILE_ATTR, None))
if file:
Expand Down
26 changes: 24 additions & 2 deletions src/inspect_ai/_util/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Literal, cast

from pydantic import BaseModel, Field
from pydantic_core import to_jsonable_python

from .constants import PKG_NAME

Expand Down Expand Up @@ -74,14 +75,26 @@ def registry_tag(
named_params[params[i]] = arg
named_params |= kwargs

# plan objects are serialised with name and params
for param in named_params.keys():
value = named_params[param]
if is_registry_object(value) and registry_info(value).type == "plan":
named_params[param] = dict(
plan=registry_log_name(value), params=registry_params(value)
)

# callables are not serializable so use their names
for param in named_params.keys():
if is_registry_object(named_params[param]):
named_params[param] = registry_info(named_params[param]).name
elif hasattr(named_params[param], "__name__"):
elif callable(named_params[param]):
named_params[param] = getattr(named_params[param], "__name__")
elif isinstance(named_params[param], dict | list):
named_params[param] = to_jsonable_python(
named_params[param], fallback=lambda x: getattr(x, "__name__", None)
)
else:
named_params[param] = str(named_params[param])
named_params[param] = named_params[param]

# set attribute
setattr(o, REGISTRY_INFO, info)
Expand Down Expand Up @@ -157,6 +170,15 @@ def registry_create(type: RegistryType, name: str, **kwargs: Any) -> object:
def with_registry_info(o: object) -> object:
return set_registry_info(o, registry_info(obj))

# instantiate plan objects for tasks
if type == "task":
for param in kwargs.keys():
value = kwargs[param]
if isinstance(value, dict) and "plan" in value and "params" in value:
kwargs[param] = registry_create(
"plan", value["plan"], **value["params"]
)

if isclass(obj):
return with_registry_info(obj(**kwargs))
elif callable(obj):
Expand Down
33 changes: 18 additions & 15 deletions src/inspect_ai/util/_sandbox/docker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,19 @@
logger = getLogger(__name__)


async def auto_compose(parent: str = "") -> str | None:
# compose file provides all the config we need
if has_compose_file(parent):
return None
CONFIG_FILES = [
"compose.yaml",
"compose.yml",
"docker-compose.yaml",
"docker-compose.yml",
]


async def resolve_compose_file(parent: str = "") -> str | None:
# existing compose file provides all the config we need
compose = find_compose_file(parent)
if compose is not None:
return Path(os.path.join(parent, compose)).resolve().as_posix()

# temporary auto-compose
if has_auto_compose_file(parent):
Expand All @@ -25,17 +34,11 @@ async def auto_compose(parent: str = "") -> str | None:
return await auto_compose_file(COMPOSE_GENERIC_YAML, parent)


def has_compose_file(parent: str = "") -> bool:
compose_files = [
"compose.yaml",
"compose.yml",
"docker-compose.yaml",
"docker-compose.yml",
]
for file in compose_files:
def find_compose_file(parent: str = "") -> str | None:
for file in CONFIG_FILES:
if os.path.isfile(os.path.join(parent, file)):
return True
return False
return file
return None


def has_dockerfile(parent: str = "") -> bool:
Expand All @@ -52,7 +55,7 @@ def is_auto_compose_file(file: str) -> bool:

async def ensure_auto_compose_file(file: str | None) -> None:
if file is not None and is_auto_compose_file(file) and not os.path.exists(file):
await auto_compose(os.path.dirname(file))
await resolve_compose_file(os.path.dirname(file))


def safe_cleanup_auto_compose(file: str | None) -> None:
Expand Down
Loading

0 comments on commit 2236c7c

Please sign in to comment.