Skip to content

Commit

Permalink
sync 01-07-24
Browse files Browse the repository at this point in the history
  • Loading branch information
aisi-inspect committed Jul 2, 2024
1 parent 2d7561f commit b0c8301
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/inspect_ai/_eval/context.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from inspect_ai.model import Model
from inspect_ai.model._model import init_active_model, init_model_usage
from inspect_ai.util._concurrency import init_concurrency
from inspect_ai.util._logger import init_logger_records
from inspect_ai.util._subprocess import init_max_subprocesses


def init_eval_context(max_subprocesses: int | None = None) -> None:
init_concurrency()
init_max_subprocesses(max_subprocesses)


Expand Down
46 changes: 30 additions & 16 deletions src/inspect_ai/scorer/_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@
from ._target import Target


def match_target(match: str, target: Target, ignore_case: bool) -> bool:
if ignore_case:
match = match.lower()
target = Target([t.lower() for t in target])

return match in target


def match_first(
matches: tuple[str | Any, ...], target: Target, ignore_case: bool
) -> str | None:
for match in matches:
if isinstance(match, str):
if ignore_case:
match = match.lower()
if not isinstance(match, str):
continue

if match in target:
return match
if match_target(match, target, ignore_case):
return match

return None

Expand All @@ -27,12 +34,11 @@ def match_all_groups(
matches: tuple[str | Any, ...], target: Target, ignore_case: bool
) -> str | None:
for match in matches:
if isinstance(match, str):
if ignore_case:
match = match.lower()
if not isinstance(match, str):
continue

if match not in target:
return None
if not match_target(match, target, ignore_case):
return None

return target.text

Expand Down Expand Up @@ -64,21 +70,29 @@ async def score(state: TaskState, target: Target) -> Score:
)

if match:
if ignore_case:
target = Target([t.lower() for t in target])

groups = match.groups()
if match_all:
found_match = match_all_groups(
matches=match.groups(), target=target, ignore_case=ignore_case
matches=groups, target=target, ignore_case=ignore_case
)
answer = found_match
else:
found_match = match_first(
matches=match.groups(), target=target, ignore_case=ignore_case
matches=groups, target=target, ignore_case=ignore_case
)

if found_match is None and len(groups) == 1:
# A common use of a pattern is to extract a single answer
# from some templated text. If we fail to match in that
# scenario, it's worth returning the failed match because
# this is useful information for the user.
answer = groups[0]
else:
answer = found_match

return Score(
value=CORRECT if found_match else INCORRECT,
answer=found_match,
answer=answer,
explanation=state.output.completion,
)
else:
Expand Down
2 changes: 2 additions & 0 deletions src/inspect_ai/solver/_tool/environment/docker/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def safe_cleanup_auto_compose(file: str | None) -> None:
image: "python:3.12-bookworm"
command: "tail -f /dev/null"
network_mode: none
stop_grace_period: 1s
"""

COMPOSE_DOCKERFILE_YAML = f"""{COMPOSE_COMMENT}
Expand All @@ -84,6 +85,7 @@ def safe_cleanup_auto_compose(file: str | None) -> None:
context: "."
command: "tail -f /dev/null"
network_mode: none
stop_grace_period: 1s
"""


Expand Down
15 changes: 11 additions & 4 deletions src/inspect_ai/util/_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from contextvars import ContextVar
from dataclasses import dataclass


Expand Down Expand Up @@ -40,29 +41,35 @@ def concurrency(
key = key if key else name

# do we have an existing semaphore? if not create one and store it
semaphore = _concurrency_semaphores.get(key, None)
semaphore = _concurrency_semaphores.get().get(key, None)
if semaphore is None:
semaphore = ConcurencySempahore(
name, concurrency, asyncio.Semaphore(concurrency)
)
_concurrency_semaphores[key] = semaphore
_concurrency_semaphores.get()[key] = semaphore

# return the semaphore
return semaphore.semaphore


def concurrency_status() -> dict[str, tuple[int, int]]:
status: dict[str, tuple[int, int]] = {}
for c in _concurrency_semaphores.values():
for c in _concurrency_semaphores.get().values():
status[c.name] = (c.concurrency - c.semaphore._value, c.concurrency)
return status


def init_concurrency() -> None:
_concurrency_semaphores.set({})


@dataclass
class ConcurencySempahore:
name: str
concurrency: int
semaphore: asyncio.Semaphore


_concurrency_semaphores: dict[str, ConcurencySempahore] = {}
_concurrency_semaphores: ContextVar[dict[str, ConcurencySempahore]] = ContextVar(
"concurrency_semaphores", default={}
)
28 changes: 28 additions & 0 deletions tests/scorer/test_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,31 @@ async def test_only_returns_exact_target_matches():
result = await scorer(state, Target(["bar"]))

assert result.text == INCORRECT


@pytest.mark.asyncio
async def test_one_match_group_returns_incorrect_match():
scorer = pattern(
"ANSWER: (A|B)",
ignore_case=False,
match_all=False,
)
state = simple_task_state(model_output="ANSWER: A")
result = await scorer(state, Target(["B"]))

assert result.answer == "A"
assert result.text == INCORRECT


@pytest.mark.asyncio
async def test_multiple_match_group_returns_none():
scorer = pattern(
"ANSWER: (A|B) ALTERNATE_ANSWER: (A|B)",
ignore_case=False,
match_all=False,
)
state = simple_task_state(model_output="ANSWER: A ALTERNATE_ANSWER: A")
result = await scorer(state, Target(["B"]))

assert result.answer is None
assert result.text == INCORRECT
5 changes: 5 additions & 0 deletions tools/vscode/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 0.3.23

- Ensure the log view only opens in the correct window when debugging a task
- Changes to improve performance and usability of large log files

## 0.3.22

- Improve reliability of opening and viewing log files upon completion of evaluations
Expand Down
2 changes: 1 addition & 1 deletion tools/vscode/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"author": {
"name": "UK AI Safety Institute"
},
"version": "0.3.22",
"version": "0.3.23",
"license": "MIT",
"homepage": "https://ukgovernmentbeis.github.io/inspect_ai/",
"repository": {
Expand Down
42 changes: 32 additions & 10 deletions tools/vscode/src/providers/inspect/inspect-eval.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import { DebugConfiguration, ExtensionContext, debug, window, workspace } from "vscode";
import {
DebugConfiguration,
ExtensionContext,
debug,
window,
workspace,
} from "vscode";
import { inspectEvalCommands } from "./inspect-eval-commands";
import { Command } from "../../core/command";
import {
Expand All @@ -20,7 +26,6 @@ export async function activateEvalManager(
// Activate the manager
const inspectEvalMgr = new InspectEvalManager(stateManager);


// Set up our terminal environment
// Update the workspace id used in our terminal environments
await stateManager.initializeWorkspaceId();
Expand All @@ -32,15 +37,14 @@ export async function activateEvalManager(

log.append(`new: ${workspaceId}`);

env.delete('INSPECT_WORKSPACE_ID');
env.append('INSPECT_WORKSPACE_ID', workspaceId);
env.delete("INSPECT_WORKSPACE_ID");
env.append("INSPECT_WORKSPACE_ID", workspaceId);

return [inspectEvalCommands(inspectEvalMgr), inspectEvalMgr];
}

export class InspectEvalManager {
constructor(private readonly stateManager_: WorkspaceStateManager) {
}
constructor(private readonly stateManager_: WorkspaceStateManager) { }

public async startEval(file: AbsolutePath, task?: string, debug = false) {
// if we don't have inspect bail and let the user know
Expand Down Expand Up @@ -113,7 +117,6 @@ export class InspectEvalManager {

// If we're debugging, launch using the debugger
if (debug) {

// Handle debugging
let debugPort = 5678;
if (debug === true) {
Expand All @@ -124,7 +127,19 @@ export class InspectEvalManager {
args.push(debugPort.toString());
}

await runDebugger(inspectBinPath()?.path || "inspect", args, workspaceDir.path, debugPort);
// Pass the workspace ID to the debug environment so we'll
// properly target the workspace window when showing the logview
const env = {
INSPECT_WORKSPACE_ID: this.stateManager_.getWorkspaceInstance(),
};

await runDebugger(
inspectBinPath()?.path || "inspect",
args,
workspaceDir.path,
debugPort,
env
);
} else {
// Run the command
runEvalCmd(args, workspaceDir.path);
Expand All @@ -146,7 +161,13 @@ const runEvalCmd = (args: string[], cwd: string) => {
terminal.sendText(["inspect", ...args].join(" "));
};

const runDebugger = async (program: string, args: string[], cwd: string, port: number) => {
const runDebugger = async (
program: string,
args: string[],
cwd: string,
port: number,
env?: Record<string, string>
) => {
const name = "Inspect Eval";
const debugConfiguration: DebugConfiguration = {
name,
Expand All @@ -157,7 +178,8 @@ const runDebugger = async (program: string, args: string[], cwd: string, port: n
console: "internalConsole",
cwd,
port,
"justMyCode": false
env,
justMyCode: false,
};
await debug.startDebugging(activeWorkspaceFolder(), debugConfiguration);
};

0 comments on commit b0c8301

Please sign in to comment.