From eca65fc2398fe0ae32448fcc5926b23e0a3f8e70 Mon Sep 17 00:00:00 2001 From: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com> Date: Sat, 4 May 2024 21:02:45 +0000 Subject: [PATCH] release v0.3.5 --- CHANGELOG.md | 8 ++ docs/_quarto.yml | 2 +- docs/eval-logs.qmd | 4 +- docs/eval-tuning.qmd | 2 +- docs/index.qmd | 6 +- docs/log-viewer.qmd | 22 ++--- docs/models.qmd | 4 +- docs/theme.scss | 3 + docs/tools.qmd | 2 +- docs/workflow.qmd | 6 +- pyproject.toml | 1 + src/inspect_ai/_view/view.py | 17 +++- src/inspect_ai/dataset/_dataset.py | 94 +++++++++++++++++----- src/inspect_ai/model/_model.py | 11 +-- src/inspect_ai/model/_providers/hf.py | 70 +++++++++++----- src/inspect_ai/model/_providers/mistral.py | 2 +- src/inspect_ai/model/_providers/openai.py | 2 +- tests/test_hf.py | 45 +++++++++++ 18 files changed, 227 insertions(+), 74 deletions(-) create mode 100644 tests/test_hf.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f757e858..a13d59007 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v0.3.5 (04 May 2024) + +- Fix issue with logs from S3 buckets in inspect view. +- Add `sort()` method to `Dataset` (defaults to sorting by sample input length). +- Improve tokenization for HF provider (left padding, attention mask, and allow for custom chat template) +- Improve batching for HF provider (generate as soon as queue fills, thread safety for future.set_result). +- Various improvements to documentation. + ## v0.3.4 (01 May 2024) - `write_eval_log()` now ignores unserializable objects in metadata fields. diff --git a/docs/_quarto.yml b/docs/_quarto.yml index 981f22fa7..1dabd1e37 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -17,7 +17,7 @@ book: description: "Open-source framework for large language model evaluations" sidebar: header: > - [![](images/aisi-logo.png)](https://www.gov.uk/government/organisations/ai-safety-institute) + [![](images/aisi-logo.png){fig-alt="UK AI Safety Institute Website"}](https://www.gov.uk/government/organisations/ai-safety-institute) page-footer: left: diff --git a/docs/eval-logs.qmd b/docs/eval-logs.qmd index a0d4968f4..904b92a9c 100644 --- a/docs/eval-logs.qmd +++ b/docs/eval-logs.qmd @@ -8,7 +8,7 @@ Every time you use `inspect eval` or call the `eval()` function, an evaluation l $ inspect eval security_guide.py --model openai/gpt-4 ``` -![](images/eval-log.png) +![](images/eval-log.png){fig-alt="The Inspect task results displayed in the terminal. A link to the evaluation log is at the bottom of the results display."} You can also use the Inspect log viewer for interactive exploration of logs. Run this command once at the beginning of a working session (the view will update automatically when new evaluations are run): @@ -16,7 +16,7 @@ You can also use the Inspect log viewer for interactive exploration of logs. Run $ inspect view ``` -![](images/inspect-view-main.png){.border .lightbox} +![](images/inspect-view-main.png){.border .lightbox fig-alt="The Inspect log viewer, displahing a summary of results for the task as well as 8 individual samples."} This section won't cover using `inspect view` though. Rather, it will cover the details of managing log usage from the CLI as well as the Python API for reading logs. See the [Log Viewer](#sec-log-viewer) section for details on interactively exploring logs. diff --git a/docs/eval-tuning.qmd b/docs/eval-tuning.qmd index 1b3f0e83d..d8b8bbacf 100644 --- a/docs/eval-tuning.qmd +++ b/docs/eval-tuning.qmd @@ -23,7 +23,7 @@ The default value for max connections is 10. By increasing it we might get bette When you run an eval you'll see information reported on the current active connection usage as well as the number of HTTP rate limit errors that have been encountered (note that Inspect will automatically retry on rate limits and other errors likely to be transient): -![](images/rate-limit.png) +![](images/rate-limit.png){fig-alt="The Inspect task results displayed in the terminal. The number of HTTP rate limit errors that have occurred (25) is printed in the bottom right of the task results."} Here we've set a higher max connections than the default (30). While you might be tempted to set this very high to see how much concurrent traffic you can sustain, more often than not setting too high a max connections will result in slower evaluations, because retries are done using [exponential backoff](https://en.wikipedia.org/wiki/Exponential_backoff), and bouncing off of rate limits too frequently will have you waiting minutes for retries to fire. diff --git a/docs/index.qmd b/docs/index.qmd index 4cd2da9f8..fba12ab94 100644 --- a/docs/index.qmd +++ b/docs/index.qmd @@ -12,7 +12,7 @@ toc: false - Adapt and extend the framework with custom Python components. -![](images/inspect-view-splash.png){.lightbox .border} +![](images/inspect-view-splash.png){.lightbox .border fig-alt="The Inspect log viewer, displahing a summary of results for the task as well as 5 individual samples."} ::: @@ -142,7 +142,7 @@ The `@task` decorator applied to the `theory_of_mind()` function is what enables $ inspect eval theory_of_mind.py --model openai/gpt-4 ``` -![](images/running-theory.png) +![](images/running-theory.png){fig-alt="The Inspect task results displayed in the terminal. A progress bar indicates that the evaluation is about 60% complete."} By default, eval logs are written to the `./logs` sub-directory of the current working directory. When the eval is complete you will find a link to the log at the bottom of the task results summary. @@ -152,7 +152,7 @@ You can also explore eval results using the Inspect log viewer. Run `inspect vie $ inspect view ``` -![](images/inspect-view-home.png){.border .lightbox} +![](images/inspect-view-home.png){.border .lightbox fig-alt="The Inspect log viewer, displahing a summary of results for the task as well as 7 individual samples."} See the [Log Viewer](#sec-log-viewer) section for additional details on using Inspect View. diff --git a/docs/log-viewer.qmd b/docs/log-viewer.qmd index 3b7a6a3a5..e30bace8b 100644 --- a/docs/log-viewer.qmd +++ b/docs/log-viewer.qmd @@ -4,7 +4,7 @@ Inspect View provides a convenient way to visualise evaluation logs, including drilling into message histories, scoring decisions, and additional metadata written to the log. Here's what the main view of an evaluation log looks like: -![](images/inspect-view-main.png){.border .lightbox} +![](images/inspect-view-main.png){.border .lightbox fig-alt="The Inspect log viewer, displahing a summary of results for the task as well as 8 individual samples."} Below we'll describe how to get the most out of using Inspect View. @@ -36,7 +36,7 @@ You only need to run `inspect view` once at the beginning of a session (as it wi You can view and navigate between a history of all evals in the log directory using the menu at the top right: -![](images/inspect-view-history.png){.border .lightbox} +![](images/inspect-view-history.png){.border .lightbox fig-alt="The Inspect log viewer, with the history panel displayed on the left overlaying the main interface. Several log files are displayed in the log history, each of which includes a summary of the results."} ## Sample Details @@ -46,7 +46,7 @@ Click a sample to drill into its messages, scoring, and metadata. The messages tab displays the message history. In this example we see that the model make two tool calls before answering (the final assistant message is not fully displayed for brevity): -![](images/inspect-view-messages.png){.border .lightbox} +![](images/inspect-view-messages.png){.border .lightbox fig-alt="The Inspect log viewer showing a sample expanded, with details on the user, assistant, and tool messages for the sample."} Looking carefully at the message history (especially for agents or multi-turn solvers) is critically important for understanding how well your evaluation is constructed. @@ -54,13 +54,13 @@ Looking carefully at the message history (especially for agents or multi-turn so The scoring tab shows additional details including the full input and full model explanation for answers: -![](images/inspect-view-scoring.png){.border .lightbox} +![](images/inspect-view-scoring.png){.border .lightbox fig-alt="The Inspect log viewer showing a sample expanded, with details on the scoring of the sample, including the input, target, answer, and explanation."} ### Metadata The metadata tab shows additional data made available by solvers, tools, an scorers (in this case the `web_search()` tool records which URLs it visited to retreive additional context): -![](images/inspect-view-metadata.png){.border .lightbox} +![](images/inspect-view-metadata.png){.border .lightbox fig-alt="The Inspect log viewer showing a sample expanded, with details on the metadata recorded by the web search tool during the evaluation (specifically, the URLs queried by the web search tool for the sample)."} ## Scores and Answers @@ -75,7 +75,7 @@ A scorer can fail to correctly score output at either of these steps. Failing to You can use the log viewer to catch and evaluate these sorts of issues. For example, here we can see that we were unable to extract answers for a couple of questions that were scored incorrect: -![](images/inspect-view-answers.png){.border .lightbox} +![](images/inspect-view-answers.png){.border .lightbox fig-alt="The Inspect log viewer with several 5 samples displayed, 3 of which are incorrect. The Answer column displays the answer extracted from the model output for each sample."} It's possible that these answers are legitimately incorrect. However it's also possible that the correct answer is in the model's output but just in a format we didn't quite expect. In each case you'll need to drill into the sample to investigate. @@ -97,11 +97,11 @@ Note there is also an `explanation` field: this is also important, as it allows It's often useful to filter log entries by score (for example, to investigate whether incorrect answers are due to scorer issues or are true negatives). Use the **Scores** picker to filter by specific scores: -![](images/inspect-view-filter.png){.border .lightbox} +![](images/inspect-view-filter.png){.border .lightbox fig-alt="The Inspect log view, with 4 samples displayed, each of which are marked incorrect. The Scores picker is focused, and has selected 'Incorrect', indicating that only incorrect scores should be displayed."} By default, samples are ordered (with all samples for an epoch presented in sequence). However you can also order by score, or order by samples (so you see all of the results for a given sample across all epochs presented together). Use the **Sort** picker to control this: -![](images/inspect-view-sort.png){.border .lightbox} +![](images/inspect-view-sort.png){.border .lightbox fig-alt="The Inspect log view, with the results of a single sample for each of the 4 epochs of the evaluation."} Viewing by sample can be especially valuable for diagnosing the sources of inconsistency (and determining whether they are inherent or an artifact of the evaluation methodology). Above we can see that sample 1 is incorrect in epoch 1 because of issue the model had with forming a correct function call. @@ -121,7 +121,7 @@ logger.info(f"web query: {query}") You can see all of these log entries in the **Logging** tab: -![](images/inspect-view-logging.png){.border .lightbox} +![](images/inspect-view-logging.png){.border .lightbox fig-alt="The Logging panel of the Inspect log viewer, displaying several info log messages from the web search tool indicating what queries were executed by the tool."} It is important to note that the Inspect View will show all log entries level `info` or higher. However, printing every `info` message to the console during an eval might be too distracting, so the default log level for printing is `warning`. If you change it to `info` then you'll also see these log messages in the console: @@ -129,7 +129,7 @@ It is important to note that the Inspect View will show all log entries level `i $ inspect eval biology_qa.py --log-level info ``` -![](images/inspect-view-logging-console.png){.lightbox} +![](images/inspect-view-logging-console.png){.lightbox fig-alt="This Inspect task display in the terminal, with several info log messages from the web search tool printed above the the task diplay."} A default log level of `warning` enables you to include many calls to `logger.info()` in your code without having them show by default, while also making them available in the log viewer should you need them. @@ -139,4 +139,4 @@ Note that you can also set the log level using the `INSPECT_LOG_LEVEL` environme The **Info** panel of the log viewer provides additional meta-information about evaluation tasks, including dataset, plan, and scorer details, git revision, and model token usage: -![](images/inspect-view-info.png){style=".border .lightbox"} \ No newline at end of file +![](images/inspect-view-info.png){.border .lightbox fig-alt="The Info panel of the Inspect log viewer, displaying various details about the evaluation including dataset, plan, and scorer details, git revision, and model token usage."} \ No newline at end of file diff --git a/docs/models.qmd b/docs/models.qmd index 74b011e04..b39880d14 100644 --- a/docs/models.qmd +++ b/docs/models.qmd @@ -295,7 +295,7 @@ def theory_of_mind(): ## Model Args -The section above illustrates passing model specific arguments to local models on the command line, in `eval()`, and in `get_model()`. This actually works for all model types, so if there is an additional aspect of a modal you want to tweak that isn't covered by the `GenerationConfig`, you can use this method to do it. For example, here we specify the `transport` option for a Google Gemini model: +The section above illustrates passing model specific arguments to local models on the command line, in `eval()`, and in `get_model()`. This actually works for all model types, so if there is an additional aspect of a model you want to tweak that isn't covered by the `GenerationConfig`, you can use this method to do it. For example, here we specify the `transport` option for a Google Gemini model: ``` bash inspect eval popularity --model google/gemini-1.0-pro -M transport:grpc @@ -358,4 +358,4 @@ model = get_model("custom/name-of-model") eval(math, model = "custom/name-of-model") ``` -In this example, the `model_name` argument passed to `__init__()` will be "name-of-model". \ No newline at end of file +In this example, the `model_name` argument passed to `__init__()` will be "name-of-model". diff --git a/docs/theme.scss b/docs/theme.scss index e36b65b63..27e5d9db9 100644 --- a/docs/theme.scss +++ b/docs/theme.scss @@ -46,3 +46,6 @@ } } +.blockquote { + color: #505a62; +} diff --git a/docs/tools.qmd b/docs/tools.qmd index 4d550af04..af340c856 100644 --- a/docs/tools.qmd +++ b/docs/tools.qmd @@ -183,7 +183,7 @@ Web search options include: - `model`---Model to use to determine if search results are relevant (defaults to the model currently being evaluated). -#### Google Provider +### Google Provider The `web_search()` tool uses [Google Programmable Search Engine](https://programmablesearchengine.google.com/about/). To use it you will therefore need to setup your own Google Programmable Search Engine and also enable the [Programmable Search Element Paid API](https://developers.google.com/custom-search/docs/paid_element). Then, ensure that the following environment variables are defined: diff --git a/docs/workflow.qmd b/docs/workflow.qmd index 62f5692b7..fa2ba17fe 100644 --- a/docs/workflow.qmd +++ b/docs/workflow.qmd @@ -37,7 +37,7 @@ You can run this evaluation from the shell using the `inspect eval` command. For $ inspect eval theory.py --model openai/gpt-4 ``` -![](images/running-theory.png) +![](images/running-theory.png){fig-alt="The Inspect task results displayed in the terminal. A progress bar indicates that the evaluation is about 60% complete."} Immediately after an evaluation completes, a link to the log for the evaluation is written to the terminal (if you are running in VS Code this link will open the log in an editor within the IDE). @@ -67,12 +67,12 @@ As you iterate on an evaluation, you'll typically want to dig further into messa $ inspect view ``` -![](images/inspect-view-main.png){.border .lightbox} +![](images/inspect-view-main.png){.border .lightbox fig-alt="The Inspect log viewer, displahing a summary of results for the task as well as 8 individual samples."} The log viewer will update automatically whenever a new evaluation is completed (you can also navigate back to previous evaluations). The log viewer summarises aggregate data and also provides a detailed view into each sample. For example, here we zoom in on the model's scoring explanation for a specific sample: -![](images/inspect-view-scoring.png){.border .lightbox} +![](images/inspect-view-scoring.png){.border .lightbox fig-alt="The Inspect log viewer showing a sample expanded, with details on the scoring of the sample, including the input, target, answer, and explanation."} See the [Log Viewer](#sec-log-viewer) section for additional details on using Inspect View. diff --git a/pyproject.toml b/pyproject.toml index 1eb6dd922..826bb4971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ dev = [ "mistralai", "boto3", "transformers", + "accelerate", "torch", "datasets", "langchain", diff --git a/src/inspect_ai/_view/view.py b/src/inspect_ai/_view/view.py index cf16baf26..cd54eeace 100644 --- a/src/inspect_ai/_view/view.py +++ b/src/inspect_ai/_view/view.py @@ -9,7 +9,7 @@ from io import BytesIO from pathlib import Path from typing import Any -from urllib.parse import parse_qs, urlparse +from urllib.parse import parse_qs, urlparse, urlunparse import psutil @@ -128,10 +128,23 @@ def handle_log(self) -> None: # check for query params parsed = urlparse(path) - path = parsed.path + + # read query parameters from the URL query_params = parse_qs(parsed.query) header_only = query_params.get("header-only", None) is not None + # reconstruct the path + path = urlunparse( + ( + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + "", # Clear the query component + parsed.fragment, + ) + ) + ctype = self.guess_type(path) try: contents: bytes | None = None diff --git a/src/inspect_ai/dataset/_dataset.py b/src/inspect_ai/dataset/_dataset.py index 93c36490d..4a0c44a4c 100644 --- a/src/inspect_ai/dataset/_dataset.py +++ b/src/inspect_ai/dataset/_dataset.py @@ -1,12 +1,24 @@ import abc import random -from typing import Any, Callable, Iterator, Sequence, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Optional, + Sequence, + Union, + overload, +) from pydantic import BaseModel, Field from typing_extensions import override from inspect_ai.model import ChatMessage +if TYPE_CHECKING: + from _typeshed import SupportsRichComparison + class Sample(BaseModel): r"""Sample to be used in an evaluation task. @@ -37,6 +49,21 @@ class Sample(BaseModel): """Arbitrary metadata associated with the sample.""" +def sample_input_len(sample: Sample) -> int: + """Measures the length of a samples `input` field. + + The default length function use in `Dataset.sort()`. + + Args: + sample (Sample): A Sample to be used in an evaluation task. + """ + return ( + len(sample.input) + if isinstance(sample.input, str) + else sum(len(inp.text) for inp in sample.input) + ) + + DatasetRecord = dict[str, Any] DatasetReader = Iterator[DatasetRecord] @@ -50,28 +77,22 @@ class Dataset(Sequence[Sample], abc.ABC): """ @abc.abstractproperty - def name(self) -> str | None: - ... + def name(self) -> str | None: ... @abc.abstractproperty - def location(self) -> str | None: - ... + def location(self) -> str | None: ... @overload - def __getitem__(self, index: int) -> Sample: - ... + def __getitem__(self, index: int) -> Sample: ... @overload - def __getitem__(self, index: slice) -> "Dataset": - ... + def __getitem__(self, index: slice) -> "Dataset": ... @abc.abstractmethod - def __getitem__(self, index: Union[int, slice]) -> Union[Sample, "Dataset"]: - ... + def __getitem__(self, index: Union[int, slice]) -> Union[Sample, "Dataset"]: ... @abc.abstractmethod - def __len__(self) -> int: - ... + def __len__(self) -> int: ... @abc.abstractmethod def shuffle(self, seed: int | None = None) -> None: @@ -81,6 +102,24 @@ def shuffle(self, seed: int | None = None) -> None: seed: (int | None): Random seed for shuffling (optional). """ + @abc.abstractmethod + def sort( + self, + reverse: bool = False, + key: Optional[Callable[[Sample], "SupportsRichComparison"]] = sample_input_len, + ) -> None: + """Sort the dataset (in place) in ascending order and return None. + + If a key function is given, apply it once to each list item and sort them, ascending or descending, according to their function values. + + The key function defaults to measuring the length of the sample's input field. + + Args: + reverse (bool): if true, sort in descending order. Defaults to False. + key (Callable[[Any], Any]): a callable mapping each item to a numeric value (optional, defaults to sample_input_len). + """ + + @abc.abstractmethod def filter( self, predicate: Callable[[Sample], bool], name: str | None = None ) -> "Dataset": @@ -93,11 +132,6 @@ def filter( Returns: Filtered dataset. """ - return MemoryDataset( - name=name or self.name, - location=self.location, - samples=[sample for sample in self if predicate(sample)], - ) class FieldSpec(BaseModel): @@ -168,12 +202,10 @@ def location(self) -> str | None: return self._location @overload - def __getitem__(self, index: int) -> Sample: - ... + def __getitem__(self, index: int) -> Sample: ... @overload - def __getitem__(self, index: slice) -> Dataset: - ... + def __getitem__(self, index: slice) -> Dataset: ... @override def __getitem__(self, index: Union[int, slice]) -> Union[Sample, Dataset]: @@ -194,3 +226,21 @@ def shuffle(self, seed: int | None = None) -> None: random.Random(seed).shuffle(self.samples) else: random.shuffle(self.samples) + + @override + def sort( + self, + reverse: bool = False, + key: Optional[Callable[[Sample], "SupportsRichComparison"]] = sample_input_len, + ) -> None: + self.samples.sort(reverse=reverse, key=key) + + @override + def filter( + self, predicate: Callable[[Sample], bool], name: str | None = None + ) -> "MemoryDataset": + return MemoryDataset( + name=name or self.name, + location=self.location, + samples=[sample for sample in self if predicate(sample)], + ) diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index 55d54715c..a839bb699 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -345,10 +345,11 @@ def completion(self, completion: str) -> None: if len(self.choices) > 0: self.choices[0].message.text = completion else: - self.choices.append(ChatCompletionChoice( - message = ChatMessageAssistant(content = completion), - stop_reason="stop" - )) + self.choices.append( + ChatCompletionChoice( + message=ChatMessageAssistant(content=completion), stop_reason="stop" + ) + ) @staticmethod def from_content( @@ -677,7 +678,7 @@ def get_model( model: str | Model | None = None, config: GenerateConfig = GenerateConfig(), base_url: str | None = None, - **model_args: dict[str, Any], + **model_args: Any, ) -> Model: """Get an instance of a model. diff --git a/src/inspect_ai/model/_providers/hf.py b/src/inspect_ai/model/_providers/hf.py index b1327ae0b..3bfe4a209 100644 --- a/src/inspect_ai/model/_providers/hf.py +++ b/src/inspect_ai/model/_providers/hf.py @@ -9,7 +9,11 @@ import numpy as np import torch from torch import Tensor -from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed # type: ignore +from transformers import ( # type: ignore + AutoModelForCausalLM, + AutoTokenizer, + set_seed, +) from typing_extensions import override from inspect_ai._util.constants import DEFAULT_MAX_TOKENS @@ -55,6 +59,7 @@ def collect_model_arg(name: str) -> Any | None: model_path = collect_model_arg("model_path") tokenizer_path = collect_model_arg("tokenizer_path") self.batch_size = collect_model_arg("batch_size") + self.chat_template = collect_model_arg("chat_template") # device if device: @@ -88,6 +93,7 @@ def collect_model_arg(name: str) -> Any | None: self.tokenizer = AutoTokenizer.from_pretrained(model_name) # LLMs generally don't have a pad token and we need one for batching self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" async def generate( self, @@ -129,6 +135,7 @@ async def generate( tokenizer=tokenizer, generator=generator, decoder=decoder, + batch_size=config.max_connections or self.max_connections(), ) ) @@ -165,9 +172,11 @@ def hf_chat(self, messages: list[ChatMessage]) -> str: hf_messages = chat_api_input(messages) # apply chat template chat = self.tokenizer.apply_chat_template( - hf_messages, add_generation_prompt=True, tokenize=False + hf_messages, + add_generation_prompt=True, + tokenize=False, + chat_template=self.chat_template, ) - # return return cast(str, chat) @@ -182,18 +191,17 @@ def set_random_seeds(seed: int | None = None) -> None: class Tokenizer(Protocol): - def __call__(self, input: list[str]) -> dict[Literal["input_ids"], Tensor]: - ... + def __call__( + self, input: list[str] + ) -> dict[Literal["input_ids", "attention_mask"], Tensor]: ... class Generator(Protocol): - def __call__(self, input_ids: Tensor) -> Tensor: - ... + def __call__(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: ... class Decoder(Protocol): - def __call__(self, sequences: Tensor) -> list[str]: - ... + def __call__(self, sequences: Tensor) -> list[str]: ... @dataclass @@ -203,6 +211,7 @@ class GenerateInput: tokenizer: Tokenizer generator: Generator decoder: Decoder + batch_size: int @dataclass @@ -213,9 +222,16 @@ class GenerateOutput: total_tokens: int +@dataclass +class _QueueItem: + input: GenerateInput + future: asyncio.Future[GenerateOutput] + loop: asyncio.AbstractEventLoop + + batch_thread: Thread | None = None -batch_queue: "Queue[tuple[GenerateInput, asyncio.Future[GenerateOutput]]]" = Queue() +batch_queue: "Queue[_QueueItem]" = Queue() async def batched_generate(input: GenerateInput) -> GenerateOutput: @@ -228,7 +244,7 @@ async def batched_generate(input: GenerateInput) -> GenerateOutput: # enque the job loop = asyncio.get_event_loop() future: asyncio.Future[GenerateOutput] = loop.create_future() - batch_queue.put((input, future)) + batch_queue.put(_QueueItem(input=input, future=future, loop=loop)) # await the job await future @@ -244,8 +260,13 @@ def process_batches() -> None: while True: try: input = batch_queue.get(timeout=2) - inputs.append(input) + loop = input.loop + inputs.append((input.input, input.future)) + if len(inputs) == input.input.batch_size: + # max batch size reached + break except Empty: + # we have exhausted the queue break # see if we have any work to do @@ -261,12 +282,17 @@ def process_batches() -> None: decoder = first_input.decoder # tokenize and move to device - input_ids = tokenizer([item[0].input for item in inputs])["input_ids"] + tokenized_inputs = tokenizer([item[0].input for item in inputs]) + input_ids = tokenized_inputs["input_ids"] + attention_mask = tokenized_inputs["attention_mask"] input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) # generate with torch.inference_mode(): - generate_ids = generator(input_ids=input_ids) + generate_ids = generator( + input_ids=input_ids, attention_mask=attention_mask + ) # decode outputs = decoder(sequences=generate_ids[:, input_ids.size(dim=1) :]) @@ -276,15 +302,21 @@ def process_batches() -> None: future = inputs[i][1] input_tokens = input_ids.size(dim=1) output_tokens = generate_ids.size(dim=1) - input_ids.size(dim=1) - future.set_result( + + # asyncio futures are not thread safe, so we need to pass the event loop + # down to this point, so we can mark the future as done in a thread safe manner. + # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading + loop.call_soon_threadsafe( + future.set_result, GenerateOutput( output=output, input_tokens=input_tokens, output_tokens=output_tokens, total_tokens=input_tokens + output_tokens, - ) + ), ) + except Exception as ex: - for input in inputs: - future = input[1] - future.set_exception(ex) + for inp in inputs: + future = inp[1] + loop.call_soon_threadsafe(future.set_exception, ex) diff --git a/src/inspect_ai/model/_providers/mistral.py b/src/inspect_ai/model/_providers/mistral.py index 3462fb0e8..b6fd2e204 100644 --- a/src/inspect_ai/model/_providers/mistral.py +++ b/src/inspect_ai/model/_providers/mistral.py @@ -75,7 +75,7 @@ def __init__( if not base_url: raise ValueError( "You must provide a base URL when using Mistral on Azure. Use the AZUREAI_MISTRAL_BASE_URL " - + " environment variable or the --model_base_url CLI flag to set the base URL." + + " environment variable or the --model-base-url CLI flag to set the base URL." ) model_args["endpoint"] = base_url diff --git a/src/inspect_ai/model/_providers/openai.py b/src/inspect_ai/model/_providers/openai.py index b8aa2c79e..d1f102d1d 100644 --- a/src/inspect_ai/model/_providers/openai.py +++ b/src/inspect_ai/model/_providers/openai.py @@ -90,7 +90,7 @@ def __init__( if not base_url: raise ValueError( "You must provide a base URL when using OpenAI on Azure. Use the AZUREAI_OPENAI_BASE_URL " - + " environment variable or the --model_base_url CLI flag to set the base URL." + + " environment variable or the --model-base-url CLI flag to set the base URL." ) self.client: AsyncAzureOpenAI | AsyncOpenAI = AsyncAzureOpenAI( diff --git a/tests/test_hf.py b/tests/test_hf.py new file mode 100644 index 000000000..983b23b26 --- /dev/null +++ b/tests/test_hf.py @@ -0,0 +1,45 @@ +import pytest +from transformers import PreTrainedModel # type: ignore +from utils import skip_if_github_action + +from inspect_ai.model import ( + ChatMessageUser, + GenerateConfig, + get_model, +) + + +@pytest.fixture +def model() -> PreTrainedModel: + return get_model( + "hf/EleutherAI/pythia-70m", + config=GenerateConfig( + max_tokens=1, + seed=42, + temperature=0.01, + ), + # this allows us to run base models with the chat message scaffolding: + chat_template="{% for message in messages %}{{ message.content }}{% endfor %}", + ) + + +@pytest.mark.asyncio +@skip_if_github_action +async def test_hf_api(model: PreTrainedModel) -> None: + message = ChatMessageUser(content="Lorem ipsum dolor") + response = await model.generate(input=[message]) + assert len(response.completion) >= 1 + + +@pytest.mark.asyncio +@skip_if_github_action +async def test_hf_api_fails(model: PreTrainedModel) -> None: + temp_before = model.config.temperature + try: + model.config.temperature = 0.0 + + message = ChatMessageUser(content="Lorem ipsum dolor") + with pytest.raises(Exception): + await model.generate(input=[message]) + finally: + model.config.temperature = temp_before