Skip to content

Commit

Permalink
adding translate transform for point clouds
Browse files Browse the repository at this point in the history
Signed-off-by: pushkalkatara <[email protected]>
  • Loading branch information
pushkalkatara committed May 8, 2020
1 parent c3ec548 commit c11a1bb
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
97 changes: 97 additions & 0 deletions kaolin/transforms/pointcloudfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,53 @@
EPS = 1e-6


def shift(cloud: Union[torch.Tensor, PointCloud],
shf: Union[float, int, torch.Tensor],
inplace: Optional[bool] = True):
"""Shift the input pointcloud by a shift factor.
Args:
cloud (torch.Tensor or kaolin.rep.PointCloud): pointcloud (ndims >= 2).
shf (float, int, torch.Tensor): shift factor (scaler, or tensor).
inplace (bool, optional): Bool to make the transform in-place
Returns:
(torch.Tensor): shifted pointcloud pf the same shape as input.
Shape:
- cloud: :math:`(B x N x D)` (or) :math:`(N x D)`, where :math:`(B)`
is the batchsize, :math:`(N)` is the number of points per cloud,
and :math:`(D)` is the dimensionality of each cloud.
- shf: :math:`(1)` or :math:`(B)`.
Example:
>>> points = torch.rand(1000,3)
>>> points2 = shift(points, torch.FloatTensor([3]))
"""

if isinstance(cloud, np.ndarray):
cloud = torch.from_numpy(cloud)

if isinstance(shf, np.ndarray):
shf = torch.from_numpy(shf)

if isinstance(cloud, PointCloud):
cloud = cloud.points

if isinstance(shf, int) or isinstance(shf, float):
shf = torch.Tensor([shf]).to(cloud.device)

helpers._assert_tensor(cloud)
helpers._assert_tensor(shf)
helpers._assert_dim_ge(cloud, 2)
helpers._assert_gt(shf, 0.)

if not inplace:
cloud = cloud.clone()

return shf + cloud


def scale(cloud: Union[torch.Tensor, PointCloud],
scf: Union[float, int, torch.Tensor],
inplace: Optional[bool] = True):
Expand Down Expand Up @@ -74,6 +121,56 @@ def scale(cloud: Union[torch.Tensor, PointCloud],
return scf * cloud


def translate(cloud: Union[torch.Tensor, PointCloud], tranmat: torch.Tensor,
inplace: Optional[bool] = True):
"""Translate the input pointcloud by a translation matrix.
Args:
cloud (Tensor or np.array): pointcloud (ndims = 2 or 3)
tranmat (Tensor or np.array): translation matrix (1 x 3, 1 per cloud).
Returns:
cloud_tran (Tensor): Translated pointcloud of the same shape as input.
Shape:
- cloud: :math:`(B x N x 3)` (or) :math:`(N x 3)`, where :math:`(B)`
is the batchsize, :math:`(N)` is the number of points per cloud,
and :math:`(3)` is the dimensionality of each cloud.
- tranmat: :math:`(1, 3)` or :math:`(B, 1, 3)`.
Example:
>>> points = torch.rand(1000,3)
>>> t_mat = torch.rand(1,3)
>>> points2 = translate(points, t_mat)
"""
if isinstance(cloud, np.ndarray):
cloud = torch.from_numpy(cloud)
if isinstance(cloud, PointCloud):
cloud = cloud.points
if isinstance(tranmat, np.ndarray):
trainmat = torch.from_numpy(tranmat)

helpers._assert_tensor(cloud)
helpers._assert_tensor(tranmat)
helpers._assert_dim_ge(cloud, 2)
helpers._assert_dim_ge(tranmat, 2)
# Rotation matrix must have last two dimensions of shape 3.
helpers._assert_shape_eq(tranmat, (1, 3), dim=-1)
helpers._assert_shape_eq(tranmat, (1, 3), dim=-2)

if not inplace:
cloud = cloud.clone()

if tranmat.dim() == 2 and cloud.dim() == 2:
cloud = torch.add(tranmat, cloud)
else:
if tranmat.dim() == 2:
tranmat = tranmat.expand(cloud.shape[0], 1, 3)
cloud = torch.add(tranmat, cloud)

return cloud

def rotate(cloud: Union[torch.Tensor, PointCloud], rotmat: torch.Tensor,
inplace: Optional[bool] = True):
"""Rotates the the input pointcloud by a rotation matrix.
Expand Down
58 changes: 58 additions & 0 deletions kaolin/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,35 @@ def __call__(self, arr: np.ndarray):
return torch.from_numpy(arr)


class ShiftPointCloud(Transform):
r"""Shift a pointcloud with respect a fixed shift factor.
Given a shift factor `shf`, this transform will shift each point in the
pointcloud, i.e.,
``cloud = shf + cloud``
Args:
shf (int or float or torch.Tensor): Shift pofactorint by which input
clouds are to be shifted.
inplace (bool, optional): Whether or not the transformation should be
in-place (default: True).
"""

def __init__(self, shf: Union[int, float, torch.Tensor],
inplace: Optional[bool] = True):
self.shf = shf
self.inplace = inplace

def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
"""
Args:
cloud (torch.Tensor or PointCloud): Pointcloud to be shifted.
Returns:
(torch.Tensor or PointCloud): Shifted pointcloud.
"""
return pcfunc.shift(cloud, shf=self.shf, inplace=self.inplace)


class ScalePointCloud(Transform):
"""Scale a pointcloud with a fixed scaling factor.
Given a scale factor `scf`, this transform will scale each point in the
Expand Down Expand Up @@ -231,6 +260,35 @@ def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
return pcfunc.scale(cloud, scf=self.scf, inplace=self.inplace)


class TranslatePointCloud(Transform):
r"""Translate a pointcloud with a given translation matrix.
Given a :math:`1 \times 3` translation matrix, this transform will
translate each point in the cloud by the translation matrix specified.
Args:
tranmat (torch.Tensor): Translation matrix that specifies the translation
to be applied to the pointcloud (shape: :math:`1 \times 3`).
inplace (bool, optional): Bool to make this operation in-place.
TODO: Example.
"""

def __init__(self, tranmat: torch.Tensor, inplace: Optional[bool] = True):
self.tranmat = tranmat
self.inplace = inplace

def __call__(self, cloud: Union[torch.Tensor, PointCloud]):
"""
Args:
cloud (torch.Tensor or PointCloud): Input pointcloud to be translated.
Returns:
(torch.Tensor or PointCloud): Translated pointcloud.
"""
return pcfunc.translate(cloud, tranmat=self.tranmat, inplace=self.inplace)


class RotatePointCloud(Transform):
r"""Rotate a pointcloud with a given rotation matrix.
Given a :math:`3 \times 3` rotation matrix, this transform will rotate each
Expand Down

0 comments on commit c11a1bb

Please sign in to comment.