Skip to content

Commit

Permalink
sdk/python: object group num_workers
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Koo <[email protected]>
  • Loading branch information
rkoo19 committed Sep 17, 2024
1 parent e40ec36 commit c0faf87
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 14 deletions.
14 changes: 13 additions & 1 deletion python/aistore/sdk/multiobj/object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,11 @@ def prefetch(
Args:
latest (bool, optional): GET the latest object version from the associated remote bucket
continue_on_error (bool, optional): Whether to continue if there is an error prefetching a single object
num_workers (int, optional): Number of concurrent workers; number of target mountpaths if omitted or zero
blob_threshold (int, optional): Utilize built-in blob-downloader for remote objects
greater than the specified (threshold) size in bytes
num_workers (int, optional): Number of concurrent workers (readers). Defaults to the number of target
mountpaths if omitted or zero. A value of -1 indicates no workers at all (i.e., single-threaded
execution). Any positive value will be adjusted not to exceed the number of target CPUs.
Raises:
aistore.sdk.errors.AISError: All other types of errors with AIStore
Expand Down Expand Up @@ -225,6 +227,7 @@ def copy(
force: bool = False,
latest: bool = False,
sync: bool = False,
num_workers: int = None,
):
"""
Copies a list or range of objects in a bucket
Expand All @@ -238,6 +241,9 @@ def copy(
(see "limited coexistence" and xact/xreg/xreg.go)
latest (bool, optional): GET the latest object version from the associated remote bucket
sync (bool, optional): synchronize destination bucket with its remote (e.g., Cloud or remote AIS) source
num_workers (int, optional): Number of concurrent workers (readers). Defaults to the number of target
mountpaths if omitted or zero. A value of -1 indicates no workers at all (i.e., single-threaded
execution). Any positive value will be adjusted not to exceed the number of target CPUs.
Raises:
aistore.sdk.errors.AISError: All other types of errors with AIStore
Expand Down Expand Up @@ -268,6 +274,7 @@ def copy(
tc_msg=TCBckMsg(copy_msg=copy_msg),
object_selection=self._obj_collection.get_value(),
continue_on_err=continue_on_error,
num_workers=num_workers,
).as_dict()

return self.bck.make_request(
Expand All @@ -288,6 +295,7 @@ def transform(
force: bool = False,
latest: bool = False,
sync: bool = False,
num_workers: int = None,
):
"""
Performs ETL operation on a list or range of objects in a bucket, placing the results in the destination bucket
Expand All @@ -303,6 +311,9 @@ def transform(
(see "limited coexistence" and xact/xreg/xreg.go)
latest (bool, optional): GET the latest object version from the associated remote bucket
sync (bool, optional): synchronize destination bucket with its remote (e.g., Cloud or remote AIS) source
num_workers (int, optional): Number of concurrent workers (readers). Defaults to the number of target
mountpaths if omitted or zero. A value of -1 indicates no workers at all (i.e., single-threaded
execution). Any positive value will be adjusted not to exceed the number of target CPUs.
Raises:
aistore.sdk.errors.AISError: All other types of errors with AIStore
Expand Down Expand Up @@ -333,6 +344,7 @@ def transform(
tc_msg=TCBckMsg(transform_msg=transform_msg, copy_msg=copy_msg),
object_selection=self._obj_collection.get_value(),
continue_on_err=continue_on_error,
num_workers=num_workers,
).as_dict()
return self.bck.make_request(
HTTP_METHOD_POST, ACT_TRANSFORM_OBJECTS, value=value
Expand Down
8 changes: 7 additions & 1 deletion python/aistore/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)


# pylint: disable=too-few-public-methods,unused-variable,missing-function-docstring
# pylint: disable=too-few-public-methods,unused-variable,missing-function-docstring,too-many-lines


class ActionMsg(BaseModel):
Expand Down Expand Up @@ -516,13 +516,16 @@ class PrefetchMsg(BaseModel):
continue_on_err: bool
latest: bool
blob_threshold: int = None
num_workers: int = None

def as_dict(self):
dict_rep = self.object_selection
dict_rep["coer"] = self.continue_on_err
dict_rep["latest-ver"] = self.latest
if self.blob_threshold:
dict_rep["blob-threshold"] = self.blob_threshold
if self.num_workers:
dict_rep["num-workers"] = self.num_workers
return dict_rep


Expand All @@ -535,6 +538,7 @@ class TCMultiObj(BaseModel):
tc_msg: TCBckMsg = None
continue_on_err: bool
object_selection: Dict
num_workers: int = None

def as_dict(self):
dict_rep = self.object_selection
Expand All @@ -543,6 +547,8 @@ def as_dict(self):
dict_rep[key] = val
dict_rep["tobck"] = self.to_bck.as_dict()
dict_rep["coer"] = self.continue_on_err
if self.num_workers:
dict_rep["num-workers"] = self.num_workers
return dict_rep


Expand Down
62 changes: 50 additions & 12 deletions python/tests/integration/sdk/test_object_group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,31 @@ def test_evict_objects_local(self):
with self.assertRaises(InvalidBckProvider):
local_bucket.objects(obj_names=[]).evict()

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
)
def test_prefetch_list(self):
def _prefetch_objects_test_helper(self, num_workers=None):
obj_group = self.bucket.objects(obj_names=self.obj_names[1:])
self._evict_all_objects()
prefetch_kwargs = {}
if num_workers is not None:
prefetch_kwargs["num_workers"] = num_workers
# Fetch back a specific object group and verify cache status
job_id = obj_group.prefetch()
job_id = obj_group.prefetch(**prefetch_kwargs)
self.client.job(job_id).wait(timeout=TEST_TIMEOUT * 2)
self._verify_cached_objects(OBJECT_COUNT, range(1, OBJECT_COUNT))

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
)
def test_prefetch_list(self):
self._prefetch_objects_test_helper()

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
)
def test_prefetch_list_with_num_workers(self):
self._prefetch_objects_test_helper(num_workers=3)

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
Expand Down Expand Up @@ -119,7 +132,7 @@ def test_prefetch_objects_local(self):
with self.assertRaises(InvalidBckProvider):
local_bucket.objects(obj_names=[]).prefetch()

def test_copy_objects(self):
def _copy_objects_test_helper(self, num_workers=None):
to_bck_name = "destination-bucket"
to_bck = self._create_bucket(to_bck_name)
self.assertEqual(0, len(to_bck.list_all_objects(prefix=self.obj_prefix)))
Expand All @@ -128,15 +141,25 @@ def test_copy_objects(self):
)

new_prefix = PREFIX_NAME
copy_kwargs = {"prepend": new_prefix}
if num_workers is not None:
copy_kwargs["num_workers"] = num_workers

copy_job = self.bucket.objects(obj_names=self.obj_names[1:5]).copy(
to_bck, prepend=new_prefix
to_bck, **copy_kwargs
)
self.client.job(job_id=copy_job).wait_for_idle(timeout=TEST_TIMEOUT)

self.assertEqual(
4, len(to_bck.list_all_objects(prefix=new_prefix + self.obj_prefix))
)

def test_copy_objects(self):
self._copy_objects_test_helper()

def test_copy_objects_with_num_workers(self):
self._copy_objects_test_helper(num_workers=3)

@unittest.skipIf(
not REMOTE_SET,
"Remote bucket is not set",
Expand Down Expand Up @@ -302,9 +325,8 @@ def _archive_exec_assert(self, arch_name, src_bck, res_bck, **kwargs):
member_names.append(member.name)
self.assertEqual(set(archived_names), set(member_names))

@pytest.mark.etl
def test_transform_objects(self):
# Define an etl with code that hashes the contents of each object
def _transform_objects_test_helper(self, num_workers=None):
# Define an ETL that hashes the contents of each object
etl_name = "etl-" + random_string(5)

def transform(input_bytes):
Expand All @@ -323,8 +345,16 @@ def transform(input_bytes):
OBJECT_COUNT, len(self.bucket.list_all_objects(prefix=self.obj_prefix))
)

transform_kwargs = {
"to_bck": to_bck,
"etl_name": md5_etl.name,
"prepend": new_prefix,
}
if num_workers is not None:
transform_kwargs["num_workers"] = num_workers

transform_job = self.bucket.objects(obj_names=self.obj_names).transform(
to_bck, etl_name=md5_etl.name, prepend=new_prefix
**transform_kwargs
)
self.client.job(job_id=transform_job).wait_for_idle(timeout=TEST_TIMEOUT)

Expand All @@ -338,6 +368,14 @@ def transform(input_bytes):
]
self.assertEqual(to_obj_values, from_obj_hashes)

@pytest.mark.etl
def test_transform_objects(self):
self._transform_objects_test_helper()

@pytest.mark.etl
def test_transform_objects_with_num_workers(self):
self._transform_objects_test_helper(num_workers=3)

def _evict_all_objects(self, num_obj=OBJECT_COUNT):
job_id = self.bucket.objects(obj_names=self.obj_names).evict()
self.client.job(job_id).wait(timeout=TEST_TIMEOUT)
Expand Down
8 changes: 8 additions & 0 deletions python/tests/unit/sdk/multiobj/test_object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def test_prefetch(self):
prefetch_expected_val = self.expected_value.copy()
prefetch_expected_val["coer"] = False
prefetch_expected_val["latest-ver"] = False
prefetch_expected_val["num-workers"] = 3
self.object_group_test_helper(
self.object_group.prefetch,
HTTP_METHOD_POST,
ACT_PREFETCH_OBJECTS,
prefetch_expected_val,
num_workers=3,
)

def test_prefetch_with_blob_threshold(self):
Expand All @@ -112,13 +114,15 @@ def test_prefetch_with_blob_threshold(self):
prefetch_expected_val["latest-ver"] = False
blob_threshold_value = LARGE_FILE_SIZE
prefetch_expected_val["blob-threshold"] = blob_threshold_value
prefetch_expected_val["num-workers"] = 3

self.object_group_test_helper(
self.object_group.prefetch,
HTTP_METHOD_POST,
ACT_PREFETCH_OBJECTS,
prefetch_expected_val,
blob_threshold=blob_threshold_value,
num_workers=3,
)

def test_copy(self):
Expand Down Expand Up @@ -146,6 +150,7 @@ def test_copy(self):
self.expected_value["coer"] = True
self.expected_value["latest-ver"] = False
self.expected_value["synchronize"] = False
self.expected_value["num-workers"] = 3

self.object_group_test_helper(
self.object_group.copy,
Expand All @@ -157,6 +162,7 @@ def test_copy(self):
force=True,
dry_run=True,
continue_on_error=True,
num_workers=3,
)

@patch("aistore.sdk.multiobj.object_group.logging")
Expand Down Expand Up @@ -197,6 +203,7 @@ def test_transform(self):
self.expected_value["request_timeout"] = timeout
self.expected_value["dry_run"] = True
self.expected_value["force"] = True
self.expected_value["num-workers"] = 3
self.object_group_test_helper(
self.object_group.transform,
HTTP_METHOD_POST,
Expand All @@ -209,6 +216,7 @@ def test_transform(self):
dry_run=True,
force=True,
continue_on_error=True,
num_workers=3,
)

@patch("aistore.sdk.multiobj.object_group.logging")
Expand Down

0 comments on commit c0faf87

Please sign in to comment.