Skip to content

Commit

Permalink
Merge pull request #10 from skhrg/transform_refactor
Browse files Browse the repository at this point in the history
Transform refactor
  • Loading branch information
skhrg authored Apr 2, 2024
2 parents be397c4 + edbcd75 commit ac5086d
Showing 1 changed file with 65 additions and 89 deletions.
154 changes: 65 additions & 89 deletions megham/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
def get_shift(
src: NDArray[np.floating],
dst: NDArray[np.floating],
row_basis: bool = True,
method: str = "median",
weights: Optional[NDArray[np.floating]] = None,
) -> NDArray[np.floating]:
Expand All @@ -29,11 +28,7 @@ def get_shift(
but really any array broadcastable with src is accepted.
Some useful options are:
* np.zeros(1) to align with the origin
* A (ndim, 1) array to align with an arbitrary point
row_basis : bool, default: True
If the basis of the points is row.
If row basis then each row of src and dst is a point.
If col basis then each col of src and dst is a point.
* A (ndim,) array to align with an arbitrary point
method : str, default: 'median'
Method to use to align points.
Current accepted values are: 'median' and 'mean'
Expand All @@ -46,7 +41,6 @@ def get_shift(
-------
shift : NDArray[np.floating]
The (ndim,) shift to apply after transformation.
If point are in col basis will be returned as a column vector.
Raises
------
Expand All @@ -56,30 +50,22 @@ def get_shift(
if method not in ["median", "mean"]:
raise ValueError(f"Invalid method: {method}")

if row_basis:
src = src.T
dst = np.atleast_2d(dst).T

shift = np.zeros(src.shape[0])
shift = np.zeros(src.shape[1])
if method == "median":
shift = np.median(dst - src, axis=-1)
shift = np.median(dst - src, axis=0)
elif method == "mean":
if weights is None:
shift = np.mean(dst - src, axis=-1)
shift = np.mean(dst - src, axis=0)
else:
wdiff = weights * (dst - src)
shift = np.nansum(wdiff, axis=1) / np.nansum(weights)

if not row_basis:
shift = shift[..., np.newaxis]
wdiff = weights[..., None] * (dst - src)
shift = np.nansum(wdiff, axis=0) / np.nansum(weights)

return shift


def get_rigid(
src: NDArray[np.floating],
dst: NDArray[np.floating],
row_basis: bool = True,
center_dst: bool = True,
**kwargs,
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
Expand All @@ -88,19 +74,14 @@ def get_rigid(
It is assumed that the point clouds have the same registration,
ie. src[i] corresponds to dst[i].
Transformation is dst = src@rot + shift in row basis,
and dst = rot@src + shift in col basis.
Transformation is dst = src@rot + shift.
Parameters
----------
src : NDArray[np.floating]
A (ndim, npoints) array of source points.
A (npoints, ndim) array of source points.
dst : NDArray[np.floating]
A (ndim, npoints) array of destination points.
row_basis : bool, default: True
If the basis of the points is row.
If row basis then each row of src and dst is a point.
If col basis then each col of src and dst is a point.
A (npoints, ndim) array of destination points.
center_dst : bool, default: True
If True, dst will be recentered at the origin before computing transformation.
This is done with get_shift, but weights will not be used if provided.
Expand All @@ -123,70 +104,66 @@ def get_rigid(
"""
if src.shape != dst.shape:
raise ValueError("Input point clouds should have the same shape")
if row_basis:
src = src.T
dst = dst.T

msk = np.isfinite(src).all(axis=0) * np.isfinite(dst).all(axis=0)
ndim = len(src)
msk = np.isfinite(src).all(axis=1) * np.isfinite(dst).all(axis=1)
ndim = src.shape[1]
if np.sum(msk) < ndim * (ndim - 1) / 2:
raise ValueError("Not enough finite points to compute transformation")

_dst = dst[:, msk].copy()
_dst = dst[msk].copy()
if center_dst:
_kwargs = kwargs.copy()
_kwargs.update({"weights": None})
_dst += get_shift(_dst, np.zeros(1), False, **_kwargs)
_src = src[:, msk].copy()
_src += get_shift(_src, _dst, False, **kwargs)
_dst += get_shift(_dst, np.zeros(1), **_kwargs)
_src = src[msk].copy()
_src += get_shift(_src, _dst, **kwargs)

M = _src @ (_dst.T)
M = _src.T @ (_dst)
u, _, vh = la.svd(M)
v = vh.T
uT = u.T

corr = np.eye(ndim)
corr[-1, -1] = la.det((v) @ (uT))
rot = v @ corr @ uT
rot = rot.T

transformed = rot @ src[:, msk]
shift = get_shift(transformed, dst[:, msk], False, **kwargs)

if row_basis:
rot = rot.T
shift = shift[:, 0]
transformed = src[msk] @ rot
shift = get_shift(transformed, dst[msk], **kwargs)

return rot, shift


def get_affine(
src: NDArray[np.floating],
dst: NDArray[np.floating],
row_basis: bool = True,
weights: Optional[NDArray[np.floating]] = None,
center_dst: bool = True,
force_svd: bool = False,
**kwargs,
) -> tuple[NDArray[np.floating], NDArray[np.floating]]:
"""
Get affine transformation between two point clouds.
It is assumed that the point clouds have the same registration,
ie. src[i] corresponds to dst[i].
Transformation is dst = src@affine + shift in row basis,
and dst = affine@src + shift in col basis.
Transformation is dst = src@affine + shift.
Parameters
----------
src : NDArray[np.floating]
A (npoints, ndim) or (ndim, npoints) array of source points.
A (npoints, ndim) array of source points.
dst : NDArray[np.floating]
A ((npoints, ndim) or (ndim, npoints) array of destination points.
row_basis : bool, default: True
If the basis of the points is row.
If row basis then each row of src and dst is a point.
If col basis then each col of src and dst is a point.
A (npoints, ndim) array of destination points.
weights : Optional[NDArray[np.floating]], default: None
(npoints,) array of weights to use.
If provided a weighted least squares is done instead of an SVD.
center_dst : bool, default: True
If True, dst will be recentered at the origin before computing transformation.
This is done with get_shift, but weights will not be used if provided.
force_svd : bool, default: False
If True the SVD is used even if there are a small number of points
or weights are present.
**kwargs
Arguments to pass to get_shift.
Expand All @@ -196,7 +173,6 @@ def get_affine(
The (ndim, ndim) transformation matrix.
shift : NDArray[np.floating]
The (ndim,) shift to apply after transformation.
If point are in col basis will be returned as a column vector.
Raises
------
Expand All @@ -206,35 +182,44 @@ def get_affine(
"""
if src.shape != dst.shape:
raise ValueError("Input point clouds should have the same shape")
if row_basis:
src = src.T
dst = dst.T

msk = np.isfinite(src).all(axis=0) * np.isfinite(dst).all(axis=0)
if np.sum(msk) < len(src) + 1:
msk = np.isfinite(src).all(axis=1) * np.isfinite(dst).all(axis=1)
if np.sum(msk) < src.shape[1] + 1:
raise ValueError("Not enough finite points to compute transformation")

_dst = dst[:, msk].copy()
if center_dst:
_kwargs = kwargs.copy()
_kwargs.update({"weights": None})
_dst += get_shift(_dst, np.zeros(1), False, **_kwargs)
_src = src[:, msk].copy()
_src += get_shift(_src, _dst, False, **kwargs)

M = np.vstack((_src, _dst)).T
*_, vh = la.svd(M)
vh_splits = [
quad for half in np.split(vh.T, 2, axis=0) for quad in np.split(half, 2, axis=1)
]
affine = np.dot(vh_splits[2], la.pinv(vh_splits[0]))
# When we have a small number of points lstsq is better than SVD
# Condition is a bit arbitrary for now
if force_svd is False and weights is None and np.sum(msk) < 50 * src.shape[1]:
weights = np.ones(len(src))

transformed = affine @ src[:, msk]
shift = get_shift(transformed, dst[:, msk], False, **kwargs)
_dst = dst[msk].copy()
if center_dst:
_dst += get_shift(_dst, np.zeros(1), **kwargs)
_src = src[msk].copy()
init_shift = get_shift(_src, _dst, weights=weights, **kwargs)

if force_svd or weights is None:
M = np.vstack((_src.T, (_dst - init_shift).T)).T
*_, vh = la.svd(M)
vh_splits = [
quad
for half in np.split(vh.T, 2, axis=0)
for quad in np.split(half, 2, axis=1)
]
affine = np.dot(vh_splits[2], la.pinv(vh_splits[0])).T
shift = init_shift
else:
rt_weight = np.sqrt(weights[msk])[..., None]
wsrc = rt_weight * _src
wdst = rt_weight * (_dst - init_shift)
x, *_ = la.lstsq(
np.column_stack((wsrc, np.ones(len(wsrc)))), wdst, check_finite=False
)
affine = x[:-1]
shift = x[-1] + init_shift

if row_basis:
affine = affine.T
shift = shift[:, 0]
transformed = src[msk] @ affine + shift
shift += get_shift(transformed, dst[msk], **kwargs)

return affine, shift

Expand All @@ -243,7 +228,6 @@ def apply_transform(
src: NDArray[np.floating],
transform: NDArray[np.floating],
shift: NDArray[np.floating],
row_basis: bool = True,
) -> NDArray[np.floating]:
"""
Apply a transformation to a set of points.
Expand All @@ -252,18 +236,13 @@ def apply_transform(
----------
src : NDArray[np.floating]
The points to transform.
Should have shape (ndim, npoints) or (npoints, ndim).
Should have shape (npoints, ndim).
transform: NDArray[np.floating]
The transformation matrix.
Should have shape (ndim, ndim).
shift : NDArray[np.floating]
The shift to apply after the affine tranrform.
Should have shape (ndim,).
row_basis : bool, default: True
Whether or not the input and output need to be transposed.
This is the case when src is (npoints, ndim).
By default the function will try to figure this out in its own,
this is only used in the case where it can't because src is (ndim, ndim).
Returns
-------
Expand All @@ -287,10 +266,7 @@ def apply_transform(
if len(src_shape) != 2:
raise ValueError(f"src should be a 2d array, not {len(src.shape)}d")

if row_basis:
transformed = src @ transform + shift
else:
transformed = transform @ src + shift
transformed = src @ transform + shift
return transformed


Expand Down

0 comments on commit ac5086d

Please sign in to comment.