diff --git a/CHANGELOG.md b/CHANGELOG.md index e0c61aee9..3b7d0b568 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,12 @@ ## Unreleased +- Evaluate multiple models in parallel by passing a list of models to `eval()`. - Add `api_key` to `get_model()` for explicitly specifying an API key for a model. +- Improved handling of very large (> 100MB) log files in Inspect View. +- Use `network_mode: none` for disabling networking by default in Docker tool environments. +- Allow tool environent providers to specify a default `max_samples` (set to 25 for the Docker provider). +- Prevent concurrent calls to `eval_async()` (unsafe because of need to change directories for tasks). Parallel task evaluation will instead be implemented as a top-level feature of `eval()` and `eval_async()` ## v0.3.17 (25 June 2024) diff --git a/docs/_format/pre-render.sh b/docs/_format/pre-render.sh index 731bc4fd5..bce4debaf 100755 --- a/docs/_format/pre-render.sh +++ b/docs/_format/pre-render.sh @@ -6,4 +6,4 @@ if [ -n "${QUARTO_PROJECT_RENDER_ALL}" ]; then (echo; echo) >> ../examples.qmd for f in security_guide.qmd hellaswag.qmd theory_of_mind.qmd mathematics.qmd biology_qa.qmd arc.qmd tool_use.qmd gsm8k.qmd footer.qmd; do (cat "${f}"; echo; echo; echo) >> ../examples.qmd; done cd .. -fi +fi \ No newline at end of file diff --git a/docs/_quarto.yml b/docs/_quarto.yml index d8edec2a1..a4792aad2 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -68,9 +68,9 @@ book: - part: "Advanced" chapters: - caching.qmd + - parallelism.qmd - eval-logs.qmd - eval-suites.qmd - - eval-tuning.qmd - extensions.qmd toc-depth: 2 @@ -96,4 +96,4 @@ format: # date: today execute: - enabled: false + enabled: false \ No newline at end of file diff --git a/docs/_sample-preservation.md b/docs/_sample-preservation.md index 450f50335..78ef132eb 100644 --- a/docs/_sample-preservation.md +++ b/docs/_sample-preservation.md @@ -16,4 +16,6 @@ If dataset shuffling is important to your evaluation and you want to preserve sa Another consideration is `max_samples`, which is the maximum number of samples to run concurrently within a task. Larger numbers of concurrent samples will result in higher throughput, but will also result in completed samples being written less frequently to the log file, and consequently less total recovable samples in the case of an interrupted task. -By default, Inspect sets the value of `max_samples` to `max_connections + 1`, ensuring that the model API is always fully saturated (note that it would rarely make sense to set it _lower_ than `max_connections`). The default `max_connections` is 10, which will typically result in samples being written to the log frequently. On the other hand, setting a very large `max_connections` (e.g. 100 `max_connections` for a dataset with 100 samples) may result in very few recoverable samples in the case of an interruption. \ No newline at end of file +By default, Inspect sets the value of `max_samples` to `max_connections + 1`, ensuring that the model API is always fully saturated (note that it would rarely make sense to set it _lower_ than `max_connections`). The default `max_connections` is 10, which will typically result in samples being written to the log frequently. On the other hand, setting a very large `max_connections` (e.g. 100 `max_connections` for a dataset with 100 samples) may result in very few recoverable samples in the case of an interruption. + +Note also that when using [Tool Environments](#sec-tool-environments), the tool environment provider may place an additional cap on the default `max_samples` (for example, the Docker provider limits the default `max_samples` to no more than 25). \ No newline at end of file diff --git a/docs/agents.qmd b/docs/agents.qmd index 436093052..ed382cbc6 100644 --- a/docs/agents.qmd +++ b/docs/agents.qmd @@ -579,7 +579,7 @@ eval("ctf.py", toolenv_cleanup = False) When you do this, you'll see something like the following printed out at the end of the eval: -![](images/toolenv-no-cleanup.png){.border} +![](images/toolenv-no-cleanup.png){.border fig-alt="A printed list of yet to be cleaned up Docker tool environments (including the container id and cleanup command for each one)"} You then might use this command to get a shell inside one of the containers: diff --git a/docs/datasets.qmd b/docs/datasets.qmd index aab4198e7..ddf097d75 100644 --- a/docs/datasets.qmd +++ b/docs/datasets.qmd @@ -2,7 +2,7 @@ ## Overview -Inspect has native support for reading datasets in the CSV, JSON, and JSON Lines formats, as well as from [Hugging Face](#sec-hugging-face-datasets). In addition, the core dataset interface for the evaluation pipeline is flexible enough to accept data read from just about any source. +Inspect has native support for reading datasets in the CSV, JSON, and JSON Lines formats, as well as from [Hugging Face](#sec-hugging-face-datasets). In addition, the core dataset interface for the evaluation pipeline is flexible enough to accept data read from just about any source (see the [Custom Reader](#sec-custom-reader) section below for details). If your data is already in a format amenable for direct reading as an Inspect `Sample`, reading a dataset is as simple as this: @@ -216,22 +216,22 @@ ChatMessageUser(content = [ Note that image input is currently only supported for OpenAI vision models (e.g. [gpt-4-vision-preview](https://platform.openai.com/docs/guides/vision)), Google Gemini vision models (e.g. [gemini-pro-vision](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemini-pro-vision)), and Anthropic Claude 3 models. ::: -## Custom Reader +## Custom Reader {#sec-custom-reader} -You are not restricted to the built in dataset functions for reading samples. Since the `dataset` field of the `Task` class takes either a `Dataset` or a sequences of`Sample`, the following is also valid: +You are not restricted to the built in dataset functions for reading samples. You can also construct a `MemoryDataset`, and pass that to a task. For example: ``` python from inspect_ai import Task, task -from inspect_ai.dataset import Sample +from inspect_ai.dataset import MemoryDataset, Sample from inspect_ai.scorer import model_graded_fact from inspect_ai.solver import generate, system_message -dataset=[ +dataset=MemoryDataset([ Sample( input="What cookie attributes should I use for strong security?", target="secure samesite and httponly", ) -] +]) @task def security_guide(): @@ -242,4 +242,4 @@ def security_guide(): ) ``` -So if the built in dataset functions don't meet your needs, you can create a custom function that yields a list of `Sample` instances and pass those directly to your `Task`. \ No newline at end of file +So if the built in dataset functions don't meet your needs, you can create a custom function that yields a `MemoryDataset`and pass those directly to your `Task`. \ No newline at end of file diff --git a/docs/extensions.qmd b/docs/extensions.qmd index 59b0e6e37..759fb89a7 100644 --- a/docs/extensions.qmd +++ b/docs/extensions.qmd @@ -53,12 +53,14 @@ For example, if your package was named `inspect_package` and your model provider ::: {.panel-tabset group="entry-points"} ## Setuptools + ``` toml [project.entry-points.inspect_ai] inspect_package = "inspect_package.inspect_extensions" ``` ## Poetry + ``` toml [tool.poetry.plugins.inspect_ai] inspect_package = "inspect_package.inspect_extensions" @@ -170,12 +172,15 @@ class PodmanToolEnvironment(ToolEnvironment): The class methods take care of various stages of initialisation, setup, and teardown: | Method | Lifecycle | Purpose | -|------------------|------------------|------------------------------------| +|-------------------|-------------------|----------------------------------| | `task_init()` | Called at the beginning of each `Task`. | Expensive initialisation operations (e.g. pulling or building images) | | `sample_init()` | Called at the beginning of each `Sample`. | Create `ToolEnvironment` instances for the sample. | | `sample_cleanup()` | Called at the end of each `Sample` | Cleanup `ToolEnvironment` instances for the sample. | | `task_cleanup()` | Called at the end of each `Task`. | Last chance handler for any resources not yet cleaned up (see also discussion below). | | `cli_cleanup()` | Called via `inspect toolenv cleanup` | CLI invoked manual cleanup of resources created by this `ToolEnvironment`. | +| `max_samples()` | Called at startup | Provide a default `max_samples` (used to cap the default, explicit `max_samples` will override this). | + +In the case of parallel execution of a group of tasks that share a working directory and tool environment, the `task_init()` and `task_cleanup()` functions may be called once for the entire group as a performance optimisation. The `task_cleanup()` has a number of important functions: @@ -195,9 +200,9 @@ The `task_cleanup()` function will typically print out the information required The `ToolEnvironment` instance methods provide access to process execution and file input/output within the environment. A few notes on implementing these methods: -1. The `exec()` method currently only handles text output. If a call results in binary output then a `UnicodeDecodeError` will be raised. Tool environments should catch this and raise a `ToolError`. +1. The `exec()` method currently only handles text output. If a call results in binary output then a `UnicodeDecodeError` will be raised. Tool environments should catch this and raise a `ToolError`. -2. The `read_file()` method raise a `FileNotFoundError` if the specified `file` does not exist in the tool environment, as tools calling `read_file()` will often want to catch the `FileNotFoundError` and re-throw a `ToolError` (since models will frequently attempt to read files that do not exist). +2. The `read_file()` method raise a `FileNotFoundError` if the specified `file` does not exist in the tool environment, as tools calling `read_file()` will often want to catch the `FileNotFoundError` and re-throw a `ToolError` (since models will frequently attempt to read files that do not exist). The best way to learn about writing tool environments is to look at the source code for the built in environments, [LocalToolEnvironment](https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/src/inspect_ai/solver/_tool/environment/local.py) and [DockerToolEnvironment](https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/src/inspect_ai/solver/_tool/environment/docker/docker.py). @@ -209,12 +214,14 @@ For example, if your package was named `inspect_package` and your tool environme ::: {.panel-tabset group="entry-points"} ## Setuptools + ``` toml [project.entry-points.inspect_ai] inspect_package = "inspect_package.inspect_extensions" ``` ## Poetry + ``` toml [tool.poetry.plugins.inspect_ai] inspect_package = "inspect_package.inspect_extensions" @@ -299,12 +306,14 @@ As with Model APIs and Tool Environments, fsspec filesystems should be registere ::: {.panel-tabset group="entry-points"} ## Setuptools + ``` toml [project.entry-points."fsspec.specs"] myfs = "inspect_package:MyFs" ``` ## Poetry + ``` toml [tool.poetry.plugins."fsspec.specs"] myfs = "inspect_package:MyFs" diff --git a/docs/images/inspect-multiple-models.png b/docs/images/inspect-multiple-models.png new file mode 100644 index 000000000..a0a3bdf3f Binary files /dev/null and b/docs/images/inspect-multiple-models.png differ diff --git a/docs/index.qmd b/docs/index.qmd index 841deacfc..5b6f442a5 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -179,9 +179,9 @@ These sections discuss more advanced features and workflow. You don't need to re - [Caching](#sec-caching) enables you to cache model output to reduce the number of API calls made, saving both time and expense. -- [Eval Logs](#sec-eval-logs) explores how to get the most out of evaluation logs for developing, debugging, and analyzing evaluations. +- [Parallelism](#sec-parallelism) delves into how to obtain maximum performance for evaluations. Inspect uses a highly parallel async architecture---here we cover how to tune this parallelism (e.g to stay under API rate limits or to not overburden local compute) for optimal throughput. -- [Eval Tuning](#sec-eval-tuning) delves into how to obtain maximum performance for evaluations. Inspect uses a highly parallel async architecture---here we cover how to tune this parallelism (e.g to stay under API rate limits or to not overburden local compute) for optimal throughput. +- [Eval Logs](#sec-eval-logs) explores how to get the most out of evaluation logs for developing, debugging, and analyzing evaluations. - [Eval Suites](#sec-eval-suites) covers Inspect's features for describing, running, and analysing larger sets of evaluation tasks. diff --git a/docs/models.qmd b/docs/models.qmd index 5c62e53f3..481c26960 100644 --- a/docs/models.qmd +++ b/docs/models.qmd @@ -118,7 +118,7 @@ Use `inspect eval --help` to learn about all of the available generation config Inspect uses an asynchronous architecture to run task samples in parallel. If your model provider can handle 100 concurrent connections, then Inspect can utilise all of those connections to get the highest possible throughput. The limiting factor on parallelism is therefore not typically local parallelism (e.g. number of cores) but rather what the underlying rate limit is for your interface to the provider. -If you are experiencing rate-limit errors you will need to experiment with the `max_connections` option to find the optimal value that keeps you under the rate limit (the section on [Eval Tuning](eval-tuning.qmd) includes additional documentation on how to do this). Note that the next section describes how you can set a model-provider specific value for `max_connections` as well as other generation options. +If you are experiencing rate-limit errors you will need to experiment with the `max_connections` option to find the optimal value that keeps you under the rate limit (the section on [Parallelism](parallelism.qmd) includes additional documentation on how to do this). Note that the next section describes how you can set a model-provider specific value for `max_connections` as well as other generation options. ### Model Specific Configuration diff --git a/docs/eval-tuning.qmd b/docs/parallelism.qmd similarity index 74% rename from docs/eval-tuning.qmd rename to docs/parallelism.qmd index ad97c28ae..5a84d59c4 100644 --- a/docs/eval-tuning.qmd +++ b/docs/parallelism.qmd @@ -1,20 +1,22 @@ -# Eval Tuning {#sec-eval-tuning} +--- +aliases: + - eval-tuning.html +--- -## Overview - -Inspect runs evaluations using a highly parallel async architecture. Rather than processing a batch at a time, many samples are processed concurrently. This is possible because evaluations generally use relatively little local compute, but rather spend most of their time waiting for model API calls and web requests to complete. Consequently, Inspect eagerly executes as much local computation as it can and at the same time ensures that model APIs are not over-saturated by enforcing a maximum number of concurrent connections. +# Parallelism {#sec-parallelism} -This section describes how to tune Inspect's concurrency, as well as how to handle situations where more local compute is required. - -## Max Samples +## Overview -The `max_samples` option determines how many samples are executed in parallel. By default, `max_samples` is set to `max_connections` so that the connection to the Model API can be fully utilised. See the section below for more details on `max_connections`. +Inspect runs evaluations using a parallel async architecture, eagerly executing many samples in parallel while at the same time ensuring that that resources aren't over-saturated by enforcing various limits (e.g. maximum number of concurrent model connections, maximum number of subprocesses, etc.). -If you have additional expensive operations beyond calling models (e.g. using a [Tool Environment](#sec-tool-environments)) then you may want to increase `max_samples` to fully saturate both the Model API and container subprocesses used for tool execution. When running an evaluation you'll see an indicator of how many connections and how many subprocesses are currently active. If neither is at capacity then you will likely benefit from increasing `max_samples`. +There are a progression of concurrency concerns, and while most evaluations can rely on the Inspect default behaviour, others will benefit from more customisation. Below we'll cover the following: -Note that setting `max_samples` to an arbitrarily high number does have some disadvantages: you will consume more memory (especially if using tool environments) as well as wait longer for completed samples to be logged (so could be subject to losing more work if your eval task fails). +1. Model API connection concurrency. +2. Evaluting multiple models in parallel. +3. Tool environment concurrency. +4. Writing parallel code in custom solvers and scorers. -## Model APIs +## Model Connections ### Max Connections @@ -54,89 +56,89 @@ $ inspect eval --model openai/gpt-4 --log-level=http Note that max connections is applied per-model. This means that if you use a grader model from a provider distinct from the one you are evaluating you will get extra concurrency (as each model will enforce its own max connections). ::: -## Other APIs +## Multiple Models {#sec-multiple-models} -It's possible that your custom solvers, tools, or scorers will call other REST APIs. Two things to keep in mind when doing this are: +::: {.callout-note appearance="simple"} +The multiple models feature described below is available in only the development version of Inspect (it is not yet published to PyPI). You can install the development version with: -1. It's critical that connections to other APIs use `async` HTTP APIs (i.e. the `httpx` model rather than the `requests` module). This is because Inspect's parallelism relies on everything being `async`, so if you make a blocking HTTP call with `requests` it will actually hold up all of the rest of the work in system! +```bash +$ pip install git+https://github.com/UKGovernmentBEIS/inspect_ai +``` +::: -2. As with model APIs, rate limits may be in play, so it's important not to over-saturate these connections. Recall that Inspect runs all samples in parallel so if you have 500 samples and don't do anything to limit concurrency, you will likely end up making hundreds of calls at a time to the API. +You can evaluate multiple models in parallel by passing a list of models to the `eval()` function. For example: -Here's some (oversimplified) example code that illustrates how to call a REST API within an Inspect component. We use the `async` interface of the `httpx` module, and we use Inspect's `concurrency()` function to limit simultaneous connections to 10: +```python +eval("mathematics.py", model=[ + "openai/gpt-4-turbo", + "anthropic/claude-3-opus-20240229", + "google/gemini-1.5-pro" +]) +``` -``` python -import httpx -from inspect_ai.util import concurrency -from inspect_ai.solver import Generate, TaskState +![](images/inspect-multiple-models.png){fig-alt="An evaluation task display show the progress for 3 differnet models."} -client = httpx.AsyncClient() +Since each model provider has its own `max_connections` they don't contend with each other for resources. If you need to evaluate multiple models, doing so concurrently is highly recommended. -async def solve(state: TaskState, generate: Generate): - ... - # wrap the call to client.get() in an async concurrency - # block to limit simultaneous connections to 10 - async with concurrency("my-rest-api", 10): - response = await client.get("https://example.com/api") +If you want to specify multiple models when using the `--model` CLI argument or `INSPECT_EVAL_MODEL` environment variable, just separate the model names with commas. For example: + +```bash +INSPECT_EVAL_MODEL=openai/gpt-4-turbo,google/gemini-1.5-pro ``` -Note that we pass a name ("my-rest-api") to the `concurrency()` function. This provides a named scope for managing concurrency for calls to that specific API/service. +## Tool Environments {#sec-parallel-tool-environments} -## Subprocesses +[Tool Environments](#sec-tool-environments) (e.g. Docker containers) often allocate resources on a per-sample basis, and also make use of the Inspect `subprocess()` function for executing commands within the environment. -It's possible that your custom solvers, tools, or scorers will need to launch child processes to perform various tasks. Subprocesses have similar considerations as calling APIs: you want to make sure that they don't block the rest of the work in Inspect (so they should be invoked with `async`) and you also want to make sure they don't provide *too much* concurrency (i.e. you wouldn't want to launch 200 processes at once on a 4 core machine!) +### Max Samples -To assist with this, Inspect provides the `subprocess()` function. This `async` function takes a command and arguments and invokes the specified command asynchronously, collecting and returning stdout and stderr. The `subprocess()` function also automatically limits concurrent child processes to the number of CPUs on your system (`os.cpu_count()`). Here's an example from the implementation of a `list_files()` tool: +The `max_samples` option determines how many samples are executed in parallel (and in th case of Docker containers how many containers are run in parallel). By default, `max_samples` is set to `max_connections` so that the connection to the Model API can be fully utilised. -``` python -@tool(prompt=( - "If you are asked to list the files in a directory you " - + "should call the list_files function to access the listing." -)) -def list_files(): - async def execute(dir: str): - """List the files in a directory. +Since Tool enviroinments include additional expensive operations beyond calling models, you may want to increase `max_samples` to fully saturate both the Model API and container subprocesses used for tool execution. When running an evaluation you'll see an indicator of how many connections and how many subprocesses are currently active. If neither is at capacity then you will likely benefit from increasing `max_samples`. - Args: - dir (str): Directory +Note that setting `max_samples` to an arbitrarily high number does have some disadvantages: you will consume more memory (especially if using tool environments) as well as wait longer for completed samples to be logged (so could be subject to losing more work if your eval task fails). - Returns: - File listing of the directory - """ - result = await subprocess(["ls", dir]) - if result.success: - return result.stdout - else: - raise ToolError(result.stderr) +### Max Subprocesses - return execute -``` +The `max_subprocesses` option determines how many subprocesses calls can run in parallel. By defualt, this is set to `os.cpu_count()`. Depending on the nature of execution done inside tool environments, you might benefit from increasing or decreasting `max_subprocesses`. -The maximum number of concurrent subprocesses can be modified using the `--max-subprocesses` option. For example: +## Solvers and Scorers {#sec-parallel-solvers-and-scorers} -``` bash -$ inspect eval --model openai/gpt-4 --max-subprocesses 4 -``` +### REST APIs -Note that if you need to execute computationally expensive code in an eval, you should always factor it into a call to `subprocess()` so that you get optimal concurrency and performance. +It's possible that your custom solvers, tools, or scorers will call other REST APIs. Two things to keep in mind when doing this are: -### Timeouts +1. It's critical that connections to other APIs use `async` HTTP APIs (i.e. the `httpx` model rather than the `requests` module). This is because Inspect's parallelism relies on everything being `async`, so if you make a blocking HTTP call with `requests` it will actually hold up all of the rest of the work in system! -If you need to ensure that your subprocess runs for no longer than a specified interval, you can use the `timeout` option. For example: +2. As with model APIs, rate limits may be in play, so it's important not to over-saturate these connections. Recall that Inspect runs all samples in parallel so if you have 500 samples and don't do anything to limit concurrency, you will likely end up making hundreds of calls at a time to the API. + +Here's some (oversimplified) example code that illustrates how to call a REST API within an Inspect component. We use the `async` interface of the `httpx` module, and we use Inspect's `concurrency()` function to limit simultaneous connections to 10: ``` python -result = await subprocess(["ls", dir], timeout = 30) +import httpx +from inspect_ai.util import concurrency +from inspect_ai.solver import Generate, TaskState + +client = httpx.AsyncClient() + +async def solve(state: TaskState, generate: Generate): + ... + # wrap the call to client.get() in an async concurrency + # block to limit simultaneous connections to 10 + async with concurrency("my-rest-api", 10): + response = await client.get("https://example.com/api") ``` -If a timeout occurs, then the `result.status` will be `False` and a timeout error message will be included in `result.stderr`. +Note that we pass a name ("my-rest-api") to the `concurrency()` function. This provides a named scope for managing concurrency for calls to that specific API/service. -## Parallel Code +### Parallel Code {#sec-parallel-code} Generally speaking, you should try to make all of the code you write within Inspect solvers, tools, and scorers as parallel as possible. The main idea is to eagerly post as much work as you can, and then allow the various concurrency gates described above to take care of not overloading remote APIs or local resources. There are two keys to writing parallel code: 1. Use `async` for all potentially expensive operations. If you are calling a remote API, use the `httpx.AsyncClient`. If you are running local code, use the `subprocess()` function described above. 2. If your `async` work can be parallelised, do it using `asyncio.gather()`. For example, if you are calling three different model APIs to score a task, you can call them all in parallel. Or if you need to retrieve 10 web pages you don't need to do it in a loop—rather, you can fetch them all at once. -### Model Requests +#### Model Requests Let's say you have a scorer that uses three different models to score based on majority vote. You could make all of the model API calls in parallel as follows: @@ -159,7 +161,7 @@ grader_outputs = await asyncio.gather(*graders) Note that we don't await the call to `model.generate()` when building our list of graders. Rather the call to `asyncio.gather()` will await each of these requests and return when they have all completed. Inspect's internal handling of `max_connections` for model APIs will apply to these requests, so you need now worry about how many you put in flight, they will be throttled as appropriate. -### Web Requests +#### Web Requests Here's an examples of using `asyncio.gather()` to parallelise web requests: @@ -192,4 +194,53 @@ async def download(page): downloads = [download(page) for page in pages] results = await asyncio.gather(*downloads) -``` \ No newline at end of file +``` + + +### Subprocesses + +It's possible that your custom solvers, tools, or scorers will need to launch child processes to perform various tasks. Subprocesses have similar considerations as calling APIs: you want to make sure that they don't block the rest of the work in Inspect (so they should be invoked with `async`) and you also want to make sure they don't provide *too much* concurrency (i.e. you wouldn't want to launch 200 processes at once on a 4 core machine!) + +To assist with this, Inspect provides the `subprocess()` function. This `async` function takes a command and arguments and invokes the specified command asynchronously, collecting and returning stdout and stderr. The `subprocess()` function also automatically limits concurrent child processes to the number of CPUs on your system (`os.cpu_count()`). Here's an example from the implementation of a `list_files()` tool: + +``` python +@tool(prompt=( + "If you are asked to list the files in a directory you " + + "should call the list_files function to access the listing." +)) +def list_files(): + async def execute(dir: str): + """List the files in a directory. + + Args: + dir (str): Directory + + Returns: + File listing of the directory + """ + result = await subprocess(["ls", dir]) + if result.success: + return result.stdout + else: + raise ToolError(result.stderr) + + return execute +``` + +The maximum number of concurrent subprocesses can be modified using the `--max-subprocesses` option. For example: + +``` bash +$ inspect eval --model openai/gpt-4 --max-subprocesses 4 +``` + +Note that if you need to execute computationally expensive code in an eval, you should always factor it into a call to `subprocess()` so that you get optimal concurrency and performance. + +#### Timeouts + +If you need to ensure that your subprocess runs for no longer than a specified interval, you can use the `timeout` option. For example: + +``` python +result = await subprocess(["ls", dir], timeout = 30) +``` + +If a timeout occurs, then the `result.status` will be `False` and a timeout error message will be included in `result.stderr`. diff --git a/docs/scorers.qmd b/docs/scorers.qmd index 0e041f344..14808251b 100644 --- a/docs/scorers.qmd +++ b/docs/scorers.qmd @@ -121,7 +121,7 @@ async def score(state: TaskState, target: Target): First we'll talk about the core `Score` and `Value` objects, then provide some examples of custom scorers to make things more concrete. ::: {.callout-note appearance="simple"} -Note that `score()` above is declared as an `async` function. When creating custom scorers, it's critical that you understand Inspect's concurrency model. More specifically, if your scorer is doing non-trivial work (e.g. calling REST APIs, executing external processes, etc.) please review [Eval Tuning](#sec-eval-tuning) before proceeding. +Note that `score()` above is declared as an `async` function. When creating custom scorers, it's critical that you understand Inspect's concurrency model. More specifically, if your scorer is doing non-trivial work (e.g. calling REST APIs, executing external processes, etc.) please review [Parallelism](#sec-parallel-solvers-and-scorers) before proceeding. ::: ### Score diff --git a/docs/solvers.qmd b/docs/solvers.qmd index f3ae0679d..626307b7c 100644 --- a/docs/solvers.qmd +++ b/docs/solvers.qmd @@ -334,7 +334,7 @@ Note that calls to `generate()` (for both the critique model and the model being ### Concurrency -When creating custom solvers, it's critical that you understand Inspect's concurrency model. More specifically, if your solver is doing non-trivial work (e.g. calling REST APIs, executing external processes, etc.) please review [Eval Tuning](#sec-eval-tuning) for a more in depth discussion. +When creating custom solvers, it's critical that you understand Inspect's concurrency model. More specifically, if your solver is doing non-trivial work (e.g. calling REST APIs, executing external processes, etc.) please review [Parallelism](#sec-parallel-solvers-and-scorers) for a more in depth discussion. ## Early Termination diff --git a/requirements.txt b/requirements.txt index dbc633dd6..604f5cff9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,4 +20,4 @@ s3fs>=2023 semver shortuuid tenacity -typing_extensions \ No newline at end of file +typing_extensions diff --git a/src/inspect_ai/_cli/eval.py b/src/inspect_ai/_cli/eval.py index f1ab42667..c6dfd50f2 100644 --- a/src/inspect_ai/_cli/eval.py +++ b/src/inspect_ai/_cli/eval.py @@ -41,7 +41,6 @@ "--model", type=str, required=True, - envvar=["INSPECT_EVAL_MODEL", "INSPECT_MODEL_NAME"], help="Model used to evaluate tasks.", ) @click.option( diff --git a/src/inspect_ai/_cli/score.py b/src/inspect_ai/_cli/score.py index a0416f5ed..5c017421a 100644 --- a/src/inspect_ai/_cli/score.py +++ b/src/inspect_ai/_cli/score.py @@ -5,7 +5,7 @@ from inspect_ai._display import display from inspect_ai._display.logger import init_logger -from inspect_ai._eval.context import init_eval_context +from inspect_ai._eval.context import init_eval_context, init_task_context from inspect_ai._eval.loader import load_tasks from inspect_ai._eval.score import task_score from inspect_ai._util.constants import SCORED_SUFFIX @@ -69,7 +69,8 @@ async def score( ) # initialize active model - init_eval_context(model) + init_eval_context() + init_task_context(model) # instantiate the task so we can get its scorer and metrics score_task = load_tasks([task], model)[0] diff --git a/src/inspect_ai/_display/_display.py b/src/inspect_ai/_display/_display.py index a0871f493..a2a0433d4 100644 --- a/src/inspect_ai/_display/_display.py +++ b/src/inspect_ai/_display/_display.py @@ -2,53 +2,66 @@ import contextlib from dataclasses import dataclass from types import TracebackType -from typing import Any, Iterator, Type +from typing import Any, Iterator, Type, Union -from inspect_ai.log import EvalConfig, EvalError, EvalResults, EvalStats +from inspect_ai.log import EvalConfig, EvalResults, EvalStats from inspect_ai.model import GenerateConfig, ModelName class Progress(abc.ABC): @abc.abstractmethod - def update(self, n: float = 1) -> None: ... - - -class TaskDisplay(abc.ABC): - @abc.abstractmethod - @contextlib.contextmanager - def progress(self, total: int) -> Iterator[Progress]: ... - - @abc.abstractmethod - def cancelled(self, samples_logged: int, stats: EvalStats) -> None: ... + def update(self, n: int = 1) -> None: ... @abc.abstractmethod - def summary(self, results: EvalResults, stats: EvalStats) -> None: ... - - @abc.abstractmethod - def error( - self, - samples_logged: int, - error: EvalError, - exc_type: Type[Any], - exc_value: BaseException, - traceback: TracebackType | None, - ) -> None: ... + def complete(self) -> None: ... @dataclass class TaskProfile: name: str - sequence: tuple[int, int] model: ModelName dataset: str scorer: str samples: int + steps: int eval_config: EvalConfig task_args: dict[str, Any] generate_config: GenerateConfig log_location: str +@dataclass +class TaskError: + samples_logged: int + exc_type: Type[Any] + exc_value: BaseException + traceback: TracebackType | None + + +@dataclass +class TaskCancelled: + samples_logged: int + stats: EvalStats + + +@dataclass +class TaskSuccess: + stats: EvalStats + results: EvalResults + + +TaskResult = Union[TaskError, TaskCancelled, TaskSuccess] + + +class TaskDisplay(abc.ABC): + @abc.abstractmethod + @contextlib.contextmanager + def progress(self) -> Iterator[Progress]: ... + + @abc.abstractmethod + def complete(self, result: TaskResult) -> None: ... + + class Display(abc.ABC): @abc.abstractmethod def print(self, message: str) -> None: ... @@ -57,6 +70,10 @@ def print(self, message: str) -> None: ... @contextlib.contextmanager def progress(self, total: int) -> Iterator[Progress]: ... + @abc.abstractmethod + @contextlib.contextmanager + def live_task_status(self) -> Iterator[None]: ... + @abc.abstractmethod @contextlib.contextmanager def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: ... diff --git a/src/inspect_ai/_display/rich.py b/src/inspect_ai/_display/rich.py index 7638ee438..62054c2d1 100644 --- a/src/inspect_ai/_display/rich.py +++ b/src/inspect_ai/_display/rich.py @@ -2,17 +2,16 @@ import contextlib import datetime from dataclasses import dataclass -from types import TracebackType -from typing import Any, Callable, Iterator, Type +from typing import Callable, Iterator -from rich.align import Align -from rich.console import Console, RenderableType +from rich.console import Console, Group, RenderableType from rich.live import Live from rich.panel import Panel from rich.progress import ( BarColumn, SpinnerColumn, TaskProgressColumn, + TextColumn, TimeElapsedColumn, ) from rich.progress import Progress as RProgress @@ -22,12 +21,21 @@ from inspect_ai._util.path import cwd_relative_path from inspect_ai._util.platform import is_running_in_jupyterlab, is_running_in_vscode -from inspect_ai.log import EvalError, EvalResults, EvalStats +from inspect_ai.log import EvalResults, EvalStats from inspect_ai.log._log import rich_traceback from inspect_ai.util._concurrency import concurrency_status from inspect_ai.util._logger import logger_http_rate_limit_count -from ._display import Display, Progress, TaskDisplay, TaskProfile +from ._display import ( + Display, + Progress, + TaskCancelled, + TaskDisplay, + TaskError, + TaskProfile, + TaskResult, + TaskSuccess, +) @dataclass @@ -39,56 +47,106 @@ class Theme: error: str = "red" +@dataclass +class TaskStatus: + profile: TaskProfile + result: TaskResult | None + progress: RProgress + + class RichDisplay(Display): def __init__(self) -> None: - self.console = rich_console() - self.theme = Theme() + self.tasks: list[TaskStatus] | None = None + self.progress_ui: RProgress | None = None @override def print(self, message: str) -> None: - self.console.print(message, markup=False, highlight=False) + rich_console().print(message, markup=False, highlight=False) @override @contextlib.contextmanager def progress(self, total: int) -> Iterator[Progress]: - with rich_progress(self.console) as progress: + with rich_progress() as progress: yield RichProgress(total, progress) + @override + @contextlib.contextmanager + def live_task_status(self) -> Iterator[None]: + if self.tasks is None: + # initialise tasks + self.tasks = [] + self.progress_ui = rich_progress() + + with Live(None, console=rich_console(), auto_refresh=False) as live: + # setup some timed updates + loop = asyncio.get_event_loop() + handle: asyncio.TimerHandle | None + + def update_display() -> None: + if self.tasks is not None and self.progress_ui is not None: + r = tasks_live_status(self.tasks, self.progress_ui) + live.update(r, refresh=True) + nonlocal handle + handle = loop.call_later(1, update_display) + + handle = loop.call_later(1, update_display) + + # yield + yield + + # cleanup handle if we need to + if handle: + handle.cancel() + + # render task results + live.update(tasks_results(self.tasks), refresh=True) + + # clear tasks and progress + self.tasks = None + self.progress_ui = None + + else: + yield + @override @contextlib.contextmanager def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: - with Live(None, console=self.console) as live: - # create task display - display = RichTaskDisplay( - profile, - self.console, - self.theme, - lambda r: live.update(r, refresh=True), - ) + # for typechekcer + if self.tasks is None: + self.tasks = [] + if self.progress_ui is None: + self.progress_ui = rich_progress() - # setup some timed updates (for when no progress ticks are occurring) - loop = asyncio.get_event_loop() - handle: asyncio.TimerHandle | None + status = TaskStatus(profile, None, self.progress_ui) + self.tasks.append(status) + yield RichTaskDisplay(status) - def update_display() -> None: - display.on_update() - nonlocal handle - handle = loop.call_later(5, update_display) - handle = loop.call_later(5, update_display) +class RichTaskDisplay(TaskDisplay): + def __init__(self, status: TaskStatus) -> None: + self.status = status - # yield the display - yield display + @override + @contextlib.contextmanager + def progress(self) -> Iterator[Progress]: + model = str(self.status.profile.model) + p = RichProgress( + total=self.status.profile.steps, + progress=self.status.progress, + description=model, + ) + yield p + p.complete() - # cleanup handle if we need to - if handle: - handle.cancel() + @override + def complete(self, result: TaskResult) -> None: + self.status.result = result # Note that use of rich progress seems to result in an extra -# empty cell after execution, see: -# https://github.com/Textualize/rich/issues/3211 -# https://github.com/Textualize/rich/issues/3168 +# empty cell after execution, see: https://github.com/Textualize/rich/issues/3274 + +PROGRESS_TOTAL = 102 class RichProgress(Progress): @@ -96,130 +154,110 @@ def __init__( self, total: int, progress: RProgress, - on_update: Callable[[], None] | None = None, + description: str = "", + meta: Callable[[], str] | None = None, ) -> None: self.total = total self.progress = progress - self.task_id = progress.add_task("", total=102) - self.on_update = on_update + self.meta = meta if meta else lambda: "" + self.task_id = progress.add_task( + description, total=PROGRESS_TOTAL, meta=self.meta() + ) @override - def update(self, n: float = 1) -> None: - advance = (n / self.total) * 100 - self.progress.update(task_id=self.task_id, advance=advance, refresh=True) - if self.on_update: - self.on_update() + def update(self, n: int = 1) -> None: + advance = (float(n) / float(self.total)) * 100 + self.progress.update( + task_id=self.task_id, advance=advance, refresh=True, meta=self.meta() + ) + @override + def complete(self) -> None: + self.progress.update(task_id=self.task_id, completed=PROGRESS_TOTAL) -class RichTaskDisplay(TaskDisplay): - def __init__( - self, - profile: TaskProfile, - console: Console, - theme: Theme, - render: Callable[[RenderableType], None], - ) -> None: - self.profile = profile - self.console = console - self.theme = theme - self.progress_ui = rich_progress(console) - self.render = render - self.on_update() - @override - @contextlib.contextmanager - def progress(self, total: int) -> Iterator[Progress]: - yield RichProgress(total, self.progress_ui, self.on_update) +def tasks_live_status(tasks: list[TaskStatus], progress: RProgress) -> RenderableType: + return task_live_status(tasks, progress) - @override - def cancelled(self, samples_logged: int, stats: EvalStats) -> None: - panel = self.task_panel( - body=task_stats(self.profile, stats, self.theme), - config=None, - footer=task_interrupted( - self.profile.log_location, samples_logged, self.theme - ), - log_location=self.profile.log_location, - ) - self.render(panel) - @override - def summary(self, results: EvalResults, stats: EvalStats) -> None: - panel = self.task_panel( - body=task_stats(self.profile, stats, self.theme), - config=None, - footer=task_results(results, self.theme), - log_location=self.profile.log_location, - ) - self.render(panel) +def tasks_results(tasks: list[TaskStatus]) -> RenderableType: + def render_task(task: TaskStatus) -> RenderableType: + if isinstance(task.result, TaskCancelled): + return task_result_cancelled(task.profile, task.result) + elif isinstance(task.result, TaskError): + return task_result_error(task.profile, task.result) + elif isinstance(task.result, TaskSuccess): + return task_result_summary(task.profile, task.result) + else: + return "" - @override - def error( - self, - samples_logged: int, - error: EvalError, - exc_type: Type[Any], - exc_value: BaseException, - traceback: TracebackType | None, - ) -> None: - panel = self.task_panel( - body=rich_traceback(exc_type, exc_value, traceback), - config=None, - footer=task_interrupted( - self.profile.log_location, samples_logged, self.theme - ), - log_location=self.profile.log_location, - ) - self.render(panel) - - def on_update(self) -> None: - panel = self.task_panel( - body=Align(self.progress_ui, vertical="middle"), - config=task_config(self.profile, self.theme), - footer=live_task_footer(self.theme), - log_location=None, - ) - self.render(panel) + return Group(*[render_task(task) for task in tasks]) - def task_panel( - self, - body: RenderableType, - config: str | None, - footer: tuple[RenderableType, RenderableType] | None, - log_location: str | None, - ) -> Panel: - return task_panel( - profile=self.profile, - body=body, - config=config, - footer=footer, - log_location=log_location, - options=TaskPanelOptions( - theme=self.theme, - # rich doesn't detect vs code width properly - width=(80 if is_vscode_notebook(self.console) else None), - jupyter=self.console.is_jupyter, - ), - ) +def task_live_status(tasks: list[TaskStatus], progress: RProgress) -> RenderableType: + body: list[RenderableType] = ["", progress] + config = task_config(tasks[0].profile) + if config: + body = [config] + body + + return task_panel( + profile=tasks[0].profile, + show_model=len(tasks) == 1, + body=Group(*body), + config=None, + footer=live_task_footer(), + log_location=None, + ) -@dataclass -class TaskPanelOptions: - theme: Theme - width: int | None - jupyter: bool + +def task_result_cancelled( + profile: TaskProfile, cancelled: TaskCancelled +) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=task_stats(profile, cancelled.stats), + config=None, + footer=task_interrupted(profile.log_location, cancelled.samples_logged), + log_location=profile.log_location, + ) + + +def task_result_summary(profile: TaskProfile, success: TaskSuccess) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=task_stats(profile, success.stats), + config=None, + footer=task_results(success.results), + log_location=profile.log_location, + ) + + +def task_result_error(profile: TaskProfile, error: TaskError) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=rich_traceback(error.exc_type, error.exc_value, error.traceback), + config=None, + footer=task_interrupted(profile.log_location, error.samples_logged), + log_location=profile.log_location, + ) def task_panel( profile: TaskProfile, + show_model: bool, body: RenderableType, config: str | None, footer: tuple[RenderableType, RenderableType] | None, log_location: str | None, - options: TaskPanelOptions, ) -> Panel: - # alias theme - theme = options.theme + # rendering context + theme = rich_theme() + console = rich_console() + width = 100 if is_vscode_notebook(console) else None + jupyter = console.is_jupyter # setup table table = Table.grid(expand=True) @@ -245,7 +283,7 @@ def task_panel( root = table if log_location: # if we are in jupyter then use a real hyperlink - if options.jupyter: + if jupyter: log_location = f"[link={log_location}]{log_location}[/link]" # Print a cwd relative path @@ -266,35 +304,32 @@ def task_panel( # create panel w/ title panel = Panel( root, - title=f"[bold][{theme.meta}]{task_title(profile)}[/{theme.meta}][/bold]", + title=f"[bold][{theme.meta}]{task_title(profile, show_model)}[/{theme.meta}][/bold]", title_align="left", - width=options.width, + width=width, expand=True, ) return panel -def task_title(profile: TaskProfile) -> str: - sequence = ( - f"task {profile.sequence[0]}/{profile.sequence[1]}: " - if profile.sequence[1] > 1 - else "" - ) +def task_title(profile: TaskProfile, show_model: bool) -> str: eval_epochs = profile.eval_config.epochs or 1 epochs = f" x {profile.eval_config.epochs}" if eval_epochs > 1 else "" samples = f"{profile.samples//eval_epochs:,}{epochs} sample{'s' if profile.samples > 1 else ''}" - title = f"{sequence}{profile.name} ({samples})" + title = f"{profile.name} ({samples})" + if show_model: + title = f"{title}: {profile.model}" return title def task_targets(profile: TaskProfile) -> str: - return " " + "\n ".join( - [str(profile.model), f"dataset: {profile.dataset}", f"scorer: {profile.scorer}"] - ) + targets = [f"dataset: {profile.dataset}", f"scorer: {profile.scorer}"] + return " " + "\n ".join(targets) -def task_config(profile: TaskProfile, theme: Theme) -> str: +def task_config(profile: TaskProfile) -> str: # merge config + theme = rich_theme() config = ( dict(profile.task_args) | dict(profile.eval_config.model_dump(exclude_none=True)) @@ -314,11 +349,14 @@ def task_config(profile: TaskProfile, theme: Theme) -> str: def task_resources() -> str: resources: dict[str, str] = {} for model, resource in concurrency_status().items(): + if "/" in model: + model = model.split("/", 1)[1] resources[model] = f"{resource[0]}/{resource[1]}" return task_dict(resources) -def live_task_footer(theme: Theme) -> tuple[RenderableType, RenderableType]: +def live_task_footer() -> tuple[RenderableType, RenderableType]: + theme = rich_theme() return ( f"[{theme.light}]{task_resources()}[/{theme.light}]", Text(task_http_rate_limits(), style=theme.light), @@ -326,8 +364,9 @@ def live_task_footer(theme: Theme) -> tuple[RenderableType, RenderableType]: def task_interrupted( - log_location: str, samples_logged: int, theme: Theme + log_location: str, samples_logged: int ) -> tuple[RenderableType, RenderableType]: + theme = rich_theme() return ( f"[bold][{theme.error}]Task interrupted ({samples_logged} " + "completed samples logged before interruption). " @@ -337,9 +376,8 @@ def task_interrupted( ) -def task_results( - results: EvalResults, theme: Theme -) -> tuple[RenderableType, RenderableType]: +def task_results(results: EvalResults) -> tuple[RenderableType, RenderableType]: + theme = rich_theme() output: dict[str, str] = {} for name, metric in results.metrics.items(): value = ( @@ -357,10 +395,11 @@ def task_results( return (metrics, "") -def task_stats(profile: TaskProfile, stats: EvalStats, theme: Theme) -> RenderableType: +def task_stats(profile: TaskProfile, stats: EvalStats) -> RenderableType: + theme = rich_theme() panel = Table.grid(expand=True) panel.add_column() - config = task_config(profile, theme) + config = task_config(profile) if config: panel.add_row(config) panel.add_row() @@ -400,22 +439,17 @@ def task_dict(d: dict[str, str], bold_value: bool = False) -> str: ) -def rich_progress(console: Console) -> RProgress: - return RProgress( - SpinnerColumn(finished_text="✓"), - BarColumn(bar_width=40 if is_vscode_notebook(console) else None), - TaskProgressColumn(), - TimeElapsedColumn(), - transient=True, - console=console, - expand=not is_vscode_notebook(console), - ) - - def is_vscode_notebook(console: Console) -> bool: return console.is_jupyter and is_running_in_vscode() +def rich_theme() -> Theme: + global _theme + if _theme is None: + _theme = Theme() + return _theme + + def rich_console() -> Console: global _console if _console is None: @@ -433,5 +467,21 @@ def rich_display() -> RichDisplay: return _display +def rich_progress() -> RProgress: + console = rich_console() + return RProgress( + SpinnerColumn(finished_text="✓"), + TextColumn("{task.description}"), + TextColumn("{task.fields[meta]}"), + BarColumn(bar_width=40 if is_vscode_notebook(console) else None), + TaskProgressColumn(), + TimeElapsedColumn(), + transient=True, + console=console, + expand=not is_vscode_notebook(console), + ) + + +_theme: Theme | None = None _console: Console | None = None _display: RichDisplay | None = None diff --git a/src/inspect_ai/_eval/context.py b/src/inspect_ai/_eval/context.py index c86b1ff80..31e48af84 100644 --- a/src/inspect_ai/_eval/context.py +++ b/src/inspect_ai/_eval/context.py @@ -4,11 +4,11 @@ from inspect_ai.util._subprocess import init_max_subprocesses -def init_eval_context(model: Model, max_subprocesses: int | None = None) -> None: - init_active_model(model) +def init_eval_context(max_subprocesses: int | None = None) -> None: init_max_subprocesses(max_subprocesses) -def init_task_context() -> None: +def init_task_context(model: Model) -> None: + init_active_model(model) init_model_usage() init_logger_records() diff --git a/src/inspect_ai/_eval/eval.py b/src/inspect_ai/_eval/eval.py index 5ac115563..ca3d78d80 100644 --- a/src/inspect_ai/_eval/eval.py +++ b/src/inspect_ai/_eval/eval.py @@ -2,46 +2,45 @@ import logging import os from pathlib import Path -from typing import Any +from typing import Any, Awaitable, Callable from shortuuid import uuid from typing_extensions import Unpack +from inspect_ai._display import display from inspect_ai._display.logger import init_logger -from inspect_ai._util.dotenv import init_dotenv -from inspect_ai._util.path import cwd_relative_path +from inspect_ai._util.dotenv import dotenv_environ, init_dotenv +from inspect_ai._util.error import exception_message +from inspect_ai._util.path import chdir_python, cwd_relative_path from inspect_ai._util.platform import platform_init from inspect_ai._util.registry import registry_lookup -from inspect_ai._view.view import view_notify_eval from inspect_ai.log import EvalConfig, EvalLog, EvalLogInfo, read_eval_log from inspect_ai.log._file import JSONRecorder -from inspect_ai.model import ( - GenerateConfig, - GenerateConfigArgs, - Model, - get_model, -) -from inspect_ai.solver import Solver, ToolEnvironmentSpec +from inspect_ai.log._log import Recorder +from inspect_ai.model import GenerateConfig, GenerateConfigArgs, Model +from inspect_ai.model._model import resolve_models +from inspect_ai.solver import Plan, Solver, ToolEnvironmentSpec +from inspect_ai.solver._tool.environment.context import startup_tool_environments from .context import init_eval_context -from .loader import resolve_tasks +from .loader import ResolvedTask, resolve_tasks from .task import PreviousTask, Tasks from .task.log import TaskLogger -from .task.run import eval_log_sample_source, task_run -from .task.util import task_file, task_run_dir +from .task.run import create_sample_semaphore, task_run +from .task.util import task_run_dir log = logging.getLogger(__name__) def eval( tasks: Tasks, - model: str | Model | None = None, + model: str | Model | list[str] | list[Model] | None = None, model_base_url: str | None = None, model_args: dict[str, Any] = dict(), task_args: dict[str, Any] = dict(), toolenv: ToolEnvironmentSpec | None = None, toolenv_cleanup: bool | None = None, - plan: Solver | list[Solver] | None = None, + plan: Plan | Solver | list[Solver] | None = None, log_level: str | None = None, log_dir: str | None = None, limit: int | tuple[int, int] | None = None, @@ -60,9 +59,9 @@ def eval( Args: tasks: (Tasks): Task(s) to evaluate. If None, attempt to evaluate a task in the current working directory - model (str | Model | None): Model for evaluation. If not - specified uses the current eval's model, or failing that - the value of the INSPECT_EVAL_MODEL environment variable. + model (str | Model | list[str] | list[Model] | None): Model(s) for + evaluation. If not specified use the value of the INSPECT_EVAL_MODEL + environment variable. model_base_url: (str | None): Base URL for communicating with the model API. model_args (dict[str,Any]): Model creation parameters @@ -71,7 +70,7 @@ def eval( environment type (or optionally a tuple with type and config file) toolenv_cleanup (bool | None): Cleanup tool environments after task completes (defaults to True) - plan (Solver | list[Solver] | None): Alternative plan + plan (Plan | Solver | list[Solver] | None): Alternative plan for evaluating task(s). Optional (uses task plan by default). log_level (str | None): "debug", "http", "info", "warning", "error", or "critical" (defaults to "info") @@ -129,13 +128,13 @@ def eval( async def eval_async( tasks: Tasks, - model: str | Model | None = None, + model: str | Model | list[str] | list[Model] | None = None, model_base_url: str | None = None, model_args: dict[str, Any] = dict(), task_args: dict[str, Any] = dict(), toolenv: ToolEnvironmentSpec | None = None, toolenv_cleanup: bool | None = None, - plan: Solver | list[Solver] | None = None, + plan: Plan | Solver | list[Solver] | None = None, log_level: str | None = None, log_dir: str | None = None, limit: int | tuple[int, int] | None = None, @@ -154,9 +153,9 @@ async def eval_async( Args: tasks: (Tasks): Task(s) to evaluate. If None, attempt to evaluate a task in the current working directory - model (str | Model | None): Model for evaluation. If not - specified uses the current eval's model, or failing that - the value of the INSPECT_EVAL_MODEL environment variable. + model (str | Model | list[str] | list[Model] | None): Model(s) for + evaluation. If not specified use the value of the INSPECT_EVAL_MODEL + environment variable. model_base_url: (str | None): Base URL for communicating with the model API. model_args (dict[str,Any]): Model creation parameters @@ -165,8 +164,8 @@ async def eval_async( environment type (or optionally a tuple with type and config file) toolenv_cleanup (bool | None): Cleanup tool environments after task completes (defaults to True) - plan (Solver | list[Solver] | None): Alternative plan - for evaluating task(s). Optional (uses task plan by default). + plan (Plan | Solver | list[Solver] | None): Alternative plan + for evaluating task(s). Optional (uses task plan by default). log_level (str | None): "debug", "http", "info", "warning", "error", or "critical" (defaults to "info") log_dir (str | None): Output path for logging results @@ -192,124 +191,192 @@ async def eval_async( Returns: List of EvalLog (one for each task) """ - # Provide .env and log support bootstrap for notebooks and invoking - # an eval as a plain Python script (as opposed to via inspect eval) - init_dotenv() - init_logger(log_level) - - # resolve model - model = get_model( - model=model, - base_url=model_base_url, - config=GenerateConfig(**kwargs), - **model_args, - ) + # only a single call to eval_async can be active at a time, this is + # because when running a task a chdir to the task's directory (and a + # similar mutation of the Python sys.path) occurs. since this is a + # change to global process state it cannot occur in parallel. for + # task parallelism, use eval_gather, which enforces the appropriate + # constraints on task parallelism and schedules multiple tasks for + # optimal concurrency + global _eval_async_running + if _eval_async_running: + raise RuntimeError("Multiple concurrent calls to eval_async are not allowed.") + + _eval_async_running = True + try: + # Provide .env and log support bootstrap for notebooks and invoking + # an eval as a plain Python script (as opposed to via inspect eval) + init_dotenv() + init_logger(log_level) + + # init eval context + init_eval_context(max_subprocesses) + + # resolve models + models = resolve_models( + model, model_base_url, model_args, GenerateConfig(**kwargs) + ) - # init eval context - init_eval_context(model, max_subprocesses) - - # if this is a PreviousTask then we are being spotted our id and log - if isinstance(tasks, PreviousTask): - task_id = tasks.id - eval_log = tasks.log - tasks = tasks.task - else: - task_id = None - eval_log = None - - # resolve tasks - eval_tasks = resolve_tasks(tasks, model, task_args) - - # if we have an eval_log, see if we can re-use its logged samples - sample_source = eval_log_sample_source(eval_log, eval_tasks[0].dataset) - - # warn and return empty string if we resolved no tasks - if len(eval_tasks) == 0: - log.warning("No inspect tasks were found at the specified paths.") - return [] - - # resolve recorder - log_dir = log_dir if log_dir else os.environ.get("INSPECT_LOG_DIR", "./logs") - log_dir = cwd_relative_path(log_dir) - recorder = JSONRecorder(log_dir, log_buffer=log_buffer) - - # build task names and versions (include version if > 0) - task_names: list[str] = [task.name for task in eval_tasks] - task_versions: list[int] = [task.version for task in eval_tasks] - - # create config - eval_config = EvalConfig( - limit=limit, - epochs=epochs, - max_messages=max_messages, - max_samples=max_samples, - max_subprocesses=max_subprocesses, - toolenv_cleanup=toolenv_cleanup, - log_samples=log_samples, - log_images=log_images, - log_buffer=log_buffer, - ) + # resolve tasks + resolved_tasks: list[ResolvedTask] = [] + for m in models: + resolved_tasks.extend(resolve_tasks(tasks, task_args, m, toolenv)) - run_id = uuid() - loggers: list[TaskLogger] = [] - results: list[EvalLog] = [] - for index, name, version, task in zip( - range(0, len(task_names)), task_names, task_versions, eval_tasks - ): - # tasks can provide their own epochs and max_messages - task_eval_config = eval_config.model_copy() - if task.epochs is not None: - task_eval_config.epochs = task.epochs - if task.max_messages is not None: - task_eval_config.max_messages = task.max_messages - - # create and track the logger - logger = TaskLogger( - task_name=name, - task_version=version, - task_file=task_file(task, relative=True), - task_run_dir=task_run_dir(task), - task_id=task_id if task_id else uuid(), - run_id=run_id, - model=model, - dataset=task.dataset, - tool_environment=toolenv, - task_attribs=task.attribs, - task_args=task_args, - model_args=model_args, - eval_config=task_eval_config, - recorder=recorder, + # warn and return empty string if we resolved no tasks + if len(resolved_tasks) == 0: + log.warning("No inspect tasks were found at the specified paths.") + return [] + + # resolve recorder + log_dir = log_dir if log_dir else os.environ.get("INSPECT_LOG_DIR", "./logs") + log_dir = cwd_relative_path(log_dir) + recorder = JSONRecorder(log_dir, log_buffer=log_buffer) + + # create config + eval_config = EvalConfig( + limit=limit, + epochs=epochs, + max_messages=max_messages, + max_samples=max_samples, + max_subprocesses=max_subprocesses, + toolenv_cleanup=toolenv_cleanup, + log_samples=log_samples, + log_images=log_images, + log_buffer=log_buffer, ) - loggers.append(logger) - # run the eval (create a task so it gets its own ContextVar scope) - result = await asyncio.create_task( - task_run( - task=task, - sequence=(index + 1, len(task_names)), - model=model, - logger=logger, - config=task_eval_config, - plan=plan, - score=score, - sample_source=sample_source, - **kwargs, + # run tasks (batch so that multiple models are executed in parallel) + run_id = uuid() + results: list[EvalLog] = [] + for sequence in range(0, len(resolved_tasks) // len(models)): + task_batch = list(filter(lambda t: t.sequence == sequence, resolved_tasks)) + results.extend( + await eval_parallel( + run_id=run_id, + tasks=task_batch, + eval_config=eval_config, + recorder=recorder, + model_args=model_args, + plan=plan, + score=score, + **kwargs, + ) ) - ) + # exit the loop if there was a cancellation + if any([result.status == "cancelled" for result in results]): + break + + # return list of eval logs + return EvalLogs(results) + finally: + _eval_async_running = False + + +# single call to eval_async at a time +_eval_async_running = False - # mark completed and append result - results.append(result) - # notify the view module that an eval just completed - # (in case we have a view polling for new evals) - view_notify_eval(logger.location) +async def eval_parallel( + run_id: str, + tasks: list[ResolvedTask], + eval_config: EvalConfig, + recorder: Recorder, + model_args: dict[str, Any], + plan: Plan | Solver | list[Solver] | None = None, + score: bool = True, + **kwargs: Unpack[GenerateConfigArgs], +) -> list[EvalLog]: + # we rely on the run_dir and toolenv being the same across all tasks + # alias these and then confirm that the rest of the tasks conform + run_dir = task_run_dir(tasks[0].task) + if any([task_run_dir(task.task) != run_dir for task in tasks]): + raise RuntimeError( + "Tasks passed to eval_parallel must have the same working directory." + ) + toolenv = next((task.toolenv for task in tasks if task.toolenv is not None), None) + if any([task.toolenv is not None and task.toolenv != toolenv for task in tasks]): + raise RuntimeError( + "Tasks passed to eval_parallel must have the same tool environment." + ) + + # if we have a toolenv then we need to enforce sample concurrency at + # this level of the eval (so we don't explode the # of toolenvs) + sample_semaphore: asyncio.Semaphore | None = ( + create_sample_semaphore(eval_config, GenerateConfig(**kwargs), toolenv) + if toolenv + else None + ) - # exit the loop if this was a cancellation - if result.status == "cancelled": - break + # switch to task directory context + with chdir_python(run_dir), dotenv_environ(): + # run startup pass for the tool_environment + shutdown_tool_environments: Callable[[], Awaitable[None]] | None = None + if toolenv: + cleanup = eval_config.toolenv_cleanup is not False + shutdown_tool_environments = await startup_tool_environments( + "startup", toolenv, cleanup + ) - # return list of eval logs - return EvalLogs(results) + try: + # create asyncio tasks + asyncio_tasks: list[asyncio.Task[EvalLog]] = [] + for resolved_task in tasks: + # tasks can provide their own epochs and max_messages + task = resolved_task.task + task_eval_config = eval_config.model_copy() + if task.epochs is not None: + task_eval_config.epochs = task.epochs + if task.max_messages is not None: + task_eval_config.max_messages = task.max_messages + + # create and track the logger + logger = TaskLogger( + task_name=task.name, + task_version=task.version, + task_file=resolved_task.task_file, + task_id=resolved_task.id if resolved_task.id else uuid(), + run_id=run_id, + model=resolved_task.model, + dataset=task.dataset, + tool_environment=resolved_task.toolenv, + task_attribs=task.attribs, + task_args=resolved_task.task_args, + model_args=model_args, + eval_config=task_eval_config, + recorder=recorder, + ) + + # append task + asyncio_tasks.append( + asyncio.create_task( + task_run( + task=task, + model=resolved_task.model, + toolenv=toolenv, + logger=logger, + config=task_eval_config, + plan=plan, + score=score, + sample_source=resolved_task.sample_source, + sample_semaphore=sample_semaphore, + **kwargs, + ) + ) + ) + + # run all of the tasks in parallel + with display().live_task_status(): + return await asyncio.gather(*asyncio_tasks) + + finally: + # shutdown tool environments + if shutdown_tool_environments: + try: + await shutdown_tool_environments() + except BaseException as ex: + log.warning( + f"Error occurred shutting down tool environments: {exception_message(ex)}" + ) def eval_retry( @@ -457,7 +524,7 @@ async def eval_retry_async( task_file = eval_log.eval.task_file if task_file: if not Path(task_file).exists(): - raise FileNotFoundError("Task file '{task_file}' not found") + raise FileNotFoundError(f"Task file '{task_file}' not found") task = f"{task_file}@{task_name}" else: if registry_lookup("task", task_name) is None: diff --git a/src/inspect_ai/_eval/loader.py b/src/inspect_ai/_eval/loader.py index 2c4059d69..d76bdd30f 100644 --- a/src/inspect_ai/_eval/loader.py +++ b/src/inspect_ai/_eval/loader.py @@ -1,11 +1,13 @@ import ast import inspect +from dataclasses import dataclass, field from importlib.machinery import SourceFileLoader from importlib.util import module_from_spec, spec_from_loader from pathlib import Path from types import ModuleType from typing import Any, cast +from inspect_ai._eval.task.util import task_file from inspect_ai._util.dotenv import dotenv_environ from inspect_ai._util.path import chdir_python from inspect_ai._util.registry import ( @@ -15,27 +17,87 @@ registry_lookup, ) from inspect_ai.model import Model, ModelName +from inspect_ai.solver._tool.environment.environment import ToolEnvironmentSpec from .list import task_files from .registry import task_create -from .task import Task, TaskInfo, Tasks +from .task import PreviousTask, Task, TaskInfo, Tasks from .task.constants import TASK_FILE_ATTR, TASK_RUN_DIR_ATTR +from .task.run import EvalSampleSource, eval_log_sample_source + + +@dataclass +class ResolvedTask: + task: Task + task_args: dict[str, Any] + task_file: str | None + model: Model + toolenv: tuple[str, str | None] | None + sequence: int + id: str | None = field(default=None) + sample_source: EvalSampleSource | None = field(default=None) def resolve_tasks( tasks: Tasks, - model: Model, task_args: dict[str, Any], -) -> list[Task]: + model: Model, + toolenv: ToolEnvironmentSpec | None, +) -> list[ResolvedTask]: + def as_resolved_tasks(tasks: list[Task]) -> list[ResolvedTask]: + return [ + ResolvedTask( + task=task, + task_args=task_args, + task_file=task_file(task, relative=True), + model=model, + toolenv=( + (toolenv, None) + if isinstance(toolenv, str) + else toolenv + if toolenv is not None + else task.tool_environment + ), + sequence=sequence, + ) + for sequence, task in enumerate(tasks) + ] + # take empty lists out of play if isinstance(tasks, list) and len(tasks) == 0: - return load_tasks(None, model, task_args) + return as_resolved_tasks(load_tasks(None, model, task_args)) # simple cases of passing us Task objects if isinstance(tasks, Task): - return [tasks] + return as_resolved_tasks([tasks]) elif isinstance(tasks, list) and isinstance(tasks[0], Task): - return cast(list[Task], tasks) + return as_resolved_tasks(cast(list[Task], tasks)) + + # simple case of passing us PreviousTask + if isinstance(tasks, PreviousTask): + tasks = [tasks] + if isinstance(tasks, list) and isinstance(tasks[0], PreviousTask): + previous_tasks = cast(list[PreviousTask], tasks) + loaded_tasks = load_tasks( + [task.task for task in previous_tasks], model, task_args + ) + return [ + ResolvedTask( + task=loaded_task, + task_args=task_args, + task_file=previous_task.log.eval.task_file, + model=model, + toolenv=previous_task.log.eval.tool_environment, + sequence=sequence, + id=previous_task.id, + sample_source=eval_log_sample_source( + previous_task.log, loaded_task.dataset + ), + ) + for sequence, loaded_task, previous_task in zip( + range(0, len(loaded_tasks)), loaded_tasks, previous_tasks + ) + ] # convert TaskInfo to str if isinstance(tasks, TaskInfo): @@ -54,7 +116,9 @@ def resolve_tasks( tasks = [tasks] # done! let's load the tasks - return load_tasks(cast(list[str] | None, tasks), model, task_args) + return as_resolved_tasks( + load_tasks(cast(list[str] | None, tasks), model, task_args) + ) def load_tasks( diff --git a/src/inspect_ai/_eval/task/log.py b/src/inspect_ai/_eval/task/log.py index 573eb0742..110eb753c 100644 --- a/src/inspect_ai/_eval/task/log.py +++ b/src/inspect_ai/_eval/task/log.py @@ -35,7 +35,7 @@ ) from inspect_ai.model._model import model_usage from inspect_ai.scorer import Score -from inspect_ai.solver import Plan, Solver, TaskState, ToolEnvironmentSpec +from inspect_ai.solver import Plan, Solver, TaskState from inspect_ai.util._logger import logger_records @@ -45,12 +45,11 @@ def __init__( task_name: str, task_version: int, task_file: str | None, - task_run_dir: str, task_id: str | None, run_id: str, model: Model, dataset: Dataset, - tool_environment: ToolEnvironmentSpec | None, + tool_environment: tuple[str, str | None] | None, task_attribs: dict[str, Any], task_args: dict[str, Any], model_args: dict[str, Any], @@ -58,7 +57,7 @@ def __init__( recorder: Recorder, ) -> None: # determine versions - git = git_context(task_run_dir) + git = git_context() revision = ( EvalRevision(type="git", origin=git.origin, commit=git.commit) if git @@ -82,11 +81,7 @@ def __init__( samples=len(dataset), shuffled=dataset.shuffled, ), - tool_environment=( - (tool_environment, None) - if isinstance(tool_environment, str) - else tool_environment - ), + tool_environment=tool_environment, task_attribs=task_attribs, task_args=task_args, model_args=model_args, diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index ba27fa8fd..e927c52a8 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -4,23 +4,28 @@ import sys from copy import deepcopy from logging import getLogger -from typing import AsyncGenerator, Awaitable, Callable, Literal +from typing import AsyncGenerator, Callable, Literal, cast from typing_extensions import Unpack from inspect_ai._display import display -from inspect_ai._display._display import TaskProfile -from inspect_ai._util.constants import DEFAULT_EPOCHS +from inspect_ai._display._display import ( + TaskCancelled, + TaskError, + TaskProfile, + TaskSuccess, +) +from inspect_ai._eval.task.util import sample_messages +from inspect_ai._util.constants import DEFAULT_EPOCHS, DEFAULT_MAX_CONNECTIONS from inspect_ai._util.datetime import iso_now -from inspect_ai._util.dotenv import dotenv_environ from inspect_ai._util.error import exception_message from inspect_ai._util.file import file, filesystem -from inspect_ai._util.path import chdir_python from inspect_ai._util.registry import ( is_registry_object, registry_log_name, ) from inspect_ai._util.url import data_uri_to_base64, is_data_uri +from inspect_ai._view.view import view_notify_eval from inspect_ai.dataset import Dataset, Sample from inspect_ai.log import ( EvalConfig, @@ -44,8 +49,8 @@ from inspect_ai.solver._tool.environment.context import ( cleanup_tool_environments_sample, init_tool_environments_sample, - startup_tool_environments, ) +from inspect_ai.solver._tool.environment.registry import registry_find_toolenv from ..context import init_task_context from ..task import Task @@ -53,7 +58,6 @@ from .images import samples_with_base64_images, states_with_base64_images from .log import TaskLogger, collect_eval_data, log_plan from .results import eval_results -from .util import sample_messages, task_run_dir py_logger = getLogger(__name__) @@ -63,13 +67,14 @@ async def task_run( task: Task, - sequence: tuple[int, int], model: Model, + toolenv: tuple[str, str | None] | None, logger: TaskLogger, config: EvalConfig = EvalConfig(), plan: Plan | Solver | list[Solver] | None = None, score: bool = True, sample_source: EvalSampleSource | None = None, + sample_semaphore: asyncio.Semaphore | None = None, **kwargs: Unpack[GenerateConfigArgs], ) -> EvalLog: r"""Run the task. @@ -79,8 +84,8 @@ async def task_run( Args: task (Task): Task to run. - sequence (int): Sequence of the run within a larger set of runs model (Model): Model used to generate output + toolenv (tuple[str, str | None] | None): Tool environment logger (TaskLogger): Logger for recording results. config (EvalConfig): Config (sample range/epochs, logging options) plan:(Plan | Solver | list[Solver] | None): Override of @@ -90,202 +95,182 @@ async def task_run( has a solver and metrics defined. sample_source (EvalSampleSource | None): Source from which previously executed samples can be found/returned + sample_semaphore (Semphonre | None): Semaphore for limiting + number of concurrent samples. **kwargs (GenerateConfigArgs): Generation config options Returns: EvalLog for executed task. """ - with chdir_python(task_run_dir(task)), dotenv_environ(): - # init task context - init_task_context() - - # track stats and error - stats = EvalStats(started_at=iso_now()) - error: EvalError | None = None - cancelled = False - - # resolve some config - model_name = ModelName(model) - epochs = config.epochs if config.epochs else DEFAULT_EPOCHS - toolenv_cleanup = config.toolenv_cleanup is not False - log_images = config.log_images is not False - log_samples = config.log_samples is not False - generate_config = task.config.merge(GenerateConfigArgs(**kwargs)) - - # resolve dataset - _, samples, states = await resolve_dataset( - dataset=task.dataset, - model_name=model_name, - limit=config.limit, - epochs=epochs, - log_images=log_images, - max_messages=config.max_messages, - ) + # init task context + init_task_context(model) + + # track stats and error + stats = EvalStats(started_at=iso_now()) + error: EvalError | None = None + cancelled = False + + # resolve some config + model_name = ModelName(model) + epochs = config.epochs if config.epochs else DEFAULT_EPOCHS + toolenv_cleanup = config.toolenv_cleanup is not False + log_images = config.log_images is not False + log_samples = config.log_samples is not False + generate_config = task.config.merge(GenerateConfigArgs(**kwargs)) + + # resolve dataset + _, samples, states = await resolve_dataset( + dataset=task.dataset, + model_name=model_name, + limit=config.limit, + epochs=epochs, + log_images=log_images, + max_messages=config.max_messages, + ) - # resolve tool_environment - if task.tool_environment: - tool_environment = ( - logger.eval.tool_environment - if logger.eval.tool_environment - else task.tool_environment - ) - else: - tool_environment = None - - # resolve the plan and scorer - plan = ( - plan - if isinstance(plan, Plan) - else Plan(plan) - if plan is not None - else task.plan - ) - score = score and task.scorer is not None - scorer: Scorer | None = task.scorer if (score and task.scorer) else None - - # create task profile for display - profile = TaskProfile( - name=task.name, - sequence=sequence, - model=model_name, - dataset=task.dataset.name or "(samples)", - scorer=( - registry_log_name(scorer) if is_registry_object(scorer) else "(none)" - ), - samples=len(samples), - eval_config=config, - task_args=logger.eval.task_args, - generate_config=generate_config, - log_location=logger.location, - ) + # resolve the plan and scorer + plan = ( + plan + if isinstance(plan, Plan) + else Plan(plan) + if plan is not None + else task.plan + ) + score = score and task.scorer is not None + scorer: Scorer | None = task.scorer if (score and task.scorer) else None - # run startup pass for the tool_environment - shutdown_tool_environments: Callable[[], Awaitable[None]] | None = None - if tool_environment: - shutdown_tool_environments = await startup_tool_environments( - task.name, tool_environment, toolenv_cleanup - ) + # compute steps (steps = samples * steps in plan + 1 for scorer) + steps = len(samples) * ( + len(plan.steps) + (1 if plan.finish else 0) + (1) # scorer + ) - with display().task(profile) as td: - try: - # log the plan - log_plan(logger, plan, generate_config) + # create task profile for display + profile = TaskProfile( + name=task.name, + model=model_name, + dataset=task.dataset.name or "(samples)", + scorer=(registry_log_name(scorer) if is_registry_object(scorer) else "(none)"), + samples=len(samples), + steps=steps, + eval_config=config, + task_args=logger.eval.task_args, + generate_config=generate_config, + log_location=logger.location, + ) - # run w/ progress (steps = samples * steps in plan + 1 for scorer) - total_steps = len(samples) * ( - len(plan.steps) + (1 if plan.finish else 0) + (1) # scorer - ) - with td.progress(total=total_steps) as p: - # forward progress - def progress() -> None: - p.update(1) - - # provide solvers a function that they can use to generate output - async def generate( - state: TaskState, - tool_calls: Literal["loop", "single", "none"] = "loop", - cache: bool | CachePolicy = False, - **kwargs: Unpack[GenerateConfigArgs], - ) -> TaskState: - return await task_generate( - model=model, - state=state, - tool_calls=tool_calls, - cache=cache, - config=generate_config.merge(kwargs), - ) - - # semaphore to limit concurrency - task_semaphore = create_task_semaphone( - config, generate_config, model.api + with display().task(profile) as td: + try: + # log the plan + log_plan(logger, plan, generate_config) + + with td.progress() as p: + # forward progress + def progress() -> None: + p.update(1) + + # provide solvers a function that they can use to generate output + async def generate( + state: TaskState, + tool_calls: Literal["loop", "single", "none"] = "loop", + cache: bool | CachePolicy = False, + **kwargs: Unpack[GenerateConfigArgs], + ) -> TaskState: + return await task_generate( + model=model, + state=state, + tool_calls=tool_calls, + cache=cache, + config=generate_config.merge(kwargs), ) - # create tasks - tasks = [ - task_run_sample( - task_name=task.name, - sample=sample, - state=state, - tool_environment=tool_environment, - toolenv_cleanup=toolenv_cleanup, - plan=plan, - max_messages=config.max_messages, - scorer=scorer, - generate=generate, - progress=progress, - logger=logger if log_samples else None, - log_images=log_images, - sample_source=sample_source, - semaphore=task_semaphore, - ) - for (sample, state) in zip(samples, states) - ] - - # run them in parallel (subject to config.max_samples) - scores = await asyncio.gather(*tasks) - - # compute and record metrics if we have scores - completed_scores = [ - score for score in scores if isinstance(score, Score) - ] - if len(completed_scores) > 0: - results = eval_results( - scores=completed_scores, + # semaphore to limit concurrency + sample_semaphore = ( + sample_semaphore + if sample_semaphore + else create_sample_semaphore( + config, generate_config, toolenv, model.api + ) + ) + + # create sample coroutines + sample_coroutines = [ + task_run_sample( + task_name=task.name, + sample=sample, + state=state, + tool_environment=toolenv, + toolenv_cleanup=toolenv_cleanup, + plan=plan, + max_messages=config.max_messages, scorer=scorer, - metrics=task.metrics, + generate=generate, + progress=progress, + logger=logger if log_samples else None, + log_images=log_images, + sample_source=sample_source, + semaphore=sample_semaphore, ) - logger.log_results(results) - else: - results = EvalResults() - - # collect eval data - collect_eval_data(stats, logger) - - # display task summary - td.summary(results, stats) - - except (asyncio.CancelledError, KeyboardInterrupt): - # flag as cancelled - cancelled = True - - # collect eval data - collect_eval_data(stats, logger) - - # display task cancelled - td.cancelled(logger.samples_logged, stats) - - except BaseException as ex: - # get exception info - type, value, traceback = sys.exc_info() - type = type if type else BaseException - value = value if value else ex - - # build eval error - error = eval_error(ex, type, value, traceback) - - # collect eval data - collect_eval_data(stats, logger) - - # display it - td.error(logger.samples_logged, error, type, value, traceback) - - # log as appropriate - if cancelled: - eval_log = logger.log_cancelled(stats) - elif error: - eval_log = logger.log_failure(stats, error) - else: - eval_log = logger.log_success(stats) - - # run tool environment shutdown if we have it - if shutdown_tool_environments: - try: - await shutdown_tool_environments() - except BaseException as ex: - py_logger.warning( - f"Error occurred shutting down tool environments: {exception_message(ex)}" + for (sample, state) in zip(samples, states) + ] + + # run them in parallel (subject to config.max_samples) + scores = await asyncio.gather(*sample_coroutines) + + # compute and record metrics if we have scores + completed_scores = [score for score in scores if isinstance(score, Score)] + if len(completed_scores) > 0: + results = eval_results( + scores=completed_scores, + scorer=scorer, + metrics=task.metrics, ) + logger.log_results(results) + else: + results = EvalResults() + + # collect eval data + collect_eval_data(stats, logger) + + # display task summary + td.complete(TaskSuccess(stats, results)) + + except (asyncio.CancelledError, KeyboardInterrupt): + # flag as cancelled + cancelled = True + + # collect eval data + collect_eval_data(stats, logger) + + # display task cancelled + td.complete(TaskCancelled(logger.samples_logged, stats)) + + except BaseException as ex: + # get exception info + type, value, traceback = sys.exc_info() + type = type if type else BaseException + value = value if value else ex + + # build eval error + error = eval_error(ex, type, value, traceback) + + # collect eval data + collect_eval_data(stats, logger) + + # display it + td.complete(TaskError(logger.samples_logged, type, value, traceback)) + + # log as appropriate + if cancelled: + eval_log = logger.log_cancelled(stats) + elif error: + eval_log = logger.log_failure(stats, error) + else: + eval_log = logger.log_success(stats) + + # notify the view module that an eval just completed + # (in case we have a view polling for new evals) + view_notify_eval(logger.location) # return eval log return eval_log @@ -489,19 +474,34 @@ def previous(id: int | str, epoch: int) -> EvalSample | None: # semaphore to limit concurrency. default max_samples to # max_connections + 1 if not explicitly specified (this is # to make sure it always saturates the connection pool) -def create_task_semaphone( - config: EvalConfig, generate_config: GenerateConfig, modelapi: ModelAPI +def create_sample_semaphore( + config: EvalConfig, + generate_config: GenerateConfig, + toolenv: tuple[str, str | None] | None = None, + modelapi: ModelAPI | None = None, ) -> asyncio.Semaphore: + # if the user set max_samples then use that + if config.max_samples is not None: + return asyncio.Semaphore(config.max_samples) + + # use max_connections max_samples = ( - config.max_samples - if config.max_samples is not None - else ( - generate_config.max_connections - if generate_config.max_connections is not None - else modelapi.max_connections() - ) - + 1 + generate_config.max_connections + if generate_config.max_connections is not None + else modelapi.max_connections() + if modelapi + else DEFAULT_MAX_CONNECTIONS ) + + # if a toolenv is in play then it can cap max_samples + if toolenv: + toolenv_type = registry_find_toolenv(toolenv[0]) + toolenv_max_samples = cast(int | None, getattr(toolenv_type, "max_samples")()) + if toolenv_max_samples is not None: + if max_samples > toolenv_max_samples: + max_samples = toolenv_max_samples + + # return the semaphore return asyncio.Semaphore(max_samples) diff --git a/src/inspect_ai/_eval/task/task.py b/src/inspect_ai/_eval/task/task.py index 21098a716..c1129946f 100644 --- a/src/inspect_ai/_eval/task/task.py +++ b/src/inspect_ai/_eval/task/task.py @@ -126,6 +126,7 @@ class PreviousTask: | Callable[..., Task] | type[Task] | list[str] + | list[PreviousTask] | list[TaskInfo] | list[Task] | list[Callable[..., Task]] diff --git a/src/inspect_ai/_util/git.py b/src/inspect_ai/_util/git.py index 60ab3604a..47d389c00 100644 --- a/src/inspect_ai/_util/git.py +++ b/src/inspect_ai/_util/git.py @@ -3,34 +3,31 @@ from pydantic import BaseModel -from .path import chdir - class GitContext(BaseModel): origin: str commit: str -def git_context(dir: str) -> GitContext | None: - with chdir(dir): - # check for git - git = shutil.which("git") - if not git: - return None - - # check for a git revision in this directory - commit_result = subprocess.run( - [git, "rev-parse", "--short", "HEAD"], capture_output=True, text=True - ) - if commit_result.returncode != 0: - return None - - # check for git origin (if any) - origin = subprocess.run( - [git, "remote", "get-url", "origin"], - capture_output=True, - text=True, - ).stdout.strip() - - # return context - return GitContext(origin=origin, commit=commit_result.stdout.strip()) +def git_context() -> GitContext | None: + # check for git + git = shutil.which("git") + if not git: + return None + + # check for a git revision in this directory + commit_result = subprocess.run( + [git, "rev-parse", "--short", "HEAD"], capture_output=True, text=True + ) + if commit_result.returncode != 0: + return None + + # check for git origin (if any) + origin = subprocess.run( + [git, "remote", "get-url", "origin"], + capture_output=True, + text=True, + ).stdout.strip() + + # return context + return GitContext(origin=origin, commit=commit_result.stdout.strip()) diff --git a/src/inspect_ai/_view/view.py b/src/inspect_ai/_view/view.py index b1407cffb..25844d189 100644 --- a/src/inspect_ai/_view/view.py +++ b/src/inspect_ai/_view/view.py @@ -218,6 +218,10 @@ def log_dir_aliased(self) -> str: def view_notify_eval(location: str) -> None: + # do not do this when running under pytest + if os.environ.get("PYTEST_VERSION", None) is not None: + return + file = view_last_eval_file() with open(file, "w", encoding="utf-8") as f: if not urlparse(location).scheme: diff --git a/src/inspect_ai/dataset/_dataset.py b/src/inspect_ai/dataset/_dataset.py index 8a3817948..7426651ba 100644 --- a/src/inspect_ai/dataset/_dataset.py +++ b/src/inspect_ai/dataset/_dataset.py @@ -27,7 +27,7 @@ class Sample(BaseModel): input (str | list[ChatMessage]): The input to be submitted to the model. choices (list[str] | None): Optional. List of available answer choices (used only for multiple-choice evals). - target (str | list[str] | None): Optional. Ideal target output. May be a literal value + target (str | list[str]): Optional. Ideal target output. May be a literal value or narrative text to be used by a model grader. id (int | str | None): Optional. Unique identifier for sample. metadata (dict[str,Any] | None): Optional. Arbitrary metadata associated with the sample. diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index 538eee058..e18cebae0 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -400,7 +400,7 @@ def get_model( If `Model` is passed it is returned unmodified, if `None` is passed then the model currently being evaluated is returned (or if there is no evaluation - then the model referred to by `INSPECT_MODEL_NAME`). + then the model referred to by `INSPECT_EVAL_MODEL`). config (GenerationConfig): Configuration for model. base_url (str | None): Optional. Alternate base URL for model. api_key (str | None): Optional. API key for model. @@ -411,23 +411,25 @@ def get_model( Model instance. """ - # if the model is None then use the current model from our async - # context, else try to use INSPECT_EVAL_MODEL (or the legacy INSPECT_MODEL_NAME) - model = ( - model - or active_model() - or os.getenv("INSPECT_EVAL_MODEL", None) - or os.getenv("INSPECT_MODEL_NAME", None) - ) - if model is None: - raise ValueError("No model specified (and no INSPECT_EVAL_MODEL defined)") - - # reflect back model -- we take model as a convenience so that - # function that accept str | Model can always call get_model and - # have it resolve correctly (even if trivially) + # start with seeing if a model was passed if isinstance(model, Model): return model + # now try finding an 'ambient' model (active or env var) + if model is None: + # return active_model if there is one + active = active_model() + if active: + return active + + # return based on env var if there is one + # (handle lists by taking the first model) + model = os.getenv("INSPECT_EVAL_MODEL", None) + if model is not None: + model = model.split(",")[0] + else: + raise ValueError("No model specified (and no INSPECT_EVAL_MODEL defined)") + # ensure that inspect model provider extensions are loaded ensure_entry_points() @@ -465,6 +467,36 @@ def match_modelapi_type(info: RegistryInfo) -> bool: raise ValueError(f"Model name {model}{from_api} not recognized.") +def resolve_models( + model: str | Model | list[str] | list[Model] | None, + model_base_url: str | None = None, + model_args: dict[str, Any] = dict(), + config: GenerateConfig = GenerateConfig(), +) -> list[Model]: + # reflect back a plain model + if isinstance(model, Model): + return [model] + + # helper to resolve model of various types + def resolve_model(m: str | Model | None) -> Model: + return get_model( + model=m, + base_url=model_base_url, + config=config, + **model_args, + ) + + # resolve None and str to list + if model is None or isinstance(model, str): + model = model or os.getenv("INSPECT_EVAL_MODEL", None) + if model is None: + raise ValueError("No model specified (and no INSPECT_EVAL_MODEL defined)") + model = [m.strip() for m in model.split(",")] + + # resolve models + return [resolve_model(m) for m in model] + + def simple_input_messages( input: list[ChatMessage], fold_system_message: Callable[[str, str], str] | None = None, diff --git a/src/inspect_ai/solver/_tool/environment/docker/docker.py b/src/inspect_ai/solver/_tool/environment/docker/docker.py index 0af91a81e..eac4bbf57 100644 --- a/src/inspect_ai/solver/_tool/environment/docker/docker.py +++ b/src/inspect_ai/solver/_tool/environment/docker/docker.py @@ -165,6 +165,10 @@ def __init__(self, service: str, project: ComposeProject) -> None: self._service = service self._project = project + @classmethod + def max_samples(cls) -> int | None: + return 25 + @override async def exec( self, diff --git a/src/inspect_ai/solver/_tool/environment/environment.py b/src/inspect_ai/solver/_tool/environment/environment.py index 890a2ddbb..a3086589c 100644 --- a/src/inspect_ai/solver/_tool/environment/environment.py +++ b/src/inspect_ai/solver/_tool/environment/environment.py @@ -90,6 +90,10 @@ async def cli_cleanup(cls, id: str | None) -> None: """ pass + @classmethod + def max_samples(cls) -> int | None: + return None + @abc.abstractmethod async def exec( self, diff --git a/src/inspect_ai/util/_subprocess.py b/src/inspect_ai/util/_subprocess.py index b57b7abfb..4b70cbdd2 100644 --- a/src/inspect_ai/util/_subprocess.py +++ b/src/inspect_ai/util/_subprocess.py @@ -65,7 +65,7 @@ async def subprocess( Convenience method for solvers, scorers, and tools to launch subprocesses. Automatically enforces a limit on concurrent subprocesses (defaulting to os.cpu_count() but controllable - via the `max_subproccesses` eval config option). + via the `max_subprocesses` eval config option). Args: args (str | list[str]): Command and arguments to execute. diff --git a/tests/test_eval.py b/tests/test_eval.py new file mode 100644 index 000000000..464426729 --- /dev/null +++ b/tests/test_eval.py @@ -0,0 +1,22 @@ +import asyncio + +import pytest + +from inspect_ai import Task, eval_async +from inspect_ai.dataset import Sample +from inspect_ai.scorer import match + + +@pytest.mark.asyncio +async def test_no_concurrent_eval_async(): + tasks = [ + Task(dataset=[Sample(input="Say Hello", target="Hello")], scorer=match()) + for i in range(0, 2) + ] + + results = await asyncio.gather( + *[eval_async(task, model="mockllm/model") for task in tasks], + return_exceptions=True, + ) + + assert any([isinstance(result, RuntimeError) for result in results])