From 9c414d99deee0a439f30a613cc956a498c3c54d4 Mon Sep 17 00:00:00 2001 From: neka-nat Date: Sat, 11 May 2024 15:19:45 +0900 Subject: [PATCH] fix using cuda --- examples/cpd_affine3d_cuda.py | 2 -- examples/cpd_nonrigid3d_cuda.py | 1 - examples/cpd_rigid_cuda.py | 4 +--- probreg/cpd.py | 11 +++++++++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/cpd_affine3d_cuda.py b/examples/cpd_affine3d_cuda.py index 96061b6..eeaa814 100644 --- a/examples/cpd_affine3d_cuda.py +++ b/examples/cpd_affine3d_cuda.py @@ -7,9 +7,7 @@ else: cp = np to_cpu = lambda x: x -import open3d as o3 from probreg import cpd -from probreg import callbacks import utils import time diff --git a/examples/cpd_nonrigid3d_cuda.py b/examples/cpd_nonrigid3d_cuda.py index 03d2985..3d51b44 100644 --- a/examples/cpd_nonrigid3d_cuda.py +++ b/examples/cpd_nonrigid3d_cuda.py @@ -9,7 +9,6 @@ to_cpu = lambda x: x import open3d as o3 from probreg import cpd -from probreg import callbacks import utils import time diff --git a/examples/cpd_rigid_cuda.py b/examples/cpd_rigid_cuda.py index 20ecfbc..6fb75c4 100644 --- a/examples/cpd_rigid_cuda.py +++ b/examples/cpd_rigid_cuda.py @@ -7,10 +7,8 @@ else: cp = np to_cpu = lambda x: x -import open3d as o3 -import transforms3d as trans +import transforms3d as t3d from probreg import cpd -from probreg import callbacks import utils import time diff --git a/probreg/cpd.py b/probreg/cpd.py index 306b779..afb20bf 100644 --- a/probreg/cpd.py +++ b/probreg/cpd.py @@ -25,6 +25,14 @@ """ +class DistModule: + def __init__(self, xp): + self.xp = xp + + def cdist(self, x1, x2, metric): + return self.xp.stack([self.xp.sum(self.xp.square(x2 - ts), axis=1) for ts in x1]) + + @six.add_metaclass(abc.ABCMeta) class CoherentPointDrift: """Coherent Point Drift algorithm. @@ -50,12 +58,11 @@ def __init__(self, source: Optional[np.ndarray] = None, use_color: bool = False, self._use_color = use_color if use_cuda: import cupy as cp - from cupyx.scipy.spatial import distance as cupy_distance from . import cupy_utils self.xp = cp - self.distance_module = cupy_distance + self.distance_module = DistModule(cp) self.cupy_utils = cupy_utils self._squared_kernel_sum = cupy_utils.squared_kernel_sum else: