Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
anishathalye committed Dec 20, 2022
1 parent 17d3349 commit 19ceb3b
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 34 deletions.
11 changes: 7 additions & 4 deletions cleanlab_studio/cli/api_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
from cleanlab_studio.cli.dataset.image_utils import get_image_filepath
from cleanlab_studio.internal.schema import Schema
from cleanlab_studio.internal.types import JSONDict, IDType, Modality
from cleanlab_studio.internal.api import base_url, _construct_headers, handle_api_error, handle_api_error_from_json, get_presigned_posts

from cleanlab_studio.internal.api import (
base_url,
_construct_headers,
handle_api_error,
handle_api_error_from_json,
get_presigned_posts,
)


MAX_PARALLEL_UPLOADS = 32 # XXX choose this dynamically?
INITIAL_BACKOFF = 0.25 # seconds
MAX_RETRIES = 4




async def upload_rows_async(
session: aiohttp.ClientSession,
api_key: str,
Expand Down
4 changes: 0 additions & 4 deletions cleanlab_studio/cli/dataset/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from cleanlab_studio.version import MAX_SCHEMA_VERSION, MIN_SCHEMA_VERSION, SCHEMA_VERSION




def load_schema(filepath: str) -> Schema:
with open(filepath, "r") as f:
schema_dict = json.load(f)
Expand Down Expand Up @@ -108,8 +106,6 @@ def validate_schema(schema: Schema, columns: Collection[str]) -> None:
raise ValueError("Dataset modality is text, but none of the fields is a text column.")




def save_schema(schema: Schema, filename: Optional[str]) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion cleanlab_studio/cli/dataset/upload_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def validate_and_process_record(
fields = schema.fields
id_column = schema.metadata.id_column
columns = list(fields)
dataset_filepath = dataset.filepath if hasattr(dataset, 'filepath') else None # type:ignore
dataset_filepath = dataset.filepath if hasattr(dataset, "filepath") else None # type:ignore
row_id = record.get(id_column, None)

if row_id == "" or row_id is None:
Expand Down
2 changes: 0 additions & 2 deletions cleanlab_studio/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,3 @@
class CommandState(TypedDict):
command: Optional[str]
args: Dict[str, Optional[str]]


6 changes: 5 additions & 1 deletion cleanlab_studio/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
class APIError(Exception):
pass


class UnsupportedVersionError(APIError):
def __init__(self) -> None:
super().__init__("cleanlab-studio is out of date and must be upgraded. Run 'pip install --upgrade cleanlab-studio'.")
super().__init__(
"cleanlab-studio is out of date and must be upgraded. Run 'pip install --upgrade cleanlab-studio'."
)


class AuthError(APIError):
def __init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions cleanlab_studio/internal/dataset/pandas_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dataset import Dataset
from ..types import RecordType


class PandasDataset(Dataset):
def __init__(self, df: pd.DataFrame):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion cleanlab_studio/internal/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class DataType(Enum):
def as_numpy_type(self) -> Any:
return {
DataType.string: str,
DataType.integer: np.int64, # XXX backend might use big integers
DataType.integer: np.int64, # XXX backend might use big integers
DataType.float: np.float64,
DataType.boolean: bool,
}[self]
Expand Down
70 changes: 49 additions & 21 deletions cleanlab_studio/studio/upload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Utilities to upload datasets.
'''
"""

# This implementation overlaps with the one in
# cleanlab_studio.cli.upload_helpers. The two should be unified and put in
Expand All @@ -26,7 +26,13 @@
IMAGE_UPLOAD_CHECKPOINT_SIZE = 100
MAX_PARALLEL_UPLOADS = 32

def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optional[Schema] = None, dataset_id: Optional[str] = None) -> str:

def upload_tabular_dataset(
api_key: str,
dataset: PandasDataset,
schema: Optional[Schema] = None,
dataset_id: Optional[str] = None,
) -> str:
if dataset_id is None:
# if ID is not specified, initialize dataset
assert schema is not None
Expand All @@ -49,13 +55,15 @@ def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optiona
if row is None or row[id_column] in seen:
continue
seen.add(row[id_column])
to_upload.append(list(row.values())) # ordered dict, order is preserved
to_upload.append(list(row.values())) # ordered dict, order is preserved

# split upload into chunks
# estimate size per row
nelem = min(len(to_upload), 10)
size_per_row = len(gzip.compress(json.dumps(random.sample(to_upload, nelem)).encode('utf8'))) / nelem
num_per_chunk = max(int(10*1024*1024/size_per_row), 1)
size_per_row = (
len(gzip.compress(json.dumps(random.sample(to_upload, nelem)).encode("utf8"))) / nelem
)
num_per_chunk = max(int(10 * 1024 * 1024 / size_per_row), 1)

chunks = split_into_chunks(to_upload, num_per_chunk)

Expand All @@ -71,24 +79,33 @@ def upload_tabular_dataset(api_key: str, dataset: PandasDataset, schema: Optiona


def get_type(schema, path):
while '.' in path:
first, path = path.split('.', 1)
while "." in path:
first, path = path.split(".", 1)
schema = schema[first].dataType
return schema[path].dataType


# dataset here is a pyspark DataFrame
def upload_image_dataset(api_key: str, dataframe, name: str, id_column: str, path_column: str, content_column: str, label_column: str, dataset_id: Optional[str] = None) -> str:
def upload_image_dataset(
api_key: str,
dataframe,
name: str,
id_column: str,
path_column: str,
content_column: str,
label_column: str,
dataset_id: Optional[str] = None,
) -> str:
spark = dataframe.sparkSession
if get_type(dataframe.schema, content_column).typeName() != 'binary':
raise ValueError('content column must have binary type')
if get_type(dataframe.schema, content_column).typeName() != "binary":
raise ValueError("content column must have binary type")
if dataset_id is None:
# if ID is not specified, initialize dataset
# create an appropriate schema, like we do in simple_image_upload
label_column_type = get_type(dataframe.schema, label_column).typeName()
# note: 'string' and 'integer' are DataType values; spark names line up with our internal names
if label_column_type not in ['string', 'integer']:
raise ValueError('label column must have string or integer type')
if label_column_type not in ["string", "integer"]:
raise ValueError("label column must have string or integer type")
schema = Schema.create(
metadata={
"id_column": id_column,
Expand Down Expand Up @@ -129,19 +146,30 @@ def upload_image_dataset(api_key: str, dataframe, name: str, id_column: str, pat
for chunk in chunks:
row_ids = [row[0] for row in chunk]
filepaths = [row[1] for row in chunk]
filepath_to_post = api.get_presigned_posts(api_key, dataset_id, filepaths=filepaths, row_ids=row_ids, media_type='image')
filepath_to_post = api.get_presigned_posts(
api_key, dataset_id, filepaths=filepaths, row_ids=row_ids, media_type="image"
)
# get contents for this chunk at once
ids_df = spark.createDataFrame(pd.DataFrame({id_column: row_ids}))
contents = {row[0]: row[1] for row in ids_df.join(dataframe, on=id_column, how='left').select([id_column, content_column]).collect()}
contents = {
row[0]: row[1]
for row in ids_df.join(dataframe, on=id_column, how="left")
.select([id_column, content_column])
.collect()
}

def upload_row(row):
image_file = io.BytesIO(contents[row[0]])
post_data = filepath_to_post[row[1]] # indexed by filepath
post_data = filepath_to_post[row[1]] # indexed by filepath
presigned_post = post_data["post"]
res = requests.post(url=presigned_post["url"],
data=presigned_post["fields"],
files={"file": image_file})
res = requests.post(
url=presigned_post["url"],
data=presigned_post["fields"],
files={"file": image_file},
)
if not res.ok:
raise Exception(f'failure while uploading id {row[0]}')
raise Exception(f"failure while uploading id {row[0]}")

with multiprocessing.dummy.Pool(MAX_PARALLEL_UPLOADS) as p:
# thread pool, not process pool; requests releases GIL in post
p.map(upload_row, chunk)
Expand All @@ -157,4 +185,4 @@ def upload_row(row):

def split_into_chunks(l, n):
for i in range(0, len(l), n):
yield l[i:i+n]
yield l[i : i + n]

0 comments on commit 19ceb3b

Please sign in to comment.