Skip to content

Commit

Permalink
Merge pull request #112 from neka-nat/color_cpd
Browse files Browse the repository at this point in the history
add color cpd
  • Loading branch information
neka-nat authored May 10, 2024
2 parents 9c838b7 + 913811e commit 6da9474
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 53 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This package implements several algorithms using stochastic models and provides
* Maximum likelihood when the target or source point cloud is observation data
* [Coherent Point Drift (2010)](https://arxiv.org/pdf/0905.2635.pdf)
* [Extended Coherent Point Drift (2016)](https://ieeexplore.ieee.org/abstract/document/7477719) (add correspondence priors to CPD)
* [Color Coherent Point Drift (2018)](https://arxiv.org/pdf/1802.01516)
* [FilterReg (CVPR2019)](https://arxiv.org/pdf/1811.10136.pdf)
* Variational Bayesian inference
* [Bayesian Coherent Point Drift (2020)](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8985307)
Expand Down
22 changes: 22 additions & 0 deletions examples/color_cpd_rigid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import open3d as o3
import transforms3d as t3d
from probreg import cpd
from probreg import callbacks
import logging
log = logging.getLogger('probreg')
log.setLevel(logging.DEBUG)

voxel_size = 0.05
source = o3.io.read_point_cloud("frag_115.ply")
source = source.voxel_down_sample(voxel_size=voxel_size)

target = o3.io.read_point_cloud("frag_116.ply")
target = target.voxel_down_sample(voxel_size=voxel_size)

cbs = [callbacks.Open3dVisualizerCallback(source, target)]
tf_param, _, _ = cpd.registration_cpd(source, target,
callbacks=cbs, use_color=True)

print("result: ", np.rad2deg(t3d.euler.mat2euler(tf_param.rot)),
tf_param.scale, tf_param.t)
Binary file added examples/frag_115.ply
Binary file not shown.
Binary file added examples/frag_116.ply
Binary file not shown.
3 changes: 1 addition & 2 deletions probreg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from . import (bcpd, callbacks, cpd, filterreg, gmmtree, l2dist_regs, log,
math_utils, transformation)
from . import bcpd, callbacks, cpd, filterreg, gmmtree, l2dist_regs, log, math_utils, transformation
from .version import __version__
6 changes: 3 additions & 3 deletions probreg/bcpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def expectation_step(self, t_source, target, scale, alpha, sigma_mat, sigma2, w=
pmat = np.exp(-pmat / (2.0 * sigma2))
pmat /= (2.0 * np.pi * sigma2) ** (dim * 0.5)
pmat = pmat.T
pmat *= np.exp(-(scale ** 2) / (2 * sigma2) * np.diag(sigma_mat) * dim)
pmat *= np.exp(-(scale**2) / (2 * sigma2) * np.diag(sigma_mat) * dim)
pmat *= (1.0 - w) * alpha
den = w / target.shape[0] + np.sum(pmat, axis=1)
den[den == 0] = np.finfo(np.float32).eps
Expand Down Expand Up @@ -126,7 +126,7 @@ def _maximization_step(source, target, rigid_trans, estep_res, gmat_inv, lmd, k,
nu_d, nu, n_p, px, x_hat = estep_res
dim = source.shape[1]
m = source.shape[0]
s2s2 = rigid_trans.scale ** 2 / (sigma2_p ** 2)
s2s2 = rigid_trans.scale**2 / (sigma2_p**2)
sigma_mat_inv = lmd * gmat_inv + s2s2 * np.diag(nu)
sigma_mat = np.linalg.inv(sigma_mat_inv)
residual = rigid_trans.inverse().transform(x_hat) - source
Expand All @@ -152,7 +152,7 @@ def _maximization_step(source, target, rigid_trans, estep_res, gmat_inv, lmd, k,
s1 = np.dot(target.ravel(), np.kron(nu_d, np.ones(dim)) * target.ravel())
s2 = np.dot(px.ravel(), y_hat.ravel())
s3 = np.dot(y_hat.ravel(), np.kron(nu, np.ones(dim)) * y_hat.ravel())
sigma2 = (s1 - 2.0 * s2 + s3) / (n_p * dim) + scale ** 2 * sigma2_m
sigma2 = (s1 - 2.0 * s2 + s3) / (n_p * dim) + scale**2 * sigma2_m
return MstepResult(tf.CombinedTransformation(rot, t, scale, v_hat), u_hat, sigma_mat, alpha, sigma2)


Expand Down
24 changes: 16 additions & 8 deletions probreg/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def asnumpy(x):
from .transformation import Transformation


class Plot2DCallback(object):
class Plot2DCallback:
"""Display the 2D registration result of each iteration.
Args:
Expand Down Expand Up @@ -62,12 +62,12 @@ def __call__(self, transformation: Transformation) -> None:
self._cnt += 1


class Open3dVisualizerCallback(object):
class Open3dVisualizerCallback:
"""Display the 3D registration result of each iteration.
Args:
source (numpy.ndarray): Source point cloud data.
target (numpy.ndarray): Target point cloud data.
source (open3d.geometry.PointCloud): Source point cloud data.
target (open3d.geometry.PointCloud): Target point cloud data.
save (bool, optional): If this flag is True,
each iteration image is saved in a sequential number.
keep_window (bool, optional): If this flag is True,
Expand All @@ -76,7 +76,12 @@ class Open3dVisualizerCallback(object):
"""

def __init__(
self, source: np.ndarray, target: np.ndarray, save: bool = False, keep_window: bool = True, fov: Any = None
self,
source: o3.geometry.PointCloud,
target: o3.geometry.PointCloud,
save: bool = False,
keep_window: bool = True,
fov: Any = None,
):
self._vis = o3.visualization.Visualizer()
self._vis.create_window()
Expand All @@ -85,9 +90,12 @@ def __init__(
self._result = copy.deepcopy(self._source)
self._save = save
self._keep_window = keep_window
self._source.paint_uniform_color([1, 0, 0])
self._target.paint_uniform_color([0, 1, 0])
self._result.paint_uniform_color([0, 0, 1])
if not self._source.has_colors():
self._source.paint_uniform_color([1, 0, 0])
if not self._target.has_colors():
self._target.paint_uniform_color([0, 1, 0])
if not self._result.has_colors():
self._result.paint_uniform_color([0, 0, 1])
self._vis.add_geometry(self._source)
self._vis.add_geometry(self._target)
self._vis.add_geometry(self._result)
Expand Down
4 changes: 2 additions & 2 deletions probreg/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def __call__(self, theta: np.ndarray, *args):
def compute_l2_dist(
mu_source: np.ndarray, phi_source: np.ndarray, mu_target: np.ndarray, phi_target: np.ndarray, sigma: float
):
z = np.power(2.0 * np.pi * sigma ** 2, mu_source.shape[1] * 0.5)
z = np.power(2.0 * np.pi * sigma**2, mu_source.shape[1] * 0.5)
gtrans = gt.GaussTransform(mu_target, np.sqrt(2.0) * sigma)
phi_j_e = gtrans.compute(mu_source, phi_target / z)
phi_mu_j_e = gtrans.compute(mu_source, phi_target * mu_target.T / z).T
g = (phi_source * phi_j_e * mu_source.T - phi_source * phi_mu_j_e.T).T / (2.0 * sigma ** 2)
g = (phi_source * phi_j_e * mu_source.T - phi_source * phi_mu_j_e.T).T / (2.0 * sigma**2)
return -np.dot(phi_source, phi_j_e), g


Expand Down
Loading

0 comments on commit 6da9474

Please sign in to comment.