Skip to content

Commit

Permalink
Merge branch 'main' into fix/cl-cols-index
Browse files Browse the repository at this point in the history
  • Loading branch information
axl1313 committed Apr 9, 2024
2 parents 314538a + 188662c commit b765323
Show file tree
Hide file tree
Showing 17 changed files with 1,544 additions and 308 deletions.
31 changes: 31 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,37 @@ name: CI
on: [ push, pull_request ]

jobs:
test:
name: "Test: Python ${{ matrix.python }} on ${{ matrix.os }}"
runs-on: ${{ matrix.os }}
strategy:
matrix:
os:
- ubuntu-latest
- macos-latest
- windows-latest
python:
- "3.8"
- "3.9"
- "3.10"
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .
pip install -r tests/requirements_test.txt
- name: Install Cleanlab Studio client
run: pip install -e .
- name: Cleanlab login
run: cleanlab login --key "${{ secrets.CLEANLAB_STUDIO_CI_API_KEY }}"
- name: Run tests
run: |
pytest --verbose
typecheck:
name: Type check
runs-on: ubuntu-latest
Expand Down
17 changes: 15 additions & 2 deletions cleanlab_studio/errors.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from asyncio import Handle


class HandledError(Exception):
pass

Expand Down Expand Up @@ -34,6 +37,10 @@ class SettingsError(HandledError):
pass


class ValidationError(HandledError):
pass


class UploadError(HandledError):
pass

Expand Down Expand Up @@ -63,16 +70,22 @@ class APITimeoutError(HandledError):
pass


class RateLimitError(APIError):
class RateLimitError(HandledError):
def __init__(self, message: str, retry_after: int):
self.message = message
self.retry_after = retry_after


class TlmBadRequest(APIError):
class TlmBadRequest(HandledError):
pass


class TlmServerError(APIError):
def __init__(self, message: str, status_code: int) -> None:
self.message = message
self.status_code = status_code


class UnsupportedVersionError(HandledError):
def __init__(self) -> None:
super().__init__(
Expand Down
88 changes: 62 additions & 26 deletions cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import asyncio
import io
from math import e
import os
import time
from typing import Callable, cast, List, Optional, Tuple, Dict, Union, Any

from cleanlab_studio.errors import (
APIError,
IngestionError,
InvalidProjectConfiguration,
RateLimitError,
TlmBadRequest,
TlmServerError,
)
from cleanlab_studio.internal.util import get_basic_info, obfuscate_stack_trace
from cleanlab_studio.internal.tlm.concurrency import TlmRateHandler

import aiohttp
import aiohttp.client_exceptions
import requests
from tqdm import tqdm
import pandas as pd
Expand Down Expand Up @@ -88,18 +90,40 @@ def handle_rate_limit_error_from_resp(resp: aiohttp.ClientResponse) -> None:
)


async def handle_tlm_client_error_from_resp(resp: aiohttp.ClientResponse) -> None:
async def handle_tlm_client_error_from_resp(
resp: aiohttp.ClientResponse, batch_index: Optional[int]
) -> None:
"""Catches 4XX (client error) errors."""
if 400 <= resp.status < 500:
try:
res_json = await resp.json()
error_message = res_json["error"]
except Exception:
error_message = "Client error occurred."
error_message = "TLM query failed. Please try again and contact [email protected] if the problem persists."

if batch_index is not None:
error_message = f"Error executing query at index {batch_index}:\n{error_message}"

raise TlmBadRequest(error_message)


async def handle_tlm_api_error_from_resp(
resp: aiohttp.ClientResponse, batch_index: Optional[int]
) -> None:
"""Catches 5XX (server error) errors."""
if 500 <= resp.status < 600:
try:
res_json = await resp.json()
error_message = res_json["error"]
except Exception:
error_message = "TLM query failed. Please try again and contact [email protected] if the problem persists."

if batch_index is not None:
error_message = f"Error executing query at index {batch_index}:\n{error_message}"

raise TlmServerError(error_message, resp.status)


def validate_api_key(api_key: str) -> bool:
res = requests.get(
cli_base_url + "/validate",
Expand Down Expand Up @@ -556,6 +580,9 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
await asyncio.sleep(sleep_time)
try:
return await func(*args, **kwargs)
except aiohttp.client_exceptions.ClientConnectorError as e:
# note: we don't increment num_try here, because we don't want connection errors to count against the total number of retries
sleep_time = 2**num_try
except RateLimitError as e:
# note: we don't increment num_try here, because we don't want rate limit retries to count against the total number of retries
sleep_time = e.retry_after
Expand All @@ -578,7 +605,9 @@ async def tlm_prompt(
prompt: str,
quality_preset: str,
options: Optional[JSONDict],
rate_handler: TlmRateHandler,
client_session: Optional[aiohttp.ClientSession] = None,
batch_index: Optional[int] = None,
) -> JSONDict:
"""
Prompt Trustworthy Language Model with a question, and get back its answer along with a confidence score
Expand All @@ -588,7 +617,9 @@ async def tlm_prompt(
prompt (str): prompt for TLM to respond to
quality_preset (str): quality preset to use to generate response
options (JSONDict): additional parameters for TLM
rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate
client_session (aiohttp.ClientSession): client session used to issue TLM request
batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch.
Returns:
JSONDict: dictionary with TLM response and confidence score
Expand All @@ -599,16 +630,18 @@ async def tlm_prompt(
local_scoped_client = True

try:
res = await client_session.post(
f"{tlm_base_url}/prompt",
json=dict(prompt=prompt, quality=quality_preset, options=options or {}),
headers=_construct_headers(api_key),
)
res_json = await res.json()
async with rate_handler:
res = await client_session.post(
f"{tlm_base_url}/prompt",
json=dict(prompt=prompt, quality=quality_preset, options=options or {}),
headers=_construct_headers(api_key),
)

handle_rate_limit_error_from_resp(res)
await handle_tlm_client_error_from_resp(res)
handle_api_error_from_json(res_json)
res_json = await res.json()

handle_rate_limit_error_from_resp(res)
await handle_tlm_client_error_from_resp(res, batch_index)
await handle_tlm_api_error_from_resp(res, batch_index)

finally:
if local_scoped_client:
Expand All @@ -624,7 +657,9 @@ async def tlm_get_confidence_score(
response: str,
quality_preset: str,
options: Optional[JSONDict],
rate_handler: TlmRateHandler,
client_session: Optional[aiohttp.ClientSession] = None,
batch_index: Optional[int] = None,
) -> JSONDict:
"""
Query Trustworthy Language Model for a confidence score for the prompt-response pair.
Expand All @@ -635,7 +670,9 @@ async def tlm_get_confidence_score(
response (str): response for TLM to get confidence score for
quality_preset (str): quality preset to use to generate confidence score
options (JSONDict): additional parameters for TLM
rate_handler (TlmRateHandler): concurrency handler used to manage TLM request rate
client_session (aiohttp.ClientSession): client session used to issue TLM request
batch_index (Optional[int], optional): index of prompt in batch, used for error messages. Defaults to None if not in batch.
Returns:
JSONDict: dictionary with TLM confidence score
Expand All @@ -646,21 +683,20 @@ async def tlm_get_confidence_score(
local_scoped_client = True

try:
res = await client_session.post(
f"{tlm_base_url}/get_confidence_score",
json=dict(
prompt=prompt, response=response, quality=quality_preset, options=options or {}
),
headers=_construct_headers(api_key),
)
res_json = await res.json()
async with rate_handler:
res = await client_session.post(
f"{tlm_base_url}/get_confidence_score",
json=dict(
prompt=prompt, response=response, quality=quality_preset, options=options or {}
),
headers=_construct_headers(api_key),
)

if local_scoped_client:
await client_session.close()
res_json = await res.json()

handle_rate_limit_error_from_resp(res)
await handle_tlm_client_error_from_resp(res)
handle_api_error_from_json(res_json)
handle_rate_limit_error_from_resp(res)
await handle_tlm_client_error_from_resp(res, batch_index)
await handle_tlm_api_error_from_resp(res, batch_index)

finally:
if local_scoped_client:
Expand Down
10 changes: 7 additions & 3 deletions cleanlab_studio/internal/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import List
from typing import List, Tuple

# TLM constants
# prepend constants with _ so that they don't show up in help.cleanlab.ai docs
_DEFAULT_MAX_CONCURRENT_TLM_REQUESTS: int = 16
_MAX_CONCURRENT_TLM_REQUESTS_LIMIT: int = 128
_VALID_TLM_QUALITY_PRESETS: List[str] = ["best", "high", "medium", "low", "base"]
_VALID_TLM_MODELS: List[str] = ["gpt-3.5-turbo-16k", "gpt-4"]
_TLM_MAX_RETRIES: int = 3 # TODO: finalize this number
TLM_MAX_TOKEN_RANGE: Tuple[int, int] = (64, 512) # (min, max)
TLM_NUM_CANDIDATE_RESPONSES_RANGE: Tuple[int, int] = (1, 20) # (min, max)
TLM_NUM_CONSISTENCY_SAMPLES_RANGE: Tuple[int, int] = (0, 20) # (min, max)
Empty file.
113 changes: 113 additions & 0 deletions cleanlab_studio/internal/tlm/concurrency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import asyncio
from types import TracebackType
from typing import Optional, Type

from cleanlab_studio.errors import RateLimitError, TlmServerError


class TlmRateHandler:
"""Concurrency handler for TLM queries.
Implements additive increase / multiplicative decrease congestion control algorithm.
"""

DEFAULT_CONGESTION_WINDOW: int = 4
DEFAULT_SLOW_START_THRESHOLD: int = 16

SLOW_START_INCREASE_FACTOR: int = 2
ADDITIVE_INCREMENT: int = 1
MULTIPLICATIVE_DECREASE_FACTOR: int = 2

MAX_CONCURRENT_REQUESTS: int = 512

def __init__(
self,
congestion_window: int = DEFAULT_CONGESTION_WINDOW,
slow_start_threshold: int = DEFAULT_SLOW_START_THRESHOLD,
):
"""Initializes TLM rate handler."""
self._congestion_window: int = congestion_window
self._slow_start_threshold = slow_start_threshold

# create send semaphore and seed w/ initial congestion window
self._send_semaphore = asyncio.Semaphore(value=self._congestion_window)

async def __aenter__(self) -> None:
"""Acquires send semaphore, blocking until it can be acquired."""
await self._send_semaphore.acquire()
return

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback_type: Optional[TracebackType],
) -> bool:
"""Handles exiting from rate limit context. Never suppresses exceptions.
If request succeeded, increase congestion window.
If request failed due to rate limit error, decrease congestion window.
If request failed due to 503, decrease congestion window.
Else if request failed for other reason, don't change congestion window, just exit.
"""
if exc_type is None:
await self._increase_congestion_window()

elif (
isinstance(exc, RateLimitError)
or isinstance(exc, TlmServerError)
and exc.status_code == 503
):
await self._decrease_congestion_window()

# release acquired send semaphore from aenter
self._send_semaphore.release()

return False

async def _increase_congestion_window(
self,
slow_start_increase_factor: int = SLOW_START_INCREASE_FACTOR,
additive_increment: int = ADDITIVE_INCREMENT,
max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS,
) -> None:
"""Increases TLM congestion window
If in slow start, increase is exponential.
Otherwise, increase is linear.
After increasing congestion window, notify on send condition with n=increase
"""
# track previous congestion window size
prev_congestion_window = self._congestion_window

# increase congestion window
if self._congestion_window < self._slow_start_threshold:
self._congestion_window *= slow_start_increase_factor

else:
self._congestion_window += additive_increment

# cap congestion window at max concurrent requests
self._congestion_window = min(self._congestion_window, max_concurrent_requests)

# release <congestion_window_increase> from send semaphore
congestion_window_increase = self._congestion_window - prev_congestion_window
for _ in range(congestion_window_increase):
self._send_semaphore.release()

async def _decrease_congestion_window(
self,
multiplicative_decrease_factor: int = MULTIPLICATIVE_DECREASE_FACTOR,
) -> None:
"""Decreases TLM congestion window, to minimum of 1."""
if self._congestion_window <= 1:
return

prev_congestion_window = self._congestion_window
self._congestion_window //= multiplicative_decrease_factor

# acquire congestion window decrease from send semaphore
congestion_window_decrease = prev_congestion_window - self._congestion_window
for _ in range(congestion_window_decrease):
await self._send_semaphore.acquire()
Loading

0 comments on commit b765323

Please sign in to comment.