Skip to content

Commit

Permalink
Fixing GPU checks inside test_histogram
Browse files Browse the repository at this point in the history
Signed-off-by: Julio Faracco <[email protected]>
  • Loading branch information
jcfaracco committed Sep 29, 2024
1 parent dba63a9 commit f20896a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/feature_extraction/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from dasf.feature_extraction import Histogram
from dasf.utils.funcs import is_gpu_supported

try:
import cupy as cp
Expand All @@ -26,6 +27,8 @@ def test_histogram_cpu(self):
self.assertEqual(bins[0].shape, (5,))
self.assertTrue(all([a == b for a, b in zip(bins[0], np.array([3, 9, 3, 2, 3]))]))

@unittest.skipIf(not is_gpu_supported(),
"not supported CUDA in this platform")
def test_histogram_gpu(self):
data = cp.array(self._data)

Expand All @@ -46,6 +49,8 @@ def test_histogram_mcpu(self):
self.assertEqual(bins[0].shape, (5,))
self.assertTrue(all([a == b for a, b in zip(bins[0], np.array([3, 9, 3, 2, 3]))]))

@unittest.skipIf(not is_gpu_supported(),
"not supported CUDA in this platform")
def test_histogram_mgpu(self):
data = da.from_array(cp.array(self._data))

Expand All @@ -61,4 +66,4 @@ def test_histogram_dask_without_range(self):

hist = Histogram(bins=5)

self.assertRaises(ValueError, hist._lazy_transform_gpu, X=data)
self.assertRaises(ValueError, hist._lazy_transform_cpu, X=data)

0 comments on commit f20896a

Please sign in to comment.