Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/inference api #117

Merged
merged 43 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
cafec61
adds inference API (v0)
ryansingman Jul 12, 2023
55f93a2
modify client to fit backend api endpoints
Jul 19, 2023
9257172
modify cli to make model prediction work
Jul 21, 2023
79b5cde
black
Jul 21, 2023
92d973c
modify cli after testing api endpoints
Jul 24, 2023
967722b
black
Jul 24, 2023
3d63153
integrate into Studio class
ryansingman Jul 24, 2023
d81b209
remove base url override
ryansingman Jul 24, 2023
865cf03
fix api endpoint for client to work
Jul 24, 2023
c88f90b
modify code to support text files without headers
Jul 24, 2023
7db17d8
remove test file for local testing
Jul 24, 2023
2d27b3c
change response for upload api and remove logic for comparing headers…
Jul 24, 2023
25d33ac
modify invoke lambda api to send only query_id as param
Jul 24, 2023
ba0a5f9
remove test file
Jul 25, 2023
9ab9d28
fix mypy errors
Jul 25, 2023
7db2438
remove TypeAlias
Jul 25, 2023
7b1a48a
fix mypy for Batch type
Jul 25, 2023
b0abe97
User Union instead of | for multi generic typing
Jul 25, 2023
d5fb4d4
more typing fixes and timeout in prediction
Jul 25, 2023
0b2a73c
change timeout to adn
Jul 25, 2023
4846bf3
remove print statement
Jul 25, 2023
af6ce97
remove test files again
Jul 25, 2023
a2a466c
fix code review comments
Jul 25, 2023
a3ce52f
remove test files
Jul 25, 2023
9bcb15f
for updating pr
Jul 25, 2023
8de5c4a
change timeout to new var name
Jul 25, 2023
fd758ea
remove header replace logic
Jul 26, 2023
658b5fa
modify predict function to take care of text inputs
Jul 26, 2023
612a522
remove download api endpoint and supply url directly to pandas
Jul 27, 2023
5435cd7
Merge remote-tracking branch 'origin/main' into feature/inference-api
Jul 28, 2023
40bffbf
update doctring to match documentation format
Jul 28, 2023
d83e5e1
fix predict timeout
ryansingman Jul 31, 2023
e266518
mypy fix
ryansingman Jul 31, 2023
8679aad
mypy fix
ryansingman Jul 31, 2023
c871c8c
add typing extensions req
ryansingman Jul 31, 2023
13b4c8c
fix incorrect return types, return predictions separate from class probs
ryansingman Jul 31, 2023
8b7b8de
mypy fix
ryansingman Jul 31, 2023
4261e5a
clean up polling interface, angelas comments
ryansingman Aug 1, 2023
24b4733
fix sleep placement in poll loop
ryansingman Aug 1, 2023
2603868
mypy fix
ryansingman Aug 1, 2023
f7cdefa
mypy fix
ryansingman Aug 1, 2023
f1d5102
fix results name
ryansingman Aug 2, 2023
af4d02f
fix nits
ryansingman Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion cleanlab_studio/internal/api/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import io
import os
import time
from typing import Callable, List, Optional, Tuple, Union, Any
from itertools import chain
from typing import Callable, List, Optional, Tuple, Dict, Union, Any
from cleanlab_studio.errors import APIError

import requests
Expand All @@ -19,12 +21,14 @@
from cleanlab_studio.internal.types import JSONDict
from cleanlab_studio.version import __version__


base_url = os.environ.get("CLEANLAB_API_BASE_URL", "https://api.cleanlab.ai/api")
cli_base_url = f"{base_url}/cli/v0"
upload_base_url = f"{base_url}/upload/v0"
dataset_base_url = f"{base_url}/datasets"
project_base_url = f"{base_url}/projects"
cleanset_base_url = f"{base_url}/cleansets"
model_base_url = f"{base_url}/v1/deployment"


def _construct_headers(
Expand Down Expand Up @@ -330,3 +334,51 @@ def poll_progress(
res = request_function(progress_id)
pbar.update(float(1) - pbar.n)
return res


def upload_predict_batch(api_key: str, model_id: str, batch: io.StringIO) -> str:
"""Uploads prediction batch and returns query ID."""
url = f"{model_base_url}/{model_id}/upload"
res = requests.post(
url,
headers=_construct_headers(api_key),
)

handle_api_error(res)
presigned_url = res.json()["upload_url"]
query_id: str = res.json()["query_id"]

requests.post(presigned_url["url"], data=presigned_url["fields"], files={"file": batch})

return query_id


def start_prediction(api_key: str, model_id: str, query_id: str) -> None:
"""Starts prediction for query."""
res = requests.post(
f"{model_base_url}/{model_id}/predict/{query_id}",
headers=_construct_headers(api_key),
)

handle_api_error(res)


def get_prediction_status(api_key: str, query_id: str) -> Dict[str, str]:
"""Gets status of model prediction query."""
res = requests.get(
f"{model_base_url}/predict/{query_id}",
headers=_construct_headers(api_key),
)
handle_api_error(res)

prediction_results = res.json()
status = prediction_results["status"]
result_url = prediction_results["results"]
error_msg = prediction_results["error_msg"]

if status == "COMPLETE":
ryansingman marked this conversation as resolved.
Show resolved Hide resolved
return {"status": "done", "result_url": result_url}
elif status == "FAILED":
return {"status": "error", "error_msg": error_msg}
else:
return {"status": "running"}
101 changes: 101 additions & 0 deletions cleanlab_studio/studio/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import abc
import csv
import functools
import io
import time
from typing import List, Union, Optional

import numpy as np
import numpy.typing as npt
import pandas as pd

from cleanlab_studio.errors import APIError
ryansingman marked this conversation as resolved.
Show resolved Hide resolved
from cleanlab_studio.internal.api import api


TextBatch = Union[List[str], npt.NDArray[np.str_], pd.Series]
ryansingman marked this conversation as resolved.
Show resolved Hide resolved
TabularBatch = Union[pd.DataFrame]
Batch = Union[TextBatch, TabularBatch]

Predictions = Union[npt.NDArray[np.int_], npt.NDArray[np.str_]]
ClassProbablities = pd.DataFrame


class Model(abc.ABC):
"""Base class for deployed model inference."""

def __init__(self, api_key: str, model_id: str):
"""Initializes model class w/ API key and model ID."""
self._api_key = api_key
self._model_id = model_id

def predict(
self,
batch: Batch,
timeout: int = 600,
) -> Union[str, Predictions]:
"""
Gets predictions for batch of examples.

Args:
batch: batch of example to predict classes for
ryansingman marked this conversation as resolved.
Show resolved Hide resolved
timeout: optional parameter to set timeout for predictions in seconds

Returns:
predictions from batch as a numpy array or an error message if predictions fail
"""
csv_batch = self._convert_batch_to_csv(batch)
return self._predict(csv_batch, timeout)

def _predict(self, batch: io.StringIO, timeout: int) -> Union[str, Predictions]:
"""Gets predictions for batch of examples.

:param batch: batch of example to predict classes for, as in-memory CSV file
:return: predictions from batch
"""
query_id: str = api.upload_predict_batch(self._api_key, self._model_id, batch)
api.start_prediction(self._api_key, self._model_id, query_id)

resp = api.get_prediction_status(self._api_key, query_id)
status: Optional[str] = resp["status"]
# Set timeout to prevent users from getting stuck indefinitely when there is a failure
timeout_limit = time.time() + timeout

while status == "running" and time.time() < timeout_limit:
resp = api.get_prediction_status(self._api_key, query_id)
status = resp["status"]
# Set time.sleep so that the while loop doesn't flood backend with api calls
time.sleep(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make timeout configurable and cleanup logic


if status == "error":
raise APIError(resp["error_msg"])
else:
result_url = resp["result_url"]
results_converted: Predictions = pd.read_csv(result_url).to_numpy()
return results_converted

@staticmethod
def _convert_batch_to_csv(batch: Batch) -> io.StringIO:
"""Converts batch object to CSV string IO."""
sio = io.StringIO()

# handle text batches
if isinstance(batch, (list, np.ndarray, pd.Series)):
writer = csv.writer(sio)

# write header
writer.writerow(["text"])

# write labels to CSV
for input_data in batch:
writer.writerow([input_data])

# handle tabular batches
elif isinstance(batch, pd.DataFrame):
batch.to_csv(sio)

else:
raise TypeError(f"Invalid type of batch: {type(batch)}")

sio.seek(0)
return sio
14 changes: 13 additions & 1 deletion cleanlab_studio/studio/studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy.typing as npt
import pandas as pd

from . import clean, upload
from . import clean, upload, inference
from cleanlab_studio.internal.api import api
from cleanlab_studio.internal.util import (
init_dataset_source,
Expand Down Expand Up @@ -290,6 +290,18 @@ def delete_project(self, project_id: str) -> None:
api.delete_project(self._api_key, project_id)
print(f"Successfully deleted project: {project_id}")

def get_model(self, model_id: str) -> inference.Model:
"""
Gets a model deployed by Cleanlab Studio.

Args:
model_id: ID of model to get. This ID should be fetched in the deployments page of the app UI.

Returns:
Model object with methods run predictions on new input data
ryansingman marked this conversation as resolved.
Show resolved Hide resolved
"""
return inference.Model(self._api_key, model_id)

class Experimental:
def __init__(self, outer): # type: ignore
self._outer = outer
Expand Down
Loading