diff --git a/cleanlab_studio/cli/api_service.py b/cleanlab_studio/cli/api_service.py index 0dbeb091..6494db5a 100644 --- a/cleanlab_studio/cli/api_service.py +++ b/cleanlab_studio/cli/api_service.py @@ -16,8 +16,13 @@ from cleanlab_studio.cli.dataset.image_utils import get_image_filepath from cleanlab_studio.internal.schema import Schema from cleanlab_studio.internal.types import JSONDict, IDType, Modality -from cleanlab_studio.internal.api import base_url, _construct_headers, handle_api_error, handle_api_error_from_json, get_presigned_posts - +from cleanlab_studio.internal.api import ( + base_url, + _construct_headers, + handle_api_error, + handle_api_error_from_json, + get_presigned_posts, +) MAX_PARALLEL_UPLOADS = 32 # XXX choose this dynamically? @@ -25,8 +30,6 @@ MAX_RETRIES = 4 - - async def upload_rows_async( session: aiohttp.ClientSession, api_key: str, diff --git a/cleanlab_studio/cli/dataset/schema_helpers.py b/cleanlab_studio/cli/dataset/schema_helpers.py index 73695688..eaa40092 100644 --- a/cleanlab_studio/cli/dataset/schema_helpers.py +++ b/cleanlab_studio/cli/dataset/schema_helpers.py @@ -20,8 +20,6 @@ from cleanlab_studio.version import MAX_SCHEMA_VERSION, MIN_SCHEMA_VERSION, SCHEMA_VERSION - - def load_schema(filepath: str) -> Schema: with open(filepath, "r") as f: schema_dict = json.load(f) @@ -108,8 +106,6 @@ def validate_schema(schema: Schema, columns: Collection[str]) -> None: raise ValueError("Dataset modality is text, but none of the fields is a text column.") - - def save_schema(schema: Schema, filename: Optional[str]) -> None: """ diff --git a/cleanlab_studio/cli/dataset/upload_helpers.py b/cleanlab_studio/cli/dataset/upload_helpers.py index 3582c501..57844456 100644 --- a/cleanlab_studio/cli/dataset/upload_helpers.py +++ b/cleanlab_studio/cli/dataset/upload_helpers.py @@ -103,7 +103,7 @@ def validate_and_process_record( fields = schema.fields id_column = schema.metadata.id_column columns = list(fields) - dataset_filepath = dataset.filepath if hasattr(dataset, 'filepath') else None # type:ignore + dataset_filepath = dataset.filepath if hasattr(dataset, "filepath") else None # type:ignore row_id = record.get(id_column, None) if row_id == "" or row_id is None: diff --git a/cleanlab_studio/cli/types.py b/cleanlab_studio/cli/types.py index 3b052155..d4410852 100644 --- a/cleanlab_studio/cli/types.py +++ b/cleanlab_studio/cli/types.py @@ -4,5 +4,3 @@ class CommandState(TypedDict): command: Optional[str] args: Dict[str, Optional[str]] - - diff --git a/cleanlab_studio/errors.py b/cleanlab_studio/errors.py index 7aa37f23..ad6b5112 100644 --- a/cleanlab_studio/errors.py +++ b/cleanlab_studio/errors.py @@ -1,9 +1,13 @@ class APIError(Exception): pass + class UnsupportedVersionError(APIError): def __init__(self) -> None: - super().__init__("cleanlab-studio is out of date and must be upgraded. Run 'pip install --upgrade cleanlab-studio'.") + super().__init__( + "cleanlab-studio is out of date and must be upgraded. Run 'pip install --upgrade cleanlab-studio'." + ) + class AuthError(APIError): def __init__(self) -> None: diff --git a/cleanlab_studio/internal/dataset/pandas_dataset.py b/cleanlab_studio/internal/dataset/pandas_dataset.py index 5afebcd5..cf609f32 100644 --- a/cleanlab_studio/internal/dataset/pandas_dataset.py +++ b/cleanlab_studio/internal/dataset/pandas_dataset.py @@ -5,6 +5,7 @@ from .dataset import Dataset from ..types import RecordType + class PandasDataset(Dataset): def __init__(self, df: pd.DataFrame): super().__init__() diff --git a/cleanlab_studio/internal/schema.py b/cleanlab_studio/internal/schema.py index 70e785f5..c203fc86 100644 --- a/cleanlab_studio/internal/schema.py +++ b/cleanlab_studio/internal/schema.py @@ -28,7 +28,7 @@ class DataType(Enum): def as_numpy_type(self) -> Any: return { DataType.string: str, - DataType.integer: np.int64, # XXX backend might use big integers + DataType.integer: np.int64, # XXX backend might use big integers DataType.float: np.float64, DataType.boolean: bool, }[self] diff --git a/cleanlab_studio/studio/upload.py b/cleanlab_studio/studio/upload.py index 55d93da0..68664916 100644 --- a/cleanlab_studio/studio/upload.py +++ b/cleanlab_studio/studio/upload.py @@ -1,6 +1,6 @@ -''' +""" Utilities to upload datasets. -''' +""" # This implementation overlaps with the one in # cleanlab_studio.cli.upload_helpers. The two should be unified and put in @@ -26,7 +26,13 @@ IMAGE_UPLOAD_CHECKPOINT_SIZE = 100 MAX_PARALLEL_UPLOADS = 32 -def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optional[Schema] = None, dataset_id: Optional[str] = None) -> str: + +def upload_tabular_dataset( + api_key: str, + dataset: PandasDataset, + schema: Optional[Schema] = None, + dataset_id: Optional[str] = None, +) -> str: if dataset_id is None: # if ID is not specified, initialize dataset assert schema is not None @@ -49,13 +55,15 @@ def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optiona if row is None or row[id_column] in seen: continue seen.add(row[id_column]) - to_upload.append(list(row.values())) # ordered dict, order is preserved + to_upload.append(list(row.values())) # ordered dict, order is preserved # split upload into chunks # estimate size per row nelem = min(len(to_upload), 10) - size_per_row = len(gzip.compress(json.dumps(random.sample(to_upload, nelem)).encode('utf8'))) / nelem - num_per_chunk = max(int(10*1024*1024/size_per_row), 1) + size_per_row = ( + len(gzip.compress(json.dumps(random.sample(to_upload, nelem)).encode("utf8"))) / nelem + ) + num_per_chunk = max(int(10 * 1024 * 1024 / size_per_row), 1) chunks = split_into_chunks(to_upload, num_per_chunk) @@ -71,24 +79,33 @@ def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optiona def get_type(schema, path): - while '.' in path: - first, path = path.split('.', 1) + while "." in path: + first, path = path.split(".", 1) schema = schema[first].dataType return schema[path].dataType # dataset here is a pyspark DataFrame -def upload_image_dataset(api_key: str, dataframe, name: str, id_column: str, path_column: str, content_column: str, label_column: str, dataset_id: Optional[str] = None) -> str: +def upload_image_dataset( + api_key: str, + dataframe, + name: str, + id_column: str, + path_column: str, + content_column: str, + label_column: str, + dataset_id: Optional[str] = None, +) -> str: spark = dataframe.sparkSession - if get_type(dataframe.schema, content_column).typeName() != 'binary': - raise ValueError('content column must have binary type') + if get_type(dataframe.schema, content_column).typeName() != "binary": + raise ValueError("content column must have binary type") if dataset_id is None: # if ID is not specified, initialize dataset # create an appropriate schema, like we do in simple_image_upload label_column_type = get_type(dataframe.schema, label_column).typeName() # note: 'string' and 'integer' are DataType values; spark names line up with our internal names - if label_column_type not in ['string', 'integer']: - raise ValueError('label column must have string or integer type') + if label_column_type not in ["string", "integer"]: + raise ValueError("label column must have string or integer type") schema = Schema.create( metadata={ "id_column": id_column, @@ -129,19 +146,30 @@ def upload_image_dataset(api_key: str, dataframe, name: str, id_column: str, pat for chunk in chunks: row_ids = [row[0] for row in chunk] filepaths = [row[1] for row in chunk] - filepath_to_post = api.get_presigned_posts(api_key, dataset_id, filepaths=filepaths, row_ids=row_ids, media_type='image') + filepath_to_post = api.get_presigned_posts( + api_key, dataset_id, filepaths=filepaths, row_ids=row_ids, media_type="image" + ) # get contents for this chunk at once ids_df = spark.createDataFrame(pd.DataFrame({id_column: row_ids})) - contents = {row[0]: row[1] for row in ids_df.join(dataframe, on=id_column, how='left').select([id_column, content_column]).collect()} + contents = { + row[0]: row[1] + for row in ids_df.join(dataframe, on=id_column, how="left") + .select([id_column, content_column]) + .collect() + } + def upload_row(row): image_file = io.BytesIO(contents[row[0]]) - post_data = filepath_to_post[row[1]] # indexed by filepath + post_data = filepath_to_post[row[1]] # indexed by filepath presigned_post = post_data["post"] - res = requests.post(url=presigned_post["url"], - data=presigned_post["fields"], - files={"file": image_file}) + res = requests.post( + url=presigned_post["url"], + data=presigned_post["fields"], + files={"file": image_file}, + ) if not res.ok: - raise Exception(f'failure while uploading id {row[0]}') + raise Exception(f"failure while uploading id {row[0]}") + with multiprocessing.dummy.Pool(MAX_PARALLEL_UPLOADS) as p: # thread pool, not process pool; requests releases GIL in post p.map(upload_row, chunk) @@ -157,4 +185,4 @@ def upload_row(row): def split_into_chunks(l, n): for i in range(0, len(l), n): - yield l[i:i+n] + yield l[i : i + n]