diff --git a/tests/feature_extraction/test_histogram.py b/tests/feature_extraction/test_histogram.py index 50fdd05..6b4f06e 100644 --- a/tests/feature_extraction/test_histogram.py +++ b/tests/feature_extraction/test_histogram.py @@ -6,6 +6,7 @@ import numpy as np from dasf.feature_extraction import Histogram +from dasf.utils.funcs import is_gpu_supported try: import cupy as cp @@ -26,6 +27,8 @@ def test_histogram_cpu(self): self.assertEqual(bins[0].shape, (5,)) self.assertTrue(all([a == b for a, b in zip(bins[0], np.array([3, 9, 3, 2, 3]))])) + @unittest.skipIf(not is_gpu_supported(), + "not supported CUDA in this platform") def test_histogram_gpu(self): data = cp.array(self._data) @@ -46,6 +49,8 @@ def test_histogram_mcpu(self): self.assertEqual(bins[0].shape, (5,)) self.assertTrue(all([a == b for a, b in zip(bins[0], np.array([3, 9, 3, 2, 3]))])) + @unittest.skipIf(not is_gpu_supported(), + "not supported CUDA in this platform") def test_histogram_mgpu(self): data = da.from_array(cp.array(self._data)) @@ -61,4 +66,4 @@ def test_histogram_dask_without_range(self): hist = Histogram(bins=5) - self.assertRaises(ValueError, hist._lazy_transform_gpu, X=data) + self.assertRaises(ValueError, hist._lazy_transform_cpu, X=data)