Skip to content

Commit

Permalink
Including more test cases for dasf-core
Browse files Browse the repository at this point in the history
Signed-off-by: Julio Faracco <[email protected]>
  • Loading branch information
jcfaracco committed Mar 11, 2024
1 parent 9feac1a commit 8023f33
Show file tree
Hide file tree
Showing 5 changed files with 304 additions and 8 deletions.
12 changes: 9 additions & 3 deletions dasf/pipeline/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def __init__(
rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm.rmm_cupy_allocator)
else:
raise Exception(f"'{gpu_allocator}' GPU Memory allocator is not "
"known")
raise ValueError(f"'{gpu_allocator}' GPU Memory allocator is not "
"known")
else:
self.dtype = TaskExecutorType.multi_cpu

Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(

@property
def ngpus(self):
return len(get_dask_gpu_count())
return get_dask_gpu_count()

@property
def is_connected(self):
Expand Down Expand Up @@ -181,6 +181,9 @@ def shutdown(self, gracefully=True):
else:
self.client.shutdown()

def close(self):
self.client.close()


class DaskTasksPipelineExecutor(DaskPipelineExecutor):
"""
Expand Down Expand Up @@ -293,6 +296,9 @@ def shutdown(self, gracefully=True):
else:
self.client.shutdown()

def close(self):
self.client.close()


class DaskPBSPipelineExecutor(Executor):
def __init__(self, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions dasf/transforms/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def transform(self, X):
if self.x > 1 or self.y > 1 or self.z > 1:
raise Exception("Percentages cannot be higher than 100% (1.0)")

if self.x <= 0 or self.y <= 0 or self.z <= 0:
raise Exception("Percentages cannot be negative or 0")

if X.ndim == 1:
return X[0 : int(self.x * X.shape[0])]
elif X.ndim == 2:
Expand Down
7 changes: 6 additions & 1 deletion dasf/utils/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ def get_worker_info(client) -> list:
Returns a list of workers (sorted), and the DNS name for the master host
The master is the 0th worker's host
"""
workers = client.scheduler_info()["workers"]
info = client.scheduler_info()

if "workers" not in info:
return []

workers = info["workers"]
worker_keys = sorted(workers.keys())
workers_by_host = {}
for key in worker_keys:
Expand Down
112 changes: 108 additions & 4 deletions tests/pipeline/executors/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,48 @@
import unittest
import urllib.parse

from mock import patch

from dask.distributed import Client, LocalCluster

from dasf.utils.funcs import is_gpu_supported
from dasf.pipeline.executors import DaskPipelineExecutor
from dasf.pipeline.executors.dask import setup_dask_protocol


class TestDaskProtocol(unittest.TestCase):
def test_setup_dask_protocol_none(self):
self.assertEqual(setup_dask_protocol(), "tcp://")

def test_setup_dask_protocol_tcp(self):
self.assertEqual(setup_dask_protocol("tcp"), "tcp://")

def test_setup_dask_protocol_ucx(self):
self.assertEqual(setup_dask_protocol("ucx"), "ucx://")

def test_setup_dask_protocol_foo(self):
with self.assertRaises(Exception) as context:
setup_dask_protocol("foo")

self.assertTrue('Protocol foo is not supported.' in str(context.exception))


class TestDaskExecutor(unittest.TestCase):
def setUp(self):
self.scheduler_file = os.path.abspath(f"{tempfile.gettempdir()}/scheduler.json")

def test_dask_executor_remote(self):

with LocalCluster() as cluster:
conn = urllib.parse.urlsplit(cluster.scheduler.address)

dask = DaskPipelineExecutor(address=conn.hostname, port=conn.port)

# Compute everything to gracefully shutdown
dask.client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

def test_dask_executor_local_no_args(self):
dask = DaskPipelineExecutor()
Expand All @@ -32,18 +57,94 @@ def test_dask_executor_local_no_args(self):

# Compute everything to gracefully shutdown
client.close()
dask.client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

def test_dask_executor_local_no_args_no_gracefully(self):
dask = DaskPipelineExecutor()

client = Client.current()

self.assertEqual(hash(dask.client), hash(client))

# Compute everything to gracefully shutdown
client.close()
dask.shutdown(gracefully=False)
dask.close()

self.assertFalse(dask.is_connected)

def test_dask_executor_local(self):
dask = DaskPipelineExecutor(local=True)

client = Client.current()

self.assertTrue(dask.is_connected)
self.assertEqual(hash(dask.client), hash(client))

# Compute everything to gracefully shutdown
client.close()
dask.client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

@unittest.skipIf(not is_gpu_supported(),
"not supported CUDA in this platform")
def test_dask_executor_local_gpu(self):
with patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '0'}):

dask = DaskPipelineExecutor(local=True, use_gpu=True)

client = Client.current()

self.assertEqual(hash(dask.client), hash(client))
self.assertGreater(dask.ngpus, 0)

# Compute everything to gracefully shutdown
client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

@unittest.skipIf(not is_gpu_supported(),
"not supported CUDA in this platform")
def test_dask_executor_local_gpu_and_rmm(self):
with patch.dict(os.environ, {'CUDA_VISIBLE_DEVICES': '0'}):

dask = DaskPipelineExecutor(local=True, use_gpu=True, gpu_allocator="rmm")

client = Client.current()

self.assertEqual(hash(dask.client), hash(client))

# Compute everything to gracefully shutdown
client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

@unittest.skipIf(not is_gpu_supported(),
"not supported CUDA in this platform")
def test_dask_executor_local_gpu_and_unknown_allocator(self):
with self.assertRaises(ValueError) as context:

dask = DaskPipelineExecutor(local=True, use_gpu=True, gpu_allocator="foo")

client = Client.current()

self.assertEqual(hash(dask.client), hash(client))

# Compute everything to gracefully shutdown
client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

def test_dask_executor_scheduler_file(self):
with LocalCluster() as cluster:
Expand All @@ -63,7 +164,10 @@ def test_dask_executor_scheduler_file(self):
self.assertEqual(hash(dask.client), hash(client))

# Compute everything to gracefully shutdown
dask.client.close()
dask.shutdown(gracefully=True)
dask.close()

self.assertFalse(dask.is_connected)

def tearDown(self):
if os.path.isfile(self.scheduler_file) or os.path.islink(self.scheduler_file):
Expand Down
Loading

0 comments on commit 8023f33

Please sign in to comment.