diff --git a/megham/transform.py b/megham/transform.py index 7eec041..5e7ef48 100644 --- a/megham/transform.py +++ b/megham/transform.py @@ -12,7 +12,6 @@ def get_shift( src: NDArray[np.floating], dst: NDArray[np.floating], - row_basis: bool = True, method: str = "median", weights: Optional[NDArray[np.floating]] = None, ) -> NDArray[np.floating]: @@ -29,11 +28,7 @@ def get_shift( but really any array broadcastable with src is accepted. Some useful options are: * np.zeros(1) to align with the origin - * A (ndim, 1) array to align with an arbitrary point - row_basis : bool, default: True - If the basis of the points is row. - If row basis then each row of src and dst is a point. - If col basis then each col of src and dst is a point. + * A (ndim,) array to align with an arbitrary point method : str, default: 'median' Method to use to align points. Current accepted values are: 'median' and 'mean' @@ -46,7 +41,6 @@ def get_shift( ------- shift : NDArray[np.floating] The (ndim,) shift to apply after transformation. - If point are in col basis will be returned as a column vector. Raises ------ @@ -56,22 +50,15 @@ def get_shift( if method not in ["median", "mean"]: raise ValueError(f"Invalid method: {method}") - if row_basis: - src = src.T - dst = np.atleast_2d(dst).T - - shift = np.zeros(src.shape[0]) + shift = np.zeros(src.shape[1]) if method == "median": - shift = np.median(dst - src, axis=-1) + shift = np.median(dst - src, axis=0) elif method == "mean": if weights is None: - shift = np.mean(dst - src, axis=-1) + shift = np.mean(dst - src, axis=0) else: - wdiff = weights * (dst - src) - shift = np.nansum(wdiff, axis=1) / np.nansum(weights) - - if not row_basis: - shift = shift[..., np.newaxis] + wdiff = weights[..., None] * (dst - src) + shift = np.nansum(wdiff, axis=0) / np.nansum(weights) return shift @@ -79,7 +66,6 @@ def get_shift( def get_rigid( src: NDArray[np.floating], dst: NDArray[np.floating], - row_basis: bool = True, center_dst: bool = True, **kwargs, ) -> tuple[NDArray[np.floating], NDArray[np.floating]]: @@ -88,19 +74,14 @@ def get_rigid( It is assumed that the point clouds have the same registration, ie. src[i] corresponds to dst[i]. - Transformation is dst = src@rot + shift in row basis, - and dst = rot@src + shift in col basis. + Transformation is dst = src@rot + shift. Parameters ---------- src : NDArray[np.floating] - A (ndim, npoints) array of source points. + A (npoints, ndim) array of source points. dst : NDArray[np.floating] - A (ndim, npoints) array of destination points. - row_basis : bool, default: True - If the basis of the points is row. - If row basis then each row of src and dst is a point. - If col basis then each col of src and dst is a point. + A (npoints, ndim) array of destination points. center_dst : bool, default: True If True, dst will be recentered at the origin before computing transformation. This is done with get_shift, but weights will not be used if provided. @@ -123,24 +104,21 @@ def get_rigid( """ if src.shape != dst.shape: raise ValueError("Input point clouds should have the same shape") - if row_basis: - src = src.T - dst = dst.T - msk = np.isfinite(src).all(axis=0) * np.isfinite(dst).all(axis=0) - ndim = len(src) + msk = np.isfinite(src).all(axis=1) * np.isfinite(dst).all(axis=1) + ndim = src.shape[1] if np.sum(msk) < ndim * (ndim - 1) / 2: raise ValueError("Not enough finite points to compute transformation") - _dst = dst[:, msk].copy() + _dst = dst[msk].copy() if center_dst: _kwargs = kwargs.copy() _kwargs.update({"weights": None}) - _dst += get_shift(_dst, np.zeros(1), False, **_kwargs) - _src = src[:, msk].copy() - _src += get_shift(_src, _dst, False, **kwargs) + _dst += get_shift(_dst, np.zeros(1), **_kwargs) + _src = src[msk].copy() + _src += get_shift(_src, _dst, **kwargs) - M = _src @ (_dst.T) + M = _src.T @ (_dst) u, _, vh = la.svd(M) v = vh.T uT = u.T @@ -148,13 +126,10 @@ def get_rigid( corr = np.eye(ndim) corr[-1, -1] = la.det((v) @ (uT)) rot = v @ corr @ uT + rot = rot.T - transformed = rot @ src[:, msk] - shift = get_shift(transformed, dst[:, msk], False, **kwargs) - - if row_basis: - rot = rot.T - shift = shift[:, 0] + transformed = src[msk] @ rot + shift = get_shift(transformed, dst[msk], **kwargs) return rot, shift @@ -162,8 +137,9 @@ def get_rigid( def get_affine( src: NDArray[np.floating], dst: NDArray[np.floating], - row_basis: bool = True, + weights: Optional[NDArray[np.floating]] = None, center_dst: bool = True, + force_svd: bool = False, **kwargs, ) -> tuple[NDArray[np.floating], NDArray[np.floating]]: """ @@ -171,22 +147,23 @@ def get_affine( It is assumed that the point clouds have the same registration, ie. src[i] corresponds to dst[i]. - Transformation is dst = src@affine + shift in row basis, - and dst = affine@src + shift in col basis. + Transformation is dst = src@affine + shift. Parameters ---------- src : NDArray[np.floating] - A (npoints, ndim) or (ndim, npoints) array of source points. + A (npoints, ndim) array of source points. dst : NDArray[np.floating] - A ((npoints, ndim) or (ndim, npoints) array of destination points. - row_basis : bool, default: True - If the basis of the points is row. - If row basis then each row of src and dst is a point. - If col basis then each col of src and dst is a point. + A (npoints, ndim) array of destination points. + weights : Optional[NDArray[np.floating]], default: None + (npoints,) array of weights to use. + If provided a weighted least squares is done instead of an SVD. center_dst : bool, default: True If True, dst will be recentered at the origin before computing transformation. This is done with get_shift, but weights will not be used if provided. + force_svd : bool, default: False + If True the SVD is used even if there are a small number of points + or weights are present. **kwargs Arguments to pass to get_shift. @@ -196,7 +173,6 @@ def get_affine( The (ndim, ndim) transformation matrix. shift : NDArray[np.floating] The (ndim,) shift to apply after transformation. - If point are in col basis will be returned as a column vector. Raises ------ @@ -206,35 +182,44 @@ def get_affine( """ if src.shape != dst.shape: raise ValueError("Input point clouds should have the same shape") - if row_basis: - src = src.T - dst = dst.T - msk = np.isfinite(src).all(axis=0) * np.isfinite(dst).all(axis=0) - if np.sum(msk) < len(src) + 1: + msk = np.isfinite(src).all(axis=1) * np.isfinite(dst).all(axis=1) + if np.sum(msk) < src.shape[1] + 1: raise ValueError("Not enough finite points to compute transformation") - _dst = dst[:, msk].copy() - if center_dst: - _kwargs = kwargs.copy() - _kwargs.update({"weights": None}) - _dst += get_shift(_dst, np.zeros(1), False, **_kwargs) - _src = src[:, msk].copy() - _src += get_shift(_src, _dst, False, **kwargs) - - M = np.vstack((_src, _dst)).T - *_, vh = la.svd(M) - vh_splits = [ - quad for half in np.split(vh.T, 2, axis=0) for quad in np.split(half, 2, axis=1) - ] - affine = np.dot(vh_splits[2], la.pinv(vh_splits[0])) + # When we have a small number of points lstsq is better than SVD + # Condition is a bit arbitrary for now + if force_svd is False and weights is None and np.sum(msk) < 50 * src.shape[1]: + weights = np.ones(len(src)) - transformed = affine @ src[:, msk] - shift = get_shift(transformed, dst[:, msk], False, **kwargs) + _dst = dst[msk].copy() + if center_dst: + _dst += get_shift(_dst, np.zeros(1), **kwargs) + _src = src[msk].copy() + init_shift = get_shift(_src, _dst, weights=weights, **kwargs) + + if force_svd or weights is None: + M = np.vstack((_src.T, (_dst - init_shift).T)).T + *_, vh = la.svd(M) + vh_splits = [ + quad + for half in np.split(vh.T, 2, axis=0) + for quad in np.split(half, 2, axis=1) + ] + affine = np.dot(vh_splits[2], la.pinv(vh_splits[0])).T + shift = init_shift + else: + rt_weight = np.sqrt(weights[msk])[..., None] + wsrc = rt_weight * _src + wdst = rt_weight * (_dst - init_shift) + x, *_ = la.lstsq( + np.column_stack((wsrc, np.ones(len(wsrc)))), wdst, check_finite=False + ) + affine = x[:-1] + shift = x[-1] + init_shift - if row_basis: - affine = affine.T - shift = shift[:, 0] + transformed = src[msk] @ affine + shift + shift += get_shift(transformed, dst[msk], **kwargs) return affine, shift @@ -243,7 +228,6 @@ def apply_transform( src: NDArray[np.floating], transform: NDArray[np.floating], shift: NDArray[np.floating], - row_basis: bool = True, ) -> NDArray[np.floating]: """ Apply a transformation to a set of points. @@ -252,18 +236,13 @@ def apply_transform( ---------- src : NDArray[np.floating] The points to transform. - Should have shape (ndim, npoints) or (npoints, ndim). + Should have shape (npoints, ndim). transform: NDArray[np.floating] The transformation matrix. Should have shape (ndim, ndim). shift : NDArray[np.floating] The shift to apply after the affine tranrform. Should have shape (ndim,). - row_basis : bool, default: True - Whether or not the input and output need to be transposed. - This is the case when src is (npoints, ndim). - By default the function will try to figure this out in its own, - this is only used in the case where it can't because src is (ndim, ndim). Returns ------- @@ -287,10 +266,7 @@ def apply_transform( if len(src_shape) != 2: raise ValueError(f"src should be a 2d array, not {len(src.shape)}d") - if row_basis: - transformed = src @ transform + shift - else: - transformed = transform @ src + shift + transformed = src @ transform + shift return transformed