-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into fix/cl-cols-index
- Loading branch information
Showing
17 changed files
with
1,544 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,21 @@ | ||
import asyncio | ||
import io | ||
from math import e | ||
import os | ||
import time | ||
from typing import Callable, cast, List, Optional, Tuple, Dict, Union, Any | ||
|
||
from cleanlab_studio.errors import ( | ||
APIError, | ||
IngestionError, | ||
InvalidProjectConfiguration, | ||
RateLimitError, | ||
TlmBadRequest, | ||
TlmServerError, | ||
) | ||
from cleanlab_studio.internal.util import get_basic_info, obfuscate_stack_trace | ||
from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler | ||
|
||
import aiohttp | ||
import aiohttp.client_exceptions | ||
import requests | ||
from tqdm import tqdm | ||
import pandas as pd | ||
|
@@ -88,18 +90,40 @@ def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None: | |
) | ||
|
||
|
||
async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse) -> None: | ||
async def handle_tlm_client_error_from_resp( | ||
resp: aiohttp.ClientResponse, batch_index: Optional[int] | ||
) -> None: | ||
"""Catches 4XX (client error) errors.""" | ||
if 400 <= resp.status < 500: | ||
try: | ||
res_json = await resp.json() | ||
error_message = res_json["error"] | ||
except Exception: | ||
error_message = "Client error occurred." | ||
error_message = "TLM query failed. Please try again and contact [email protected] if the problem persists." | ||
|
||
if batch_index is not None: | ||
error_message = f"Error executing query at index {batch_index}:\n{error_message}" | ||
|
||
raise TlmBadRequest(error_message) | ||
|
||
|
||
async def handle_tlm_api_error_from_resp( | ||
resp: aiohttp.ClientResponse, batch_index: Optional[int] | ||
) -> None: | ||
"""Catches 5XX (server error) errors.""" | ||
if 500 <= resp.status < 600: | ||
try: | ||
res_json = await resp.json() | ||
error_message = res_json["error"] | ||
except Exception: | ||
error_message = "TLM query failed. Please try again and contact [email protected] if the problem persists." | ||
|
||
if batch_index is not None: | ||
error_message = f"Error executing query at index {batch_index}:\n{error_message}" | ||
|
||
raise TlmServerError(error_message, resp.status) | ||
|
||
|
||
def validate_api_key(api_key: str) -> bool: | ||
res = requests.get( | ||
cli_base_url + "/validate", | ||
|
@@ -556,6 +580,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: | |
await asyncio.sleep(sleep_time) | ||
try: | ||
return await func(*args, **kwargs) | ||
except aiohttp.client_exceptions.ClientConnectorError as e: | ||
# note: we don't increment num_try here, because we don't want connection errors to count against the total number of retries | ||
sleep_time = 2**num_try | ||
except RateLimitError as e: | ||
# note: we don't increment num_try here, because we don't want rate limit retries to count against the total number of retries | ||
sleep_time = e.retry_after | ||
|
@@ -578,7 +605,9 @@ async def tlm_prompt( | |
prompt: str, | ||
quality_preset: str, | ||
options: Optional[JSONDict], | ||
rate_handler: TlmRateHandler, | ||
client_session: Optional[aiohttp.ClientSession] = None, | ||
batch_index: Optional[int] = None, | ||
) -> JSONDict: | ||
""" | ||
Prompt Trustworthy Language Model with a question, and get back its answer along with a confidence score | ||
|
@@ -588,7 +617,9 @@ async def tlm_prompt( | |
prompt (str): prompt for TLM to respond to | ||
quality_preset (str): quality preset to use to generate response | ||
options (JSONDict): additional parameters for TLM | ||
rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate | ||
client_session (aiohttp.ClientSession): client session used to issue TLM request | ||
batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch. | ||
Returns: | ||
JSONDict: dictionary with TLM response and confidence score | ||
|
@@ -599,16 +630,18 @@ async def tlm_prompt( | |
local_scoped_client = True | ||
|
||
try: | ||
res = await client_session.post( | ||
f"{tlm_base_url}/prompt", | ||
json=dict(prompt=prompt, quality=quality_preset, options=options or {}), | ||
headers=_construct_headers(api_key), | ||
) | ||
res_json = await res.json() | ||
async with rate_handler: | ||
res = await client_session.post( | ||
f"{tlm_base_url}/prompt", | ||
json=dict(prompt=prompt, quality=quality_preset, options=options or {}), | ||
headers=_construct_headers(api_key), | ||
) | ||
|
||
handle_rate_limit_error_from_resp(res) | ||
await handle_tlm_client_error_from_resp(res) | ||
handle_api_error_from_json(res_json) | ||
res_json = await res.json() | ||
|
||
handle_rate_limit_error_from_resp(res) | ||
await handle_tlm_client_error_from_resp(res, batch_index) | ||
await handle_tlm_api_error_from_resp(res, batch_index) | ||
|
||
finally: | ||
if local_scoped_client: | ||
|
@@ -624,7 +657,9 @@ async def tlm_get_confidence_score( | |
response: str, | ||
quality_preset: str, | ||
options: Optional[JSONDict], | ||
rate_handler: TlmRateHandler, | ||
client_session: Optional[aiohttp.ClientSession] = None, | ||
batch_index: Optional[int] = None, | ||
) -> JSONDict: | ||
""" | ||
Query Trustworthy Language Model for a confidence score for the prompt-response pair. | ||
|
@@ -635,7 +670,9 @@ async def tlm_get_confidence_score( | |
response (str): response for TLM to get confidence score for | ||
quality_preset (str): quality preset to use to generate confidence score | ||
options (JSONDict): additional parameters for TLM | ||
rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate | ||
client_session (aiohttp.ClientSession): client session used to issue TLM request | ||
batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch. | ||
Returns: | ||
JSONDict: dictionary with TLM confidence score | ||
|
@@ -646,21 +683,20 @@ async def tlm_get_confidence_score( | |
local_scoped_client = True | ||
|
||
try: | ||
res = await client_session.post( | ||
f"{tlm_base_url}/get_confidence_score", | ||
json=dict( | ||
prompt=prompt, response=response, quality=quality_preset, options=options or {} | ||
), | ||
headers=_construct_headers(api_key), | ||
) | ||
res_json = await res.json() | ||
async with rate_handler: | ||
res = await client_session.post( | ||
f"{tlm_base_url}/get_confidence_score", | ||
json=dict( | ||
prompt=prompt, response=response, quality=quality_preset, options=options or {} | ||
), | ||
headers=_construct_headers(api_key), | ||
) | ||
|
||
if local_scoped_client: | ||
await client_session.close() | ||
res_json = await res.json() | ||
|
||
handle_rate_limit_error_from_resp(res) | ||
await handle_tlm_client_error_from_resp(res) | ||
handle_api_error_from_json(res_json) | ||
handle_rate_limit_error_from_resp(res) | ||
await handle_tlm_client_error_from_resp(res, batch_index) | ||
await handle_tlm_api_error_from_resp(res, batch_index) | ||
|
||
finally: | ||
if local_scoped_client: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from typing import List | ||
from typing import List, Tuple | ||
|
||
# TLM constants | ||
# prepend constants with _ so that they don't show up in help.cleanlab.ai docs | ||
_DEFAULT_MAX_CONCURRENT_TLM_REQUESTS: int = 16 | ||
_MAX_CONCURRENT_TLM_REQUESTS_LIMIT: int = 128 | ||
_VALID_TLM_QUALITY_PRESETS: List[str] = ["best", "high", "medium", "low", "base"] | ||
_VALID_TLM_MODELS: List[str] = ["gpt-3.5-turbo-16k", "gpt-4"] | ||
_TLM_MAX_RETRIES: int = 3 # TODO: finalize this number | ||
TLM_MAX_TOKEN_RANGE: Tuple[int, int] = (64, 512) # (min, max) | ||
TLM_NUM_CANDIDATE_RESPONSES_RANGE: Tuple[int, int] = (1, 20) # (min, max) | ||
TLM_NUM_CONSISTENCY_SAMPLES_RANGE: Tuple[int, int] = (0, 20) # (min, max) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import asyncio | ||
from types import TracebackType | ||
from typing import Optional, Type | ||
|
||
from cleanlab_studio.errors import RateLimitError, TlmServerError | ||
|
||
|
||
class TlmRateHandler: | ||
"""Concurrency handler for TLM queries. | ||
Implements additive increase / multiplicative decrease congestion control algorithm. | ||
""" | ||
|
||
DEFAULT_CONGESTION_WINDOW: int = 4 | ||
DEFAULT_SLOW_START_THRESHOLD: int = 16 | ||
|
||
SLOW_START_INCREASE_FACTOR: int = 2 | ||
ADDITIVE_INCREMENT: int = 1 | ||
MULTIPLICATIVE_DECREASE_FACTOR: int = 2 | ||
|
||
MAX_CONCURRENT_REQUESTS: int = 512 | ||
|
||
def __init__( | ||
self, | ||
congestion_window: int = DEFAULT_CONGESTION_WINDOW, | ||
slow_start_threshold: int = DEFAULT_SLOW_START_THRESHOLD, | ||
): | ||
"""Initializes TLM rate handler.""" | ||
self._congestion_window: int = congestion_window | ||
self._slow_start_threshold = slow_start_threshold | ||
|
||
# create send semaphore and seed w/ initial congestion window | ||
self._send_semaphore = asyncio.Semaphore(value=self._congestion_window) | ||
|
||
async def __aenter__(self) -> None: | ||
"""Acquires send semaphore, blocking until it can be acquired.""" | ||
await self._send_semaphore.acquire() | ||
return | ||
|
||
async def __aexit__( | ||
self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc: Optional[BaseException], | ||
traceback_type: Optional[TracebackType], | ||
) -> bool: | ||
"""Handles exiting from rate limit context. Never suppresses exceptions. | ||
If request succeeded, increase congestion window. | ||
If request failed due to rate limit error, decrease congestion window. | ||
If request failed due to 503, decrease congestion window. | ||
Else if request failed for other reason, don't change congestion window, just exit. | ||
""" | ||
if exc_type is None: | ||
await self._increase_congestion_window() | ||
|
||
elif ( | ||
isinstance(exc, RateLimitError) | ||
or isinstance(exc, TlmServerError) | ||
and exc.status_code == 503 | ||
): | ||
await self._decrease_congestion_window() | ||
|
||
# release acquired send semaphore from aenter | ||
self._send_semaphore.release() | ||
|
||
return False | ||
|
||
async def _increase_congestion_window( | ||
self, | ||
slow_start_increase_factor: int = SLOW_START_INCREASE_FACTOR, | ||
additive_increment: int = ADDITIVE_INCREMENT, | ||
max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS, | ||
) -> None: | ||
"""Increases TLM congestion window | ||
If in slow start, increase is exponential. | ||
Otherwise, increase is linear. | ||
After increasing congestion window, notify on send condition with n=increase | ||
""" | ||
# track previous congestion window size | ||
prev_congestion_window = self._congestion_window | ||
|
||
# increase congestion window | ||
if self._congestion_window < self._slow_start_threshold: | ||
self._congestion_window *= slow_start_increase_factor | ||
|
||
else: | ||
self._congestion_window += additive_increment | ||
|
||
# cap congestion window at max concurrent requests | ||
self._congestion_window = min(self._congestion_window, max_concurrent_requests) | ||
|
||
# release <congestion_window_increase> from send semaphore | ||
congestion_window_increase = self._congestion_window - prev_congestion_window | ||
for _ in range(congestion_window_increase): | ||
self._send_semaphore.release() | ||
|
||
async def _decrease_congestion_window( | ||
self, | ||
multiplicative_decrease_factor: int = MULTIPLICATIVE_DECREASE_FACTOR, | ||
) -> None: | ||
"""Decreases TLM congestion window, to minimum of 1.""" | ||
if self._congestion_window <= 1: | ||
return | ||
|
||
prev_congestion_window = self._congestion_window | ||
self._congestion_window //= multiplicative_decrease_factor | ||
|
||
# acquire congestion window decrease from send semaphore | ||
congestion_window_decrease = prev_congestion_window - self._congestion_window | ||
for _ in range(congestion_window_decrease): | ||
await self._send_semaphore.acquire() |
Oops, something went wrong.