diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2fa5b056..e31fcefb 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,6 +3,37 @@ name: CI on: [ push, pull_request ] jobs: + test: + name: "Test: Python ${{ matrix.python }} on ${{ matrix.os }}" + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + python: + - "3.8" + - "3.9" + - "3.10" + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + - name: Install dependencies + run: | + pip install --upgrade pip + pip install . + pip install -r tests/requirements_test.txt + - name: Install Cleanlab Studio client + run: pip install -e . + - name: Cleanlab login + run: cleanlab login --key "${{ secrets.CLEANLAB_STUDIO_CI_API_KEY }}" + - name: Run tests + run: | + pytest --verbose typecheck: name: Type check runs-on: ubuntu-latest diff --git a/cleanlab_studio/errors.py b/cleanlab_studio/errors.py index c235732a..23c21507 100644 --- a/cleanlab_studio/errors.py +++ b/cleanlab_studio/errors.py @@ -1,3 +1,6 @@ +from asyncio import Handle + + class HandledError(Exception): pass @@ -34,6 +37,10 @@ class SettingsError(HandledError): pass +class ValidationError(HandledError): + pass + + class UploadError(HandledError): pass @@ -63,16 +70,22 @@ class APITimeoutError(HandledError): pass -class RateLimitError(APIError): +class RateLimitError(HandledError): def __init__(self, message: str, retry_after: int): self.message = message self.retry_after = retry_after -class TlmBadRequest(APIError): +class TlmBadRequest(HandledError): pass +class TlmServerError(APIError): + def __init__(self, message: str, status_code: int) -> None: + self.message = message + self.status_code = status_code + + class UnsupportedVersionError(HandledError): def __init__(self) -> None: super().__init__( diff --git a/cleanlab_studio/internal/api/api.py b/cleanlab_studio/internal/api/api.py index ac42ea4a..383c038c 100644 --- a/cleanlab_studio/internal/api/api.py +++ b/cleanlab_studio/internal/api/api.py @@ -1,19 +1,21 @@ import asyncio import io -from math import e import os import time from typing import Callable, cast, List, Optional, Tuple, Dict, Union, Any + from cleanlab_studio.errors import ( APIError, IngestionError, InvalidProjectConfiguration, RateLimitError, TlmBadRequest, + TlmServerError, ) -from cleanlab_studio.internal.util import get_basic_info, obfuscate_stack_trace +from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler import aiohttp +import aiohttp.client_exceptions import requests from tqdm import tqdm import pandas as pd @@ -88,18 +90,40 @@ def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None: ) -async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse) -> None: +async def handle_tlm_client_error_from_resp( + resp: aiohttp.ClientResponse, batch_index: Optional[int] +) -> None: """Catches 4XX (client error) errors.""" if 400 <= resp.status < 500: try: res_json = await resp.json() error_message = res_json["error"] except Exception: - error_message = "Client error occurred." + error_message = "TLM query failed. Please try again and contact support@cleanlab.ai if the problem persists." + + if batch_index is not None: + error_message = f"Error executing query at index {batch_index}:\n{error_message}" raise TlmBadRequest(error_message) +async def handle_tlm_api_error_from_resp( + resp: aiohttp.ClientResponse, batch_index: Optional[int] +) -> None: + """Catches 5XX (server error) errors.""" + if 500 <= resp.status < 600: + try: + res_json = await resp.json() + error_message = res_json["error"] + except Exception: + error_message = "TLM query failed. Please try again and contact support@cleanlab.ai if the problem persists." + + if batch_index is not None: + error_message = f"Error executing query at index {batch_index}:\n{error_message}" + + raise TlmServerError(error_message, resp.status) + + def validate_api_key(api_key: str) -> bool: res = requests.get( cli_base_url + "/validate", @@ -556,6 +580,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: await asyncio.sleep(sleep_time) try: return await func(*args, **kwargs) + except aiohttp.client_exceptions.ClientConnectorError as e: + # note: we don't increment num_try here, because we don't want connection errors to count against the total number of retries + sleep_time = 2**num_try except RateLimitError as e: # note: we don't increment num_try here, because we don't want rate limit retries to count against the total number of retries sleep_time = e.retry_after @@ -578,7 +605,9 @@ async def tlm_prompt( prompt: str, quality_preset: str, options: Optional[JSONDict], + rate_handler: TlmRateHandler, client_session: Optional[aiohttp.ClientSession] = None, + batch_index: Optional[int] = None, ) -> JSONDict: """ Prompt Trustworthy Language Model with a question, and get back its answer along with a confidence score @@ -588,7 +617,9 @@ async def tlm_prompt( prompt (str): prompt for TLM to respond to quality_preset (str): quality preset to use to generate response options (JSONDict): additional parameters for TLM + rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate client_session (aiohttp.ClientSession): client session used to issue TLM request + batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch. Returns: JSONDict: dictionary with TLM response and confidence score @@ -599,16 +630,18 @@ async def tlm_prompt( local_scoped_client = True try: - res = await client_session.post( - f"{tlm_base_url}/prompt", - json=dict(prompt=prompt, quality=quality_preset, options=options or {}), - headers=_construct_headers(api_key), - ) - res_json = await res.json() + async with rate_handler: + res = await client_session.post( + f"{tlm_base_url}/prompt", + json=dict(prompt=prompt, quality=quality_preset, options=options or {}), + headers=_construct_headers(api_key), + ) - handle_rate_limit_error_from_resp(res) - await handle_tlm_client_error_from_resp(res) - handle_api_error_from_json(res_json) + res_json = await res.json() + + handle_rate_limit_error_from_resp(res) + await handle_tlm_client_error_from_resp(res, batch_index) + await handle_tlm_api_error_from_resp(res, batch_index) finally: if local_scoped_client: @@ -624,7 +657,9 @@ async def tlm_get_confidence_score( response: str, quality_preset: str, options: Optional[JSONDict], + rate_handler: TlmRateHandler, client_session: Optional[aiohttp.ClientSession] = None, + batch_index: Optional[int] = None, ) -> JSONDict: """ Query Trustworthy Language Model for a confidence score for the prompt-response pair. @@ -635,7 +670,9 @@ async def tlm_get_confidence_score( response (str): response for TLM to get confidence score for quality_preset (str): quality preset to use to generate confidence score options (JSONDict): additional parameters for TLM + rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate client_session (aiohttp.ClientSession): client session used to issue TLM request + batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch. Returns: JSONDict: dictionary with TLM confidence score @@ -646,21 +683,20 @@ async def tlm_get_confidence_score( local_scoped_client = True try: - res = await client_session.post( - f"{tlm_base_url}/get_confidence_score", - json=dict( - prompt=prompt, response=response, quality=quality_preset, options=options or {} - ), - headers=_construct_headers(api_key), - ) - res_json = await res.json() + async with rate_handler: + res = await client_session.post( + f"{tlm_base_url}/get_confidence_score", + json=dict( + prompt=prompt, response=response, quality=quality_preset, options=options or {} + ), + headers=_construct_headers(api_key), + ) - if local_scoped_client: - await client_session.close() + res_json = await res.json() - handle_rate_limit_error_from_resp(res) - await handle_tlm_client_error_from_resp(res) - handle_api_error_from_json(res_json) + handle_rate_limit_error_from_resp(res) + await handle_tlm_client_error_from_resp(res, batch_index) + await handle_tlm_api_error_from_resp(res, batch_index) finally: if local_scoped_client: diff --git a/cleanlab_studio/internal/constants.py b/cleanlab_studio/internal/constants.py index 530325e1..4f7af2fc 100644 --- a/cleanlab_studio/internal/constants.py +++ b/cleanlab_studio/internal/constants.py @@ -1,6 +1,10 @@ -from typing import List +from typing import List, Tuple +# TLM constants # prepend constants with _ so that they don't show up in help.cleanlab.ai docs -_DEFAULT_MAX_CONCURRENT_TLM_REQUESTS: int = 16 -_MAX_CONCURRENT_TLM_REQUESTS_LIMIT: int = 128 _VALID_TLM_QUALITY_PRESETS: List[str] = ["best", "high", "medium", "low", "base"] +_VALID_TLM_MODELS: List[str] = ["gpt-3.5-turbo-16k", "gpt-4"] +_TLM_MAX_RETRIES: int = 3 # TODO: finalize this number +TLM_MAX_TOKEN_RANGE: Tuple[int, int] = (64, 512) # (min, max) +TLM_NUM_CANDIDATE_RESPONSES_RANGE: Tuple[int, int] = (1, 20) # (min, max) +TLM_NUM_CONSISTENCY_SAMPLES_RANGE: Tuple[int, int] = (0, 20) # (min, max) diff --git a/cleanlab_studio/internal/tlm/__init__.py b/cleanlab_studio/internal/tlm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cleanlab_studio/internal/tlm/concurrency.py b/cleanlab_studio/internal/tlm/concurrency.py new file mode 100644 index 00000000..14065a7e --- /dev/null +++ b/cleanlab_studio/internal/tlm/concurrency.py @@ -0,0 +1,113 @@ +import asyncio +from types import TracebackType +from typing import Optional, Type + +from cleanlab_studio.errors import RateLimitError, TlmServerError + + +class TlmRateHandler: + """Concurrency handler for TLM queries. + + Implements additive increase / multiplicative decrease congestion control algorithm. + """ + + DEFAULT_CONGESTION_WINDOW: int = 4 + DEFAULT_SLOW_START_THRESHOLD: int = 16 + + SLOW_START_INCREASE_FACTOR: int = 2 + ADDITIVE_INCREMENT: int = 1 + MULTIPLICATIVE_DECREASE_FACTOR: int = 2 + + MAX_CONCURRENT_REQUESTS: int = 512 + + def __init__( + self, + congestion_window: int = DEFAULT_CONGESTION_WINDOW, + slow_start_threshold: int = DEFAULT_SLOW_START_THRESHOLD, + ): + """Initializes TLM rate handler.""" + self._congestion_window: int = congestion_window + self._slow_start_threshold = slow_start_threshold + + # create send semaphore and seed w/ initial congestion window + self._send_semaphore = asyncio.Semaphore(value=self._congestion_window) + + async def __aenter__(self) -> None: + """Acquires send semaphore, blocking until it can be acquired.""" + await self._send_semaphore.acquire() + return + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + traceback_type: Optional[TracebackType], + ) -> bool: + """Handles exiting from rate limit context. Never suppresses exceptions. + + If request succeeded, increase congestion window. + If request failed due to rate limit error, decrease congestion window. + If request failed due to 503, decrease congestion window. + Else if request failed for other reason, don't change congestion window, just exit. + """ + if exc_type is None: + await self._increase_congestion_window() + + elif ( + isinstance(exc, RateLimitError) + or isinstance(exc, TlmServerError) + and exc.status_code == 503 + ): + await self._decrease_congestion_window() + + # release acquired send semaphore from aenter + self._send_semaphore.release() + + return False + + async def _increase_congestion_window( + self, + slow_start_increase_factor: int = SLOW_START_INCREASE_FACTOR, + additive_increment: int = ADDITIVE_INCREMENT, + max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS, + ) -> None: + """Increases TLM congestion window + + If in slow start, increase is exponential. + Otherwise, increase is linear. + + After increasing congestion window, notify on send condition with n=increase + """ + # track previous congestion window size + prev_congestion_window = self._congestion_window + + # increase congestion window + if self._congestion_window < self._slow_start_threshold: + self._congestion_window *= slow_start_increase_factor + + else: + self._congestion_window += additive_increment + + # cap congestion window at max concurrent requests + self._congestion_window = min(self._congestion_window, max_concurrent_requests) + + # release from send semaphore + congestion_window_increase = self._congestion_window - prev_congestion_window + for _ in range(congestion_window_increase): + self._send_semaphore.release() + + async def _decrease_congestion_window( + self, + multiplicative_decrease_factor: int = MULTIPLICATIVE_DECREASE_FACTOR, + ) -> None: + """Decreases TLM congestion window, to minimum of 1.""" + if self._congestion_window <= 1: + return + + prev_congestion_window = self._congestion_window + self._congestion_window //= multiplicative_decrease_factor + + # acquire congestion window decrease from send semaphore + congestion_window_decrease = prev_congestion_window - self._congestion_window + for _ in range(congestion_window_decrease): + await self._send_semaphore.acquire() diff --git a/cleanlab_studio/internal/tlm/validation.py b/cleanlab_studio/internal/tlm/validation.py new file mode 100644 index 00000000..191e5bed --- /dev/null +++ b/cleanlab_studio/internal/tlm/validation.py @@ -0,0 +1,180 @@ +import os +from typing import Union, Sequence, Any +from cleanlab_studio.errors import ValidationError +from cleanlab_studio.internal.constants import ( + _VALID_TLM_MODELS, + TLM_MAX_TOKEN_RANGE, + TLM_NUM_CANDIDATE_RESPONSES_RANGE, + TLM_NUM_CONSISTENCY_SAMPLES_RANGE, +) + + +SKIP_VALIDATE_TLM_OPTIONS: bool = ( + os.environ.get("CLEANLAB_STUDIO_SKIP_VALIDATE_TLM_OPTIONS", "false").lower() == "true" +) + + +def validate_tlm_prompt(prompt: Union[str, Sequence[str]]) -> None: + if isinstance(prompt, str): + return + + elif isinstance(prompt, Sequence): + if any(not isinstance(p, str) for p in prompt): + raise ValidationError( + "Some items in prompt are of invalid types, all items in the prompt list must be of type str." + ) + + else: + raise ValidationError( + f"Invalid type {type(prompt)}, prompt must either be strings or list/iterable of strings." + ) + + +def validate_tlm_try_prompt(prompt: Sequence[str]) -> None: + if isinstance(prompt, str): + raise ValidationError(f"Invalid type str, prompt must be a list/iterable of strings.") + + elif isinstance(prompt, Sequence): + if any(not isinstance(p, str) for p in prompt): + raise ValidationError( + "Some items in prompt are of invalid types, all items in the prompt list must be of type str." + ) + + else: + raise ValidationError( + f"Invalid type {type(prompt)}, prompt must be a list/iterable of strings." + ) + + +def validate_tlm_prompt_response( + prompt: Union[str, Sequence[str]], response: Union[str, Sequence[str]] +) -> None: + if isinstance(prompt, str): + if not isinstance(response, str): + raise ValidationError( + "response type must match prompt type. " + f"prompt was provided as str but response is of type {type(response)}" + ) + + elif isinstance(prompt, Sequence): + if not isinstance(response, Sequence): + raise ValidationError( + "response type must match prompt type. " + f"prompt was provided as type {type(prompt)} but response is of type {type(response)}" + ) + + if len(prompt) != len(response): + raise ValidationError("Length of the prompt and response lists must match.") + + if any(not isinstance(p, str) for p in prompt): + raise ValidationError( + "Some items in prompt are of invalid types, all items in the prompt list must be of type str." + ) + if any(not isinstance(r, str) for r in response): + raise ValidationError( + "Some items in response are of invalid types, all items in the response list must be of type str." + ) + + else: + raise ValidationError( + f"Invalid type {type(prompt)}, prompt must either be strings or list/iterable of strings." + ) + + +def validate_try_tlm_prompt_response(prompt: Sequence[str], response: Sequence[str]) -> None: + if isinstance(prompt, str): + raise ValidationError(f"Invalid type str, prompt must be a list/iterable of strings.") + + elif isinstance(prompt, Sequence): + if not isinstance(response, Sequence): + raise ValidationError( + "response type must match prompt type. " + f"prompt was provided as type {type(prompt)} but response is of type {type(response)}" + ) + + if len(prompt) != len(response): + raise ValidationError("Length of the prompt and response lists must match.") + + if any(not isinstance(p, str) for p in prompt): + raise ValidationError( + "Some items in prompt are of invalid types, all items in the prompt list must be of type str." + ) + if any(not isinstance(r, str) for r in response): + raise ValidationError( + "Some items in response are of invalid types, all items in the response list must be of type str." + ) + + else: + raise ValidationError( + f"Invalid type {type(prompt)}, prompt must be a list/iterable of strings." + ) + + +def validate_tlm_options(options: Any) -> None: + from cleanlab_studio.studio.trustworthy_language_model import TLMOptions + + if SKIP_VALIDATE_TLM_OPTIONS: + return + + if not isinstance(options, dict): + raise ValidationError( + "options must be a TLMOptions object.\n" + "See: https://help.cleanlab.ai/reference/python/trustworthy_language_model/#class-tlmoptions" + ) + + invalid_keys = set(options.keys()) - set(TLMOptions.__annotations__.keys()) + if invalid_keys: + raise ValidationError( + f"Invalid keys in options dictionary: {invalid_keys}.\n" + "See https://help.cleanlab.ai/reference/python/trustworthy_language_model/#class-tlmoptions for valid options" + ) + + for option, val in options.items(): + if option == "max_tokens": + if not isinstance(val, int): + raise ValidationError(f"Invalid type {type(val)}, max_tokens must be an integer") + + if val < TLM_MAX_TOKEN_RANGE[0] or val > TLM_MAX_TOKEN_RANGE[1]: + raise ValidationError( + f"Invalid value {val}, max_tokens must be in the range {TLM_MAX_TOKEN_RANGE}" + ) + + elif option == "model": + if val not in _VALID_TLM_MODELS: + raise ValidationError( + f"{val} is not a supported model, valid models include: {_VALID_TLM_MODELS}" + ) + + elif option == "num_candidate_responses": + if not isinstance(val, int): + raise ValidationError( + f"Invalid type {type(val)}, num_candidate_responses must be an integer" + ) + + if ( + val < TLM_NUM_CANDIDATE_RESPONSES_RANGE[0] + or val > TLM_NUM_CANDIDATE_RESPONSES_RANGE[1] + ): + raise ValidationError( + f"Invalid value {val}, num_candidate_responses must be in the range {TLM_NUM_CANDIDATE_RESPONSES_RANGE}" + ) + + elif option == "num_consistency_samples": + if not isinstance(val, int): + raise ValidationError( + f"Invalid type {type(val)}, num_consistency_samples must be an integer" + ) + + if ( + val < TLM_NUM_CONSISTENCY_SAMPLES_RANGE[0] + or val > TLM_NUM_CONSISTENCY_SAMPLES_RANGE[1] + ): + raise ValidationError( + f"Invalid value {val}, num_consistency_samples must be in the range {TLM_NUM_CONSISTENCY_SAMPLES_RANGE}" + ) + + elif option == "use_self_reflection": + if not isinstance(val, bool): + raise ValidationError( + f"Invalid type {type(val)}, use_self_reflection must be a boolean" + ) diff --git a/cleanlab_studio/studio/studio.py b/cleanlab_studio/studio/studio.py index 93f0c09b..4e7122e2 100644 --- a/cleanlab_studio/studio/studio.py +++ b/cleanlab_studio/studio/studio.py @@ -385,20 +385,41 @@ def download_embeddings( def TLM( self, - *, quality_preset: TLMQualityPreset = "medium", - **kwargs: Any, + *, + options: Optional[trustworthy_language_model.TLMOptions] = None, + timeout: Optional[float] = None, + verbose: Optional[bool] = None, ) -> trustworthy_language_model.TLM: - """Gets Trustworthy Language Model (TLM) object to prompt. + """Gets a configured instance of Trustworthy Language Model (TLM). + + The returned TLM object can then be used as a drop-in replacement for an LLM, for estimating trustworthiness scores for LLM prompt/response pairs, and more. See the documentation for the [TLM](../trustworthy_language_model#class-TLM) class for more on what you can do with TLM. + + For advanced use cases, TLM supports a number of configuration options. The documentation below summarizes the options, and the [TLM tutorial](/tutorials/tlm) explains the tradeoffs in more detail. Args: - quality_preset: quality preset to use for prompts - kwargs (Any): additional kwargs to pass to TLM class + quality_preset (TLMQualityPreset): quality preset to use for TLM queries, which will determine the quality of the output responses and trustworthiness scores. + Supported presets include "best", "high", "medium", "low", "base". + The "best" and "high" presets will improve the LLM responses themselves, with "best" also returning the most reliable trustworthiness scores. + The "medium" and "low" presets will return standard LLM responses along with associated confidence scores, + with "medium" producing more reliable trustworthiness scores than low. + The "base" preset will not return any confidence score, just a standard LLM output response, this option is similar to using your favorite LLM API. + Higher presets have increased runtime and cost. + + options (TLMOptions, optional): a typed dict of advanced configuration options. + Options that can be passed in include "model", "max_tokens", "num_candidate_responses", "num_consistency_samples", "use_self_reflection". + For more details about the options, see the documentation for [TLMOptions](../trustworthy_language_model#class-tlmoptions). + + timeout (float, optional): timeout (in seconds) to apply to each method call. If a result is not produced within the timeout, a TimeoutError will be raised. Defaults to None, which does not apply a timeout. + + verbose (bool, optional): whether to run in verbose mode, i.e., whether to show a tqdm progress bar when TLM is prompted with batches of data. If None, this will be determined automatically based on whether the code is running in an interactive environment such as a notebook. Returns: TLM: the [Trustworthy Language Model](../trustworthy_language_model#class-tlm) object """ - return trustworthy_language_model.TLM(self._api_key, quality_preset, **kwargs) + return trustworthy_language_model.TLM( + self._api_key, quality_preset, options=options, timeout=timeout, verbose=verbose + ) def poll_cleanset_status(self, cleanset_id: str, timeout: Optional[int] = None) -> bool: """ diff --git a/cleanlab_studio/studio/trustworthy_language_model.py b/cleanlab_studio/studio/trustworthy_language_model.py index 0a60ff08..f442a337 100644 --- a/cleanlab_studio/studio/trustworthy_language_model.py +++ b/cleanlab_studio/studio/trustworthy_language_model.py @@ -1,401 +1,579 @@ """ -Cleanlab TLM is a Large Language Model that gives more reliable answers and quantifies its uncertainty in these answers +Cleanlab's Trustworthy Language Model (TLM) is a large language model that gives more reliable answers and quantifies its uncertainty in these answers. + +**This module is not meant to be imported and used directly.** Instead, use [`Studio.TLM()`](/reference/python/studio/#method-tlm) to instantiate a [TLM](#class-TLM) object, and then you can use the methods like [`prompt()`](#method-prompt) and [`get_trustworthiness_score()`](#method-get_trustworthiness_score) documented in this page. + +The [Trustworthy Language Model tutorial](/tutorials/tlm/) further explains TLM and its use cases. """ from __future__ import annotations import asyncio import sys -from typing import Coroutine, List, Optional, Union, cast +from typing import Coroutine, List, Optional, Union, cast, Sequence +from tqdm.asyncio import tqdm_asyncio import aiohttp from typing_extensions import NotRequired, TypedDict # for Python <3.11 with (Not)Required from cleanlab_studio.internal.api import api -from cleanlab_studio.internal.types import JSONDict, TLMQualityPreset +from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler +from cleanlab_studio.internal.tlm.validation import ( + validate_tlm_prompt, + validate_tlm_try_prompt, + validate_tlm_prompt_response, + validate_try_tlm_prompt_response, + validate_tlm_options, +) +from cleanlab_studio.internal.types import TLMQualityPreset +from cleanlab_studio.errors import ValidationError from cleanlab_studio.internal.constants import ( - _DEFAULT_MAX_CONCURRENT_TLM_REQUESTS, - _MAX_CONCURRENT_TLM_REQUESTS_LIMIT, _VALID_TLM_QUALITY_PRESETS, + _TLM_MAX_RETRIES, ) -class TLMResponse(TypedDict): - """Trustworthy Language Model response. - - Attributes: - response (str): text response from language model - confidence_score (float): score corresponding to confidence that the response is correct - """ - - response: str - confidence_score: float - - -class TLMOptions(TypedDict): - """Trustworthy language model options. The TLM quality-preset determines many of these settings automatically, but - specifying other values here will over-ride the setting from the quality-preset. - - Args: - max_tokens (int, default = 512): the maximum number of tokens to generate in the TLM response. - - max_timeout (int, optional): the maximum timeout to query from TLM in seconds. If a max_timeout is not specified, then timeout is calculated based on number of tokens. - - num_candidate_responses (int, default = 1): this controls how many candidate responses are internally generated. - TLM scores the confidence of each candidate response, and then returns the most confident one. - A higher value here can produce better (more accurate) responses from the TLM, but at higher costs/runtimes. - - num_consistency_samples (int, default = 5): this controls how many samples are internally generated to evaluate the LLM-response-consistency. - This is a big part of the returned confidence_score, in particular for ensuring lower scores for strange input prompts or those that are too open-ended to receive a well-defined 'good' response. - Higher values here produce better (more reliable) TLM confidence scores, but at higher costs/runtimes. - - use_self_reflection (bool, default = `True`): this controls whether self-reflection is used to have the LLM reflect upon the response it is generating and explicitly self-evaluate whether it seems good or not. - This is a big part of the confidence score, in particular for ensure low scores for responses that are obviously incorrect/bad for a standard prompt that LLMs should be able to handle. - Setting this to False disables the use of self-reflection and may produce worse TLM confidence scores, but can reduce costs/runtimes. - - model (str, default = "gpt-3.5-turbo-16k"): ID of the model to use. Other options: "gpt-4" +class TLM: + """Represents a Trustworthy Language Model (TLM) instance, bound to a Cleanlab Studio account. + TLM should be configured and instantiated using the [`Studio.TLM()`](../studio/#method-tlm) method. Then, using the TLM object, you can [`prompt()`](#method-prompt) the language model, etc. """ - max_tokens: NotRequired[int] - model: NotRequired[str] - max_timeout: NotRequired[int] - num_candidate_responses: NotRequired[int] - num_consistency_samples: NotRequired[int] - use_self_reflection: NotRequired[bool] - - -class TLM: - """TLM interface class.""" - def __init__( self, api_key: str, quality_preset: TLMQualityPreset, - max_concurrent_requests: int = _DEFAULT_MAX_CONCURRENT_TLM_REQUESTS, + *, + options: Optional[TLMOptions] = None, + timeout: Optional[float] = None, + verbose: Optional[bool] = None, ) -> None: - """Initializes TLM interface. + """Initializes a Trustworthy Language Model. - Args: - api_key (str): API key used to authenticate TLM client - quality_preset (TLMQualityPreset): quality preset to use for TLM queries - max_concurrent_requests (int): maximum number of concurrent requests when issuing batch queries. Default is 16. - """ + **Objects of this class are not meant to be constructed directly.** Instead, use [`Studio.TLM()`](../studio/#method-tlm), whose documentation also explains the different configuration options.""" self._api_key = api_key - assert ( - max_concurrent_requests < _MAX_CONCURRENT_TLM_REQUESTS_LIMIT - ), f"max_concurrent_requests must be less than {_MAX_CONCURRENT_TLM_REQUESTS_LIMIT}" - if quality_preset not in _VALID_TLM_QUALITY_PRESETS: - raise ValueError( + raise ValidationError( f"Invalid quality preset {quality_preset} -- must be one of {_VALID_TLM_QUALITY_PRESETS}" ) + if options is not None: + validate_tlm_options(options) + + if timeout is not None and not (isinstance(timeout, int) or isinstance(timeout, float)): + raise ValidationError("timeout must be a integer or float value") + + if verbose is not None and not isinstance(verbose, bool): + raise ValidationError("verbose must be a boolean value") + + is_notebook_flag = is_notebook() + self._quality_preset = quality_preset + self._options = options + self._timeout = timeout if timeout is not None and timeout > 0 else None + self._verbose = verbose if verbose is not None else is_notebook_flag - if is_notebook(): + if is_notebook_flag: import nest_asyncio nest_asyncio.apply() self._event_loop = asyncio.get_event_loop() - self._query_semaphore = asyncio.Semaphore(max_concurrent_requests) + self._rate_handler = TlmRateHandler() - def _batch_prompt( + async def _batch_prompt( self, - prompts: List[str], - options: Union[None, TLMOptions, List[Union[TLMOptions, None]]] = None, - timeout: Optional[float] = None, - retries: int = 1, - ) -> List[TLMResponse]: - """Run batch of TLM prompts. + prompts: Sequence[str], + capture_exceptions: bool = False, + ) -> Union[List[TLMResponse], List[Optional[TLMResponse]]]: + """Run batch of TLM prompts. The list returned will have the same length as the input list. + + If capture_exceptions is True, the list will contain None in place of the response for any errors or timeout processing some inputs. + Otherwise, the method will raise an exception for any errors or timeout processing some inputs. Args: prompts (List[str]): list of prompts to run - options (None | TLMOptions | List[TLMOptions | None], optional): list of options (or instance of options) to pass to prompt method. Defaults to None. - timeout (Optional[float], optional): timeout (in seconds) to run all prompts. Defaults to None. - retries (int): number of retries to attempt for each individual prompt in case of error. Defaults to 1. + capture_exceptions (bool): if should return None in place of the response for any errors or timeout processing some inputs Returns: - List[TLMResponse]: TLM responses for each prompt (in supplied order) + Union[List[TLMResponse], List[Optional[TLMResponse]]]: TLM responses for each prompt (in supplied order) """ - if isinstance(options, list): - options_collection = options + if capture_exceptions: + per_query_timeout, per_batch_timeout = self._timeout, None else: - options = cast(Union[None, TLMOptions], options) - options_collection = [options for _ in prompts] - - assert len(prompts) == len(options_collection), "Length of prompts and options must match." - - tlm_responses = self._event_loop.run_until_complete( - self._batch_async( - [ - self.prompt_async( - prompt, - option_dict, - retries=retries, - ) - for prompt, option_dict in zip(prompts, options_collection) - ], - timeout=timeout, - ) + per_query_timeout, per_batch_timeout = None, self._timeout + + # run batch of TLM + tlm_responses = await self._batch_async( + [ + self._prompt_async( + prompt, + timeout=per_query_timeout, + capture_exceptions=capture_exceptions, + batch_index=batch_index, + ) + for batch_index, prompt in enumerate(prompts) + ], + per_batch_timeout, ) + if capture_exceptions: + return cast(List[Optional[TLMResponse]], tlm_responses) + return cast(List[TLMResponse], tlm_responses) - def _batch_get_confidence_score( + async def _batch_get_trustworthiness_score( self, - prompts: List[str], - responses: List[str], - options: Union[None, TLMOptions, List[Union[TLMOptions, None]]] = None, - timeout: Optional[float] = None, - retries: int = 1, - ) -> List[float]: + prompts: Sequence[str], + responses: Sequence[str], + capture_exceptions: bool = False, + ) -> Union[List[float], List[Optional[float]]]: """Run batch of TLM get confidence score. + capture_exceptions behavior: + - If true, the list will contain None in place of the response for any errors or timeout processing some inputs. + - Otherwise, the method will raise an exception for any errors or timeout processing some inputs. + + capture_exceptions interaction with timeout: + - If true, timeouts are applied on a per-query basis (i.e. some queries may succeed while others fail) + - If false, a single timeout is applied to the entire batch (i.e. all queries will fail if the timeout is reached) + Args: - prompts (List[str]): list of prompts to run get confidence score for - responses (List[str]): list of responses to run get confidence score for - options (None | TLMOptions | List[TLMOptions | None], optional): list of options (or instance of options) to pass to get confidence score method. Defaults to None. - timeout (Optional[float], optional): timeout (in seconds) to run all prompts. Defaults to None. - retries (int): number of retries to attempt for each individual prompt in case of error. Defaults to 1. + prompts (Sequence[str]): list of prompts to run get confidence score for + responses (Sequence[str]): list of responses to run get confidence score for + capture_exceptions (bool): if should return None in place of the response for any errors or timeout processing some inputs Returns: - List[float]: TLM confidence score for each prompt (in supplied order) + Union[List[float], List[Optional[float]]]: TLM confidence score for each prompt (in supplied order) """ - if isinstance(options, list): - options_collection = options + if capture_exceptions: + per_query_timeout, per_batch_timeout = self._timeout, None else: - options = cast(Union[None, TLMOptions], options) - options_collection = [options for _ in prompts] - - assert len(prompts) == len(responses), "Length of prompts and responses must match." - assert len(prompts) == len(options_collection), "Length of prompts and options must match." - - tlm_responses = self._event_loop.run_until_complete( - self._batch_async( - [ - self.get_confidence_score_async( - prompt, - response, - option_dict, - retries=retries, - ) - for prompt, response, option_dict in zip(prompts, responses, options_collection) - ], - timeout=timeout, - ) + per_query_timeout, per_batch_timeout = None, self._timeout + + # run batch of TLM get confidence score + tlm_responses = await self._batch_async( + [ + self._get_trustworthiness_score_async( + prompt, + response, + timeout=per_query_timeout, + capture_exceptions=capture_exceptions, + batch_index=batch_index, + ) + for batch_index, (prompt, response) in enumerate(zip(prompts, responses)) + ], + per_batch_timeout, ) + if capture_exceptions: + return cast(List[Optional[float]], tlm_responses) + return cast(List[float], tlm_responses) async def _batch_async( self, - tlm_coroutines: List[ - Union[Coroutine[None, None, TLMResponse], Coroutine[None, None, float]] - ], - timeout: Optional[float], - ) -> Union[List[TLMResponse], List[float]]: + tlm_coroutines: Sequence[Coroutine[None, None, Union[TLMResponse, float, None]]], + batch_timeout: Optional[float] = None, + ) -> Sequence[Union[TLMResponse, float, None]]: + """Runs batch of TLM queries. + + Args: + tlm_coroutines (List[Coroutine[None, None, Union[TLMResponse, float, None]]]): list of query coroutines to run, returning TLM responses or confidence scores (or None if capture_exceptions is True) + batch_timeout (Optional[float], optional): timeout (in seconds) to run all queries, defaults to None (no timeout) + + Returns: + Sequence[Union[TLMResponse, float, None]]: list of coroutine results, with preserved order + """ tlm_query_tasks = [asyncio.create_task(tlm_coro) for tlm_coro in tlm_coroutines] - return await asyncio.wait_for(asyncio.gather(*tlm_query_tasks), timeout=timeout) # type: ignore[arg-type] + if self._verbose: + gather_task = tqdm_asyncio.gather( + *tlm_query_tasks, + total=len(tlm_query_tasks), + desc="Querying TLM...", + bar_format="{desc} {percentage:3.0f}%|{bar}|", + ) + else: + gather_task = asyncio.gather(*tlm_query_tasks) + + wait_task = asyncio.wait_for(gather_task, timeout=batch_timeout) + try: + return cast( + Sequence[Union[TLMResponse, float, None]], + await wait_task, + ) + except Exception: + # if exception occurs while awaiting batch results, cancel remaining tasks + for query_task in tlm_query_tasks: + query_task.cancel() + + # await remaining tasks to ensure they are cancelled + await asyncio.gather(*tlm_query_tasks, return_exceptions=True) + + raise def prompt( self, - prompt: Union[str, List[str]], - options: Union[None, TLMOptions, List[Union[TLMOptions, None]]] = None, - timeout: Optional[float] = None, - retries: int = 1, + prompt: Union[str, Sequence[str]], + /, ) -> Union[TLMResponse, List[TLMResponse]]: """ - Get response and confidence from TLM. + Gets response and trustworthiness score for any text input. + + This method prompts the TLM with the given prompt(s), producing completions (like a standard LLM) + but also provides trustworthiness scores quantifying the quality of the output. Args: - prompt (str | List[str]): prompt (or list of multiple prompts) for the TLM - options (None | TLMOptions | List[TLMOptions | None], optional): list of options (or instance of options) to pass to prompt method. Defaults to None. - timeout (Optional[float], optional): timeout (in seconds) to run all prompts. Defaults to None. - If the timeout is hit, this method will throw a `TimeoutError`. - Larger values give TLM a higher chance to return outputs for all of your prompts. - Smaller values ensure this method does not take too long. - retries (int): number of retries to attempt for each individual prompt in case of internal error. Defaults to 1. - Larger values give TLM a higher chance of returning outputs for all of your prompts, - but this method will also take longer to alert you in cases of an unrecoverable error. - Set to 0 to never attempt any retries. + prompt (str | Sequence[str]): prompt (or list of multiple prompts) for the language model Returns: - TLMResponse | List[TLMResponse]: [TLMResponse](#class-tlmresponse) object containing the response and confidence score. - If multiple prompts were provided in a list, then a list of such objects is returned, one for each prompt. + TLMResponse | List[TLMResponse]: [TLMResponse](#class-tlmresponse) object containing the response and trustworthiness score. + If multiple prompts were provided in a list, then a list of such objects is returned, one for each prompt. + This method will raise an exception if any errors occur or if you hit a timeout (given a timeout is specified), + and is suitable if strict error handling and immediate notification of any exceptions/timeouts is preferred. + However, you could lose any partial results if an exception is raised. + If saving partial results is important to you, you can call this method on smaller chunks of data at a time + (and save intermediate results as desired); you can also consider using the more advanced + [`try_prompt()`](#method-try_prompt) method instead. """ - if isinstance(prompt, list): - if any(not isinstance(p, str) for p in prompt): - raise ValueError("All prompts must be strings.") + validate_tlm_prompt(prompt) - return self._batch_prompt( - prompt, - options, - timeout=timeout, - retries=retries, + if isinstance(prompt, str): + return cast( + TLMResponse, + self._event_loop.run_until_complete( + self._prompt_async(prompt, timeout=self._timeout, capture_exceptions=False), + ), ) - elif isinstance(prompt, str): - if not (options is None or isinstance(options, dict)): - raise ValueError( - "options must be a single TLMOptions object for single prompt.\n" - "See: https://help.cleanlab.ai/reference/python/trustworthy_language_model/#class-tlmoptions" - ) + return cast( + List[TLMResponse], + self._event_loop.run_until_complete( + self._batch_prompt(prompt, capture_exceptions=False), + ), + ) - return self._event_loop.run_until_complete( - self.prompt_async( - prompt, - cast(Union[None, TLMOptions], options), - retries=retries, - ) - ) + def try_prompt( + self, + prompt: Sequence[str], + /, + ) -> List[Optional[TLMResponse]]: + """ + Gets response and trustworthiness score for any text input, + handling any failures (errors of timeouts) by returning None in place of the failures. - else: - raise ValueError("prompt must be a string or list of strings.") + The list returned will have the same length as the input list, if there are any + failures (errors or timeout) processing some inputs, the list will contain None in place of the response. + + If there are any failures (errors or timeouts) processing some inputs, the list returned will have + the same length as the input list. In case of failure, the list will contain None in place of the response. + + Args: + prompt (Sequence[str]): list of multiple prompts for the TLM + Returns: + List[Optional[TLMResponse]]: list of [TLMResponse](#class-tlmresponse) objects containing the response and trustworthiness score. + The returned list will always have the same length as the input list. + In case of failure on any prompt (due to timeouts or other erros), + the return list will contain None in place of the TLM response. + This method is suitable if you prioritize obtaining results for as many inputs as possible, + however you might miss out on certain error messages. + If you would prefer to be notified immediately about any errors or timeouts that might occur, + consider using the [`prompt()`](#method-prompt) method instead. + """ + validate_tlm_try_prompt(prompt) + + return cast( + List[Optional[TLMResponse]], + self._event_loop.run_until_complete( + self._batch_prompt(prompt, capture_exceptions=True), + ), + ) async def prompt_async( + self, + prompt: Union[str, Sequence[str]], + /, + ) -> Union[TLMResponse, List[TLMResponse]]: + """ + Asynchronously get response and trustworthiness score for any text input from TLM. + This method is similar to the [`prompt()`](#method-prompt) method but operates asynchronously. + + Args: + prompt (str | Sequence[str]): prompt (or list of multiple prompts) for the TLM + Returns: + TLMResponse | List[TLMResponse]: [TLMResponse](#class-tlmresponse) object containing the response and trustworthiness score. + If multiple prompts were provided in a list, then a list of such objects is returned, one for each prompt. + This method will raise an exception if any errors occur or if you hit a timeout (given a timeout is specified). + """ + validate_tlm_prompt(prompt) + + async with aiohttp.ClientSession() as session: + if isinstance(prompt, str): + tlm_response = await self._prompt_async( + prompt, session, timeout=self._timeout, capture_exceptions=False + ) + return cast(TLMResponse, tlm_response) + + return cast( + List[TLMResponse], + await self._batch_prompt(prompt, capture_exceptions=False), + ) + + async def _prompt_async( self, prompt: str, - options: Optional[TLMOptions] = None, client_session: Optional[aiohttp.ClientSession] = None, - retries: int = 0, - ) -> TLMResponse: + timeout: Optional[float] = None, + capture_exceptions: bool = False, + batch_index: Optional[int] = None, + ) -> Optional[TLMResponse]: """ - (Asynchronously) Get response and confidence from TLM. + Private asynchronous method to get response and trustworthiness score from TLM. Args: prompt (str): prompt for the TLM - options (Optional[TLMOptions]): options to parameterize TLM with. Defaults to None. - client_session (Optional[aiohttp.ClientSession]): async HTTP session to use for TLM query. Defaults to None. - retries (int): number of retries for TLM query. Defaults to 0. + client_session (aiohttp.ClientSession, optional): async HTTP session to use for TLM query. Defaults to None (creates a new session). + timeout: timeout (in seconds) to run the prompt, defaults to None (no timeout) + capture_exceptions: if should return None in place of the response for any errors + batch_index: index of the prompt in the batch, used for error messages Returns: - TLMResponse: [TLMResponse](#class-tlmresponse) object containing the response and confidence score + TLMResponse: [TLMResponse](#class-tlmresponse) object containing the response and trustworthiness score. """ - async with self._query_semaphore: - tlm_response = await api.tlm_prompt( - self._api_key, - prompt, - self._quality_preset, - cast(JSONDict, options), - client_session, - retries=retries, + + try: + tlm_response = await asyncio.wait_for( + api.tlm_prompt( + self._api_key, + prompt, + self._quality_preset, + self._options, + self._rate_handler, + client_session, + batch_index=batch_index, + retries=_TLM_MAX_RETRIES, + ), + timeout=timeout, ) + except Exception as e: + if capture_exceptions: + return None + raise e return { "response": tlm_response["response"], - "confidence_score": tlm_response["confidence_score"], + "trustworthiness_score": tlm_response["confidence_score"], } - def get_confidence_score( + def get_trustworthiness_score( self, - prompt: Union[str, List[str]], - response: Union[str, List[str]], - options: Union[None, TLMOptions, List[Union[TLMOptions, None]]] = None, - timeout: Optional[float] = None, - retries: int = 1, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], ) -> Union[float, List[float]]: - """Gets confidence score for prompt-response pair(s). + """Gets trustworthiness score for prompt-response pairs. Args: - prompt (str | List[str]): prompt (or list of multiple prompts) for the TLM - response (str | List[str]): response (or list of multiple responses) for the TLM to evaluate - options (None | TLMOptions | List[TLMOptions | None], optional): list of options (or instance of options) to pass to get confidence score method. Defaults to None. - timeout (Optional[float], optional): maximum allowed time (in seconds) to run all prompts and evaluate all responses. Defaults to None. - If the timeout is hit, this method will throw a `TimeoutError`. - Larger values give TLM a higher chance to return outputs for all of your prompts + responses. - Smaller values ensure this method does not take too long. - retries (int): number of retries to attempt for each individual prompt in case of internal error. Defaults to 1. - Larger values give TLM a higher chance of returning outputs for all of your prompts, - but this method will also take longer to alert you in cases of an unrecoverable error. - Set to 0 to never attempt any retries. + prompt (str | Sequence[str]): prompt (or list of prompts) for the TLM to evaluate + response (str | Sequence[str]): response (or list of responses) corresponding to the input prompts Returns: - float (or list of floats if multiple prompt-responses were provided) corresponding to the TLM's confidence score. - The score quantifies how confident TLM is that the given response is good for the given prompt. + float | List[float]: float or list of floats (if multiple prompt-responses were provided) corresponding + to the TLM's trustworthiness score. + The score quantifies how confident TLM is that the given response is good for the given prompt. + This method will raise an exception if any errors occur or if you hit a timeout (given a timeout is specified), + and is suitable if strict error handling and immediate notification of any exceptions/timeouts is preferred. + However, you could lose any partial results if an exception is raised. + If saving partial results is important to you, you can call this method on smaller chunks of data at a time + (and save intermediate results as desired); you can also consider using the more advanced + [`try_get_trustworthiness_score()`](#method-try_get_trustworthiness_score) method instead. """ - if isinstance(prompt, list): - if any(not isinstance(p, str) for p in prompt): - raise ValueError("All prompts must be strings.") - if any(not isinstance(r, str) for r in response): - raise ValueError("All responses must be strings.") - - if not isinstance(response, list): - raise ValueError( - "responses must be a list or iterable of strings when prompt is a list or iterable." - ) + validate_tlm_prompt_response(prompt, response) - return self._batch_get_confidence_score( - prompt, - response, - options, - timeout=timeout, - retries=retries, + if isinstance(prompt, str) and isinstance(response, str): + return cast( + float, + self._event_loop.run_until_complete( + self._get_trustworthiness_score_async( + prompt, response, timeout=self._timeout, capture_exceptions=False + ) + ), ) - elif isinstance(prompt, str): - if not (options is None or isinstance(options, dict)): - raise ValueError( - "options must be a single TLMOptions object for single prompt.\n" - "See: https://help.cleanlab.ai/reference/python/trustworthy_language_model/#class-tlmoptions" - ) + return cast( + List[float], + self._event_loop.run_until_complete( + self._batch_get_trustworthiness_score(prompt, response, capture_exceptions=False) + ), + ) - if not isinstance(response, str): - raise ValueError("responses must be a single string for single prompt.") + def try_get_trustworthiness_score( + self, + prompt: Sequence[str], + response: Sequence[str], + ) -> List[Optional[float]]: + """Gets trustworthiness score for prompt-response pairs. + The list returned will have the same length as the input list, if there are any + failures (errors or timeout) processing some inputs, the list will contain None + in place of the response. - return self._event_loop.run_until_complete( - self.get_confidence_score_async( - prompt, - response, - cast(Union[None, TLMOptions], options), - retries=retries, + Args: + prompt (Sequence[str]): list of prompts for the TLM to evaluate + response (Sequence[str]): list of responses corresponding to the input prompts + Returns: + List[float]: list of floats corresponding to the TLM's trustworthiness score. + The score quantifies how confident TLM is that the given response is good for the given prompt. + The returned list will always have the same length as the input list. + In case of failure on any prompt-response pair (due to timeouts or other erros), + the return list will contain None in place of the trustworthiness score. + This method is suitable if you prioritize obtaining results for as many inputs as possible, + however you might miss out on certain error messages. + If you would prefer to be notified immediately about any errors or timeouts that might occur, + consider using the [`get_trustworthiness_score()`](#method-get_trustworthiness_score) method instead. + """ + validate_try_tlm_prompt_response(prompt, response) + + return cast( + List[Optional[float]], + self._event_loop.run_until_complete( + self._batch_get_trustworthiness_score(prompt, response, capture_exceptions=True) + ), + ) + + async def get_trustworthiness_score_async( + self, + prompt: Union[str, Sequence[str]], + response: Union[str, Sequence[str]], + ) -> Union[float, List[float]]: + """Asynchronously gets trustworthiness score for prompt-response pairs. + This method is similar to the [`get_trustworthiness_score()`](#method-get_trustworthiness_score) method but operates asynchronously. + + Args: + prompt (str | Sequence[str]): prompt (or list of prompts) for the TLM to evaluate + response (str | Sequence[str]): response (or list of responses) corresponding to the input prompts + Returns: + float | List[float]: float or list of floats (if multiple prompt-responses were provided) corresponding + to the TLM's trustworthiness score. + The score quantifies how confident TLM is that the given response is good for the given prompt. + This method will raise an exception if any errors occur or if you hit a timeout (given a timeout is specified). + """ + validate_tlm_prompt_response(prompt, response) + + async with aiohttp.ClientSession() as session: + if isinstance(prompt, str) and isinstance(response, str): + trustworthiness_score = await self._get_trustworthiness_score_async( + prompt, response, session, timeout=self._timeout, capture_exceptions=False ) - ) + return cast(float, trustworthiness_score) - else: - raise ValueError("prompt must be a string or list/iterable of strings.") + return cast( + List[float], + await self._batch_get_trustworthiness_score( + prompt, response, capture_exceptions=False + ), + ) - async def get_confidence_score_async( + async def _get_trustworthiness_score_async( self, prompt: str, response: str, - options: Optional[TLMOptions] = None, client_session: Optional[aiohttp.ClientSession] = None, - retries: int = 0, - ) -> float: - """(Asynchronously) gets confidence score for prompt-response pair. + timeout: Optional[float] = None, + capture_exceptions: bool = False, + batch_index: Optional[int] = None, + ) -> Optional[float]: + """Private asynchronous method to get trustworthiness score for prompt-response pairs. Args: - prompt: prompt for the TLM - response: response for the TLM to evaluate - options (Optional[TLMOptions]): options to parameterize TLM with. Defaults to None. - client_session (Optional[aiohttp.ClientSession]): async HTTP session to use for TLM query. Defaults to None. - retries (int): number of retries for TLM query. Defaults to 0. + prompt: prompt for the TLM to evaluate + response: response corresponding to the input prompt + client_session: async HTTP session to use for TLM query. Defaults to None. + timeout: timeout (in seconds) to run the prompt, defaults to None (no timeout) + capture_exceptions: if should return None in place of the response for any errors + batch_index: index of the prompt in the batch, used for error messages Returns: - float corresponding to the TLM's confidence score + float corresponding to the TLM's trustworthiness score """ if self._quality_preset == "base": - raise ValueError( + raise ValidationError( "Cannot get confidence score with `base` quality_preset -- choose a higher preset." ) - async with self._query_semaphore: - return cast( - float, - ( - await api.tlm_get_confidence_score( - self._api_key, - prompt, - response, - self._quality_preset, - cast(JSONDict, options), - client_session, - retries=retries, - ) - )["confidence_score"], + try: + tlm_response = await asyncio.wait_for( + api.tlm_get_confidence_score( + self._api_key, + prompt, + response, + self._quality_preset, + self._options, + self._rate_handler, + client_session, + batch_index=batch_index, + retries=_TLM_MAX_RETRIES, + ), + timeout=timeout, ) + return cast(float, tlm_response["confidence_score"]) + + except Exception as e: + if capture_exceptions: + return None + raise e + + +class TLMResponse(TypedDict): + """A typed dict containing the response and trustworthiness score from the Trustworthy Language Model. + + Attributes: + response (str): text response from the Trustworthy Language Model. + + trustworthiness_score (float, optional): score between 0-1 corresponding to the trustworthiness of the response. + A higher score indicates a higher confidence that the response is correct/trustworthy. The trustworthiness score + is omitted if TLM is run with quality preset "base". + """ + + response: str + trustworthiness_score: Optional[float] + + +class TLMOptions(TypedDict): + """Typed dict containing advanced configuration options for the Trustworthy Language Model. + Many of these arguments are automatically determined by the quality preset selected + (see the arguments in the TLM [initialization method](../studio#method-tlm) to learn more about the various quality presets), + but specifying custom values here will override any default values from the quality preset. + + Args: + model (str, default = "gpt-3.5-turbo-16k"): underlying LLM to use (better models will yield better results). + Models currently supported include "gpt-3.5-turbo-16k", "gpt-4". + + max_tokens (int, default = 512): the maximum number of tokens to generate in the TLM response. + The minimum value for this parameter is 64, and the maximum is 512. + + num_candidate_responses (int, default = 1): this controls how many candidate responses are internally generated. + TLM scores the trustworthiness of each candidate response, and then returns the most trustworthy one. + Higher values here can produce better (more accurate) responses from the TLM, but at higher costs/runtimes. + The minimum value for this parameter is 1, and the maximum is 20. + + num_consistency_samples (int, default = 5): this controls how many samples are internally generated to evaluate the LLM-response-consistency. + This is a big part of the returned trustworthiness_score, in particular to evaluate strange input prompts or prompts that are too open-ended + to receive a clearly defined 'good' response. + Higher values here produce better (more reliable) TLM confidence scores, but at higher costs/runtimes. + The minimum value for this parameter is 0, and the maximum is 20. + + use_self_reflection (bool, default = `True`): this controls whether self-reflection is used to have the LLM reflect upon the response it is + generating and explicitly self-evaluate the accuracy of that response. + This is a big part of the trustworthiness score, in particular for evaluating responses that are obviously incorrect/bad for a + standard prompt (with well-defined answers) that LLMs should be able to handle. + Setting this to False disables the use of self-reflection and may produce worse TLM trustworthiness scores, but will reduce costs/runtimes. + """ + + model: NotRequired[str] + max_tokens: NotRequired[int] + num_candidate_responses: NotRequired[int] + num_consistency_samples: NotRequired[int] + use_self_reflection: NotRequired[bool] + def is_notebook() -> bool: """Returns True if running in a notebook, False otherwise. diff --git a/cleanlab_studio/version.py b/cleanlab_studio/version.py index 7af65835..1ae40d9e 100644 --- a/cleanlab_studio/version.py +++ b/cleanlab_studio/version.py @@ -1,4 +1,4 @@ -__version__ = "1.3.3" +__version__ = "2.0.1" SCHEMA_VERSION = "0.2.0" MIN_SCHEMA_VERSION = "0.1.0" diff --git a/tests/requirements_test.txt b/tests/requirements_test.txt new file mode 100644 index 00000000..a2081532 --- /dev/null +++ b/tests/requirements_test.txt @@ -0,0 +1,3 @@ +pytest +pytest-asyncio +pytest-benchmark diff --git a/tests/tlm/__init__.py b/tests/tlm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/tlm/conftest.py b/tests/tlm/conftest.py new file mode 100644 index 00000000..44b68262 --- /dev/null +++ b/tests/tlm/conftest.py @@ -0,0 +1,32 @@ +import os + +import pytest + +from cleanlab_studio import Studio +from cleanlab_studio.studio.trustworthy_language_model import TLM +from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler + + +@pytest.fixture(scope="module") +def studio() -> Studio: + """Creates a Studio with default settings.""" + try: + # uses environment API key + return Studio(None) + except Exception as e: + environment = os.environ.get("CLEANLAB_API_BASE_URL") + pytest.skip( + f"Failed to create Studio: {e}. Check your API key and environment: ({environment})." + ) + + +@pytest.fixture(scope="module") +def tlm(studio: Studio) -> TLM: + """Creates a TLM with default settings.""" + return studio.TLM() + + +@pytest.fixture +def tlm_rate_handler() -> TlmRateHandler: + """Creates a TlmRateHandler with default settings.""" + return TlmRateHandler() diff --git a/tests/tlm/test_concurrency.py b/tests/tlm/test_concurrency.py new file mode 100644 index 00000000..db28e4b2 --- /dev/null +++ b/tests/tlm/test_concurrency.py @@ -0,0 +1,140 @@ +import math + +import pytest + +from cleanlab_studio.errors import RateLimitError +from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler + + +@pytest.mark.asyncio +async def test_rate_handler_slow_start(tlm_rate_handler: TlmRateHandler) -> None: + """Tests rate handler increase behavior in slow start. + + Expected behavior: + - Limiter increases congestion window exponentially up to slow start threshold. + - Limiter send semaphore value matches congestion window + """ + # compute number of expected slow start increases + expected_slow_start_increases = int( + math.log( + tlm_rate_handler.DEFAULT_SLOW_START_THRESHOLD, + tlm_rate_handler.SLOW_START_INCREASE_FACTOR, + ) + / tlm_rate_handler.DEFAULT_CONGESTION_WINDOW + ) + + # after every rate limiter acquisition, assert: + # - congestion window *= SLOW_START_INCREASE_FACTOR + # - congestion window == send_semaphore value + for i in range(1, expected_slow_start_increases + 1): + async with tlm_rate_handler: + pass + + expected_congestion_window = tlm_rate_handler.DEFAULT_CONGESTION_WINDOW * ( + tlm_rate_handler.SLOW_START_INCREASE_FACTOR**i + ) + assert ( + tlm_rate_handler._congestion_window == expected_congestion_window + ), "Congestion window is not increased exponentially in slow start" + assert ( + tlm_rate_handler._send_semaphore._value == tlm_rate_handler._congestion_window + ), "Send semaphore value does not match congestion window in slow start" + + +@pytest.mark.asyncio +async def test_rate_handler_additive_increase( + tlm_rate_handler: TlmRateHandler, num_additive_increases: int = 100 +) -> None: + """Tests rate handler increase behavior in congestion control / additive increase phase. + + Expected behavior: + - Limiter increases congestion window linearly beyond slow start window + - Limiter send semaphore value matches congestion window + """ + # arrange -- skip past slow start phase + current_limit_value = tlm_rate_handler.DEFAULT_SLOW_START_THRESHOLD + tlm_rate_handler._congestion_window = current_limit_value + tlm_rate_handler._send_semaphore._value = current_limit_value + + # after every rate limiter acquisition, assert: + # - congestion window *= SLOW_START_INCREASE_FACTOR + # - congestion window == send_semaphore value + for expected_limit_value in range( + current_limit_value + 1, num_additive_increases + current_limit_value + 1 + ): + async with tlm_rate_handler: + pass + + assert ( + tlm_rate_handler._congestion_window == expected_limit_value + ), "Congestion window is not increased linearly in congestion control" + assert ( + tlm_rate_handler._send_semaphore._value == tlm_rate_handler._congestion_window + ), "Send semaphore value does not match congestion window in congestion control" + + +@pytest.mark.parametrize("initial_congestion_window", [4, 5, 10, 101]) +@pytest.mark.asyncio +async def test_rate_handler_rate_limit_error( + tlm_rate_handler: TlmRateHandler, + initial_congestion_window: int, +) -> None: + """Tests rate handler decrease behavior on a rate limit error. + + Expected behavior: + - Limiter decreases congestion window multiplicatively + - Limiter send semaphore value matches congestion window + - RateLimitError is raised (not suppressed by context manager) + """ + # arrange -- set current congestion window + tlm_rate_handler._congestion_window = initial_congestion_window + tlm_rate_handler._send_semaphore._value = initial_congestion_window + + # acquire rate limit and raise rate limit error, check that: + # - congestion window is decreased multiplicatively + # - send semaphore value matches congestion window + # - rate limit error is raised + with pytest.raises(RateLimitError): + async with tlm_rate_handler: + raise RateLimitError("", 0) + + assert ( + tlm_rate_handler._congestion_window + == initial_congestion_window // tlm_rate_handler.MULTIPLICATIVE_DECREASE_FACTOR + ), "Congestion window is not decreased multiplicatively in congestion avoidance" + assert ( + tlm_rate_handler._send_semaphore._value == tlm_rate_handler._congestion_window + ), "Send semaphore value does not match congestion window in congestion avoidance" + + +@pytest.mark.parametrize("initial_congestion_window", [4, 5, 10, 101]) +@pytest.mark.asyncio +async def test_rate_handler_non_rate_limit_error( + tlm_rate_handler: TlmRateHandler, + initial_congestion_window: int, +) -> None: + """Tests rate handler decrease behavior on a NON rate limit error. + + Expected behavior: + - Limiter congestion window stays the same + - Limiter send semaphore value matches congestion window + - error is raised (not suppressed by context manager) + """ + # arrange -- set current congestion window + tlm_rate_handler._congestion_window = initial_congestion_window + tlm_rate_handler._send_semaphore._value = initial_congestion_window + + # acquire rate limit and raise rate limit error, check that: + # - congestion window is decreased multiplicatively + # - send semaphore value matches congestion window + # - rate limit error is raised + with pytest.raises(ValueError): + async with tlm_rate_handler: + raise ValueError + + assert ( + tlm_rate_handler._congestion_window == initial_congestion_window + ), "Congestion window is kept same for non rate limit error" + assert ( + tlm_rate_handler._send_semaphore._value == tlm_rate_handler._congestion_window + ), "Send semaphore value does not match congestion window after non rate limit error" diff --git a/tests/tlm/test_get_trustworthiness_score.py b/tests/tlm/test_get_trustworthiness_score.py new file mode 100644 index 00000000..fc97ed43 --- /dev/null +++ b/tests/tlm/test_get_trustworthiness_score.py @@ -0,0 +1,127 @@ +import asyncio +from typing import Any + +import pytest + +from cleanlab_studio.studio.trustworthy_language_model import TLM + + +def is_trustworthiness_score(response: Any) -> bool: + """Returns True if the response is a trustworthiness score.""" + return isinstance(response, float) + + +def test_single_get_trustworthiness_score(tlm: TLM) -> None: + """Tests running a single get_trustworthiness_score in the TLM. + + Expected: + - TLM should return a single response + - Response should be non-None + - No exceptions are raised + """ + # act -- run a single get_trustworthiness_score + response = tlm.get_trustworthiness_score("What is the capital of France?", "Paris") + + # assert + # - response is not None + # - a single response of type TLMResponse is returned + # - no exceptions are raised (implicit) + assert response is not None + assert is_trustworthiness_score(response) + + +def test_batch_get_trustworthiness_score(tlm: TLM) -> None: + """Tests running a batch get_trustworthiness_score in the TLM. + + Expected: + - TLM should return a list of responses + - Responses should be non-None + - No exceptions are raised + - Each response should be of type TLMResponse + """ + # act -- run a batch get_trustworthiness_score + response = tlm.get_trustworthiness_score( + ["What is the capital of France?"] * 3, + ["Paris"] * 3, + ) + + # assert + # - response is not None + # - a list of responses of type TLMResponse is returned + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(is_trustworthiness_score(r) for r in response) + + +def test_batch_get_trustworthiness_score_force_timeouts(tlm: TLM) -> None: + """Tests running a batch get_trustworthiness_score in the TLM, forcing timeouts. + + Sets timeout to 0.0001 seconds, which should force a timeout for all get_trustworthiness_scores. + This should result in a timeout error being thrown + + Expected: + - TLM should raise a timeout error + """ + # arrange -- override timeout + tlm._timeout = 0.0001 + + # assert -- timeout is thrown + with pytest.raises(asyncio.TimeoutError): + # act -- run a batch get_trustworthiness_score + tlm.get_trustworthiness_score( + ["What is the capital of France?"] * 3, + ["Paris"] * 3, + ) + + +def test_batch_try_get_trustworthiness_score(tlm: TLM) -> None: + """Tests running a batch try get_trustworthiness_score in the TLM. + + Expected: + - TLM should return a list of responses + - Responses can be None or of type TLMResponse + - No exceptions are raised + """ + # act -- run a batch get_trustworthiness_score + response = tlm.try_get_trustworthiness_score( + ["What is the capital of France?"] * 3, + ["Paris"] * 3, + ) + + # assert + # - response is not None + # - a list of responses of type TLMResponse or None is returned + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(r is None or is_trustworthiness_score(r) for r in response) + + +def test_batch_try_get_trustworthiness_score_force_timeouts(tlm: TLM) -> None: + """Tests running a batch try get_trustworthiness_score in the TLM, forcing timeouts. + + Sets timeout to 0.0001 seconds, which should force a timeout for all get_trustworthiness_scores. + This should result in None responses for all get_trustworthiness_scores. + + Expected: + - TLM should return a list of responses + - Responses can be None or of type TLMResponse + - No exceptions are raised + """ + # arrange -- override timeout + tlm._timeout = 0.0001 + + # act -- run a batch get_trustworthiness_score + response = tlm.try_get_trustworthiness_score( + ["What is the capital of France?"] * 3, + ["Paris"] * 3, + ) + + # assert + # - response is not None + # - all responses timed out and are None + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(r is None for r in response) diff --git a/tests/tlm/test_prompt.py b/tests/tlm/test_prompt.py new file mode 100644 index 00000000..53af9538 --- /dev/null +++ b/tests/tlm/test_prompt.py @@ -0,0 +1,119 @@ +import asyncio +from typing import Any + +import pytest + +from cleanlab_studio.studio.trustworthy_language_model import TLM + + +def is_tlm_response(response: Any) -> bool: + """Returns True if the response is a TLMResponse.""" + return ( + isinstance(response, dict) + and "response" in response + and "trustworthiness_score" in response + ) + + +def test_single_prompt(tlm: TLM) -> None: + """Tests running a single prompt in the TLM. + + Expected: + - TLM should return a single response + - Response should be non-None + - No exceptions are raised + """ + # act -- run a single prompt + response = tlm.prompt("What is the capital of France?") + + # assert + # - response is not None + # - a single response of type TLMResponse is returned + # - no exceptions are raised (implicit) + assert response is not None + assert is_tlm_response(response) + + +def test_batch_prompt(tlm: TLM) -> None: + """Tests running a batch prompt in the TLM. + + Expected: + - TLM should return a list of responses + - Responses should be non-None + - No exceptions are raised + - Each response should be of type TLMResponse + """ + # act -- run a batch prompt + response = tlm.prompt(["What is the capital of France?"] * 3) + + # assert + # - response is not None + # - a list of responses of type TLMResponse is returned + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(is_tlm_response(r) for r in response) + + +def test_batch_prompt_force_timeouts(tlm: TLM) -> None: + """Tests running a batch prompt in the TLM, forcing timeouts. + + Sets timeout to 0.0001 seconds, which should force a timeout for all prompts. + This should result in a timeout error being thrown + + Expected: + - TLM should raise a timeout error + """ + # arrange -- override timeout + tlm._timeout = 0.0001 + + # assert -- timeout is thrown + with pytest.raises(asyncio.TimeoutError): + # act -- run a batch prompt + tlm.prompt(["What is the capital of France?"] * 3) + + +def test_batch_try_prompt(tlm: TLM) -> None: + """Tests running a batch try prompt in the TLM. + + Expected: + - TLM should return a list of responses + - Responses can be None or of type TLMResponse + - No exceptions are raised + """ + # act -- run a batch prompt + response = tlm.try_prompt(["What is the capital of France?"] * 3) + + # assert + # - response is not None + # - a list of responses of type TLMResponse or None is returned + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(r is None or is_tlm_response(r) for r in response) + + +def test_batch_try_prompt_force_timeouts(tlm: TLM) -> None: + """Tests running a batch try prompt in the TLM, forcing timeouts. + + Sets timeout to 0.0001 seconds, which should force a timeout for all prompts. + This should result in None responses for all prompts. + + Expected: + - TLM should return a list of responses + - Responses can be None or of type TLMResponse + - No exceptions are raised + """ + # arrange -- override timeout + tlm._timeout = 0.0001 + + # act -- run a batch prompt + response = tlm.try_prompt(["What is the capital of France?"] * 3) + + # assert + # - response is not None + # - all responses timed out and are None + # - no exceptions are raised (implicit) + assert response is not None + assert isinstance(response, list) + assert all(r is None for r in response) diff --git a/tests/tlm/test_validation.py b/tests/tlm/test_validation.py new file mode 100644 index 00000000..322b8fd1 --- /dev/null +++ b/tests/tlm/test_validation.py @@ -0,0 +1,239 @@ +import numpy as np +import pytest + +from cleanlab_studio.studio.studio import Studio +from cleanlab_studio.studio.trustworthy_language_model import TLM +from cleanlab_studio.errors import TlmBadRequest, ValidationError + +np.random.seed(0) + + +MAX_PROMPT_LENGTH_TOKENS: int = 15_000 +MAX_RESPONSE_LENGTH_TOKENS: int = 15_000 +MAX_COMBINED_LENGTH_TOKENS: int = 15_000 + +CHARACTERS_PER_TOKEN: int = 5 + + +def test_prompt_too_long_exception_single_prompt(tlm: TLM): + """Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a single prompt.""" + with pytest.raises(TlmBadRequest, match="^Prompt length exceeds.*"): + tlm.prompt( + "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_batch_prompt(tlm: TLM, num_prompts: int): + """Tests that bad request error is raised when prompt is too long when calling tlm.prompt with a batch of prompts. + + Error message should indicate which the batch index for which the prompt is too long. + """ + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + prompt_too_long_index = np.random.randint(0, num_prompts) + prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + + with pytest.raises( + TlmBadRequest, + match=f"^Error executing query at index {prompt_too_long_index}:\nPrompt length exceeds.*", + ): + tlm.prompt( + prompts, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_try_prompt(tlm: TLM, num_prompts: int): + """Tests that None is returned when prompt is too long when calling tlm.try_prompt with a batch of prompts.""" + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + prompt_too_long_index = np.random.randint(0, num_prompts) + prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + + tlm_responses = tlm.try_prompt( + prompts, + ) + + # assert -- None is returned at correct index + assert tlm_responses[prompt_too_long_index] is None + + +def test_response_too_long_exception_single_score(tlm: TLM): + """Tests that bad request error is raised when response is too long when calling tlm.get_trustworthiness_score with a single prompt.""" + with pytest.raises(TlmBadRequest, match="^Response length exceeds.*"): + tlm.get_trustworthiness_score( + "a", + "a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_response_too_long_exception_batch_score(tlm: TLM, num_prompts: int): + """Tests that bad request error is raised when prompt is too long when calling tlm.get_trustworthiness_score with a batch of prompts. + + Error message should indicate which the batch index for which the prompt is too long. + """ + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + response_too_long_index = np.random.randint(0, num_prompts) + responses[response_too_long_index] = ( + "a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + ) + + with pytest.raises( + TlmBadRequest, + match=f"^Error executing query at index {response_too_long_index}:\nResponse length exceeds.*", + ): + tlm.get_trustworthiness_score( + prompts, + responses, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_response_too_long_exception_try_score(tlm: TLM, num_prompts: int): + """Tests that None is returned when prompt is too long when calling tlm.try_get_trustworthiness_score with a batch of prompts.""" + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + response_too_long_index = np.random.randint(0, num_prompts) + responses[response_too_long_index] = ( + "a" * (MAX_RESPONSE_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + ) + + tlm_responses = tlm.try_get_trustworthiness_score( + prompts, + responses, + ) + + # assert -- None is returned at correct index + assert tlm_responses[response_too_long_index] is None + + +def test_prompt_too_long_exception_single_score(tlm: TLM): + """Tests that bad request error is raised when prompt is too long when calling tlm.get_trustworthiness_score with a single prompt.""" + with pytest.raises(TlmBadRequest, match="^Prompt length exceeds.*"): + tlm.get_trustworthiness_score( + "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN, + "a", + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_batch_score(tlm: TLM, num_prompts: int): + """Tests that bad request error is raised when prompt is too long when calling tlm.get_trustworthiness_score with a batch of prompts. + + Error message should indicate which the batch index for which the prompt is too long. + """ + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + prompt_too_long_index = np.random.randint(0, num_prompts) + prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + + with pytest.raises( + TlmBadRequest, + match=f"^Error executing query at index {prompt_too_long_index}:\nPrompt length exceeds.*", + ): + tlm.get_trustworthiness_score( + prompts, + responses, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_try_score(tlm: TLM, num_prompts: int): + """Tests that None is returned when prompt is too long when calling tlm.try_get_trustworthiness_score with a batch of prompts.""" + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + prompt_too_long_index = np.random.randint(0, num_prompts) + prompts[prompt_too_long_index] = "a" * (MAX_PROMPT_LENGTH_TOKENS + 1) * CHARACTERS_PER_TOKEN + + responses = tlm.try_get_trustworthiness_score( + prompts, + responses, + ) + + # assert -- None is returned at correct index + assert responses[prompt_too_long_index] is None + + +def test_combined_too_long_exception_single_score(tlm: TLM): + """Tests that bad request error is raised when prompt + response combined length is too long when calling tlm.get_trustworthiness_score with a single prompt.""" + with pytest.raises(TlmBadRequest, match="^Prompt and response combined length exceeds.*"): + tlm.get_trustworthiness_score( + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN, + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_batch_score(tlm: TLM, num_prompts: int): + """Tests that bad request error is raised when prompt + response combined length is too long when calling tlm.get_trustworthiness_score with a batch of prompts. + + Error message should indicate which the batch index for which the prompt is too long. + """ + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + combined_too_long_index = np.random.randint(0, num_prompts) + prompts[combined_too_long_index] = ( + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN + ) + responses[combined_too_long_index] = ( + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN + ) + + with pytest.raises( + TlmBadRequest, + match=f"^Error executing query at index {combined_too_long_index}:\nPrompt and response combined length exceeds.*", + ): + tlm.get_trustworthiness_score( + prompts, + responses, + ) + + +@pytest.mark.parametrize("num_prompts", [1, 2, 5]) +def test_prompt_too_long_exception_try_score(tlm: TLM, num_prompts: int): + """Tests that None is returned when prompt + response is too long when calling tlm.try_get_trustworthiness_score with a batch of prompts.""" + # create batch of prompts with one prompt that is too long + prompts = ["What is the capital of France?"] * num_prompts + responses = ["Paris"] * num_prompts + combined_too_long_index = np.random.randint(0, num_prompts) + prompts[combined_too_long_index] = ( + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN + ) + responses[combined_too_long_index] = ( + "a" * (MAX_PROMPT_LENGTH_TOKENS // 2 + 1) * CHARACTERS_PER_TOKEN + ) + + responses = tlm.try_get_trustworthiness_score( + prompts, + responses, + ) + + # assert -- None is returned at correct index + assert responses[combined_too_long_index] is None + + +def test_invalid_option_passed(studio: Studio): + """Tests that validation error is thrown when an invalid option is passed to the TLM.""" + invalid_option = "invalid_option" + + with pytest.raises( + ValidationError, match=f"^Invalid keys in options dictionary: {{'{invalid_option}'}}.*" + ): + studio.TLM(options={invalid_option: "invalid_value"}) + + +def test_max_tokens_invalid_option_passed(studio: Studio): + """Tests that validation error is thrown when an invalid max_tokens option value is passed to the TLM.""" + option = "max_tokens" + option_value = -1 + + with pytest.raises(ValidationError, match=f"Invalid value {option_value}, max_tokens.*"): + studio.TLM(options={option: option_value})