From fae685dfd9082aca438d916b916e902951e7f03d Mon Sep 17 00:00:00 2001 From: Julio Faracco Date: Mon, 23 Sep 2024 01:07:17 -0300 Subject: [PATCH] Adding more test cases to SOM Signed-off-by: Julio Faracco --- tests/ml/cluster/test_som.py | 58 ++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/ml/cluster/test_som.py b/tests/ml/cluster/test_som.py index 5f0e241..a2565d0 100644 --- a/tests/ml/cluster/test_som.py +++ b/tests/ml/cluster/test_som.py @@ -126,3 +126,61 @@ def test_som_mcpu_local(self): q2 = som.quantization_error(da_X) self.assertTrue(q1 > q2) + + def test_som_2_cpu(self): + som = SOM(x=3, y=2, input_len=2, num_epochs=300) + + q1 = som._quantization_error_cpu(self.X) + + y = som._fit_predict_cpu(self.X) + + q2 = som._quantization_error_cpu(self.X) + + self.assertTrue(is_cpu_array(y)) + self.assertTrue(q1 > q2) + + def test_som_2_mcpu(self): + som = SOM(x=3, y=2, input_len=2, num_epochs=300) + + da_X = da.from_array(self.X, meta=np.array((), dtype=np.float32)) + + q1 = som._lazy_quantization_error_cpu(da_X) + + y = som._lazy_fit_predict_cpu(da_X) + + q2 = som._lazy_quantization_error_cpu(da_X) + + self.assertTrue(is_dask_cpu_array(y)) + self.assertTrue(q1 > q2) + + @unittest.skipIf(not is_gpu_supported(), + "not supported CUDA in this platform") + def test_som_2_gpu(self): + som = SOM(x=3, y=2, input_len=2, num_epochs=300) + + cp_X = cp.asarray(self.X) + + q1 = som._quantization_error_gpu(cp_X) + + y = som._fit_predict_gpu(cp_X) + + q2 = som._quantization_error_gpu(cp_X) + + self.assertTrue(is_gpu_array(y)) + self.assertTrue(q1 > q2) + + @unittest.skipIf(not is_gpu_supported(), + "not supported CUDA in this platform") + def test_som_2_mgpu(self): + som = SOM(x=3, y=2, input_len=2, num_epochs=300) + + da_X = da.from_array(cp.asarray(self.X), meta=cp.array((), dtype=cp.float32)) + + q1 = som._lazy_quantization_error_gpu(da_X) + + y = som._lazy_fit_predict_gpu(da_X) + + q2 = som._lazy_quantization_error_gpu(da_X) + + self.assertTrue(is_dask_gpu_array(y)) + self.assertTrue(q1 > q2)