From 535d1e14de9c055802e9a547985b322f1d287b2a Mon Sep 17 00:00:00 2001 From: Sheng Zhong Date: Thu, 27 Jun 2024 20:25:37 -0400 Subject: [PATCH] Add quaternion angular distance and slerp --- src/pytorch_kinematics/transforms/__init__.py | 7 +++ src/pytorch_kinematics/transforms/math.py | 59 ++++++++++++++++++- tests/test_transform.py | 43 ++++++++++++++ 3 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/pytorch_kinematics/transforms/__init__.py b/src/pytorch_kinematics/transforms/__init__.py index 8f4d030..73a5dff 100644 --- a/src/pytorch_kinematics/transforms/__init__.py +++ b/src/pytorch_kinematics/transforms/__init__.py @@ -14,6 +14,7 @@ quaternion_raw_multiply, quaternion_to_matrix, quaternion_from_euler, + quaternion_to_axis_angle, random_quaternions, random_rotation, random_rotations, @@ -35,5 +36,11 @@ so3_rotation_angle, ) from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate +from pytorch_kinematics.transforms.math import ( + quaternion_angular_distance, + acos_linear_extrapolation, + quaternion_close, + quaternion_slerp, +) __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/src/pytorch_kinematics/transforms/math.py b/src/pytorch_kinematics/transforms/math.py index cc10065..9a315b7 100644 --- a/src/pytorch_kinematics/transforms/math.py +++ b/src/pytorch_kinematics/transforms/math.py @@ -10,16 +10,73 @@ import torch +def quaternion_angular_distance(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """ + Computes the angular distance between two quaternions. + Args: + q1: First quaternion (assume normalized). + q2: Second quaternion (assume normalized). + Returns: + Angular distance between the two quaternions. + """ + + # Compute the cosine of the angle between the two quaternions + cos_theta = torch.sum(q1 * q2, dim=-1) + # we use atan2 instead of acos for better numerical stability + cos_theta = torch.clamp(cos_theta, -1.0, 1.0) + abs_dot = torch.abs(cos_theta) + # identity sin^2(theta) = 1 - cos^2(theta) + sin_half_theta = torch.sqrt(1.0 - torch.square(abs_dot)) + theta = 2.0 * torch.atan2(sin_half_theta, abs_dot) + + # theta for the ones that are close gets 0 and we don't care about them + close = quaternion_close(q1, q2) + theta[close] = 0 + return theta + + def quaternion_close(q1: torch.Tensor, q2: torch.Tensor, eps: float = 1e-4): """ Returns true if two quaternions are close to each other. Assumes the quaternions are normalized. Based on: https://math.stackexchange.com/a/90098/516340 """ - dist = 1 - torch.square(torch.sum(q1*q2, dim=-1)) + dist = 1 - torch.square(torch.sum(q1 * q2, dim=-1)) return torch.all(dist < eps) +def quaternion_slerp(q1: torch.Tensor, q2: torch.Tensor, t: Union[float, torch.tensor]) -> torch.Tensor: + """ + Spherical linear interpolation between two quaternions. + Args: + q1: First quaternion (assume normalized). + q2: Second quaternion (assume normalized). + t: Interpolation parameter. + Returns: + Interpolated quaternion. + """ + # Compute the cosine of the angle between the two quaternions + cos_theta = torch.sum(q1 * q2, dim=-1) + + # reverse the direction of q2 if q1 and q2 are not in the same hemisphere + to_invert = cos_theta < 0 + q2[to_invert] = -q2[to_invert] + cos_theta[to_invert] = -cos_theta[to_invert] + + # If the quaternions are close, perform a linear interpolation + if torch.all(cos_theta > 1.0 - 1e-6): + return q1 + t * (q2 - q1) + + # Ensure the angle is between 0 and pi + theta = torch.acos(cos_theta) + sin_theta = torch.sin(theta) + + # Perform the interpolation + w1 = torch.sin((1.0 - t) * theta) / sin_theta + w2 = torch.sin(t * theta) / sin_theta + return w1[:, None] * q1 + w2[:, None] * q2 + + def acos_linear_extrapolation( x: torch.Tensor, bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, diff --git a/tests/test_transform.py b/tests/test_transform.py index d5e06b7..ffdcc1b 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,6 +1,7 @@ import torch import pytorch_kinematics.transforms as tf +import pytorch_kinematics as pk def test_transform(): @@ -106,11 +107,39 @@ def test_euler(): def test_quaternions(): + import pytorch_seed + pytorch_seed.seed(0) + n = 10 q = tf.random_quaternions(n) q_tf = tf.wxyz_to_xyzw(q) assert torch.allclose(q, tf.xyzw_to_wxyz(q_tf)) + qq = pk.standardize_quaternion(q) + assert torch.allclose(qq.norm(dim=-1), torch.ones(n)) + + # random quaternions should already be unit quaternions + assert torch.allclose(q, qq) + + # distances to themselves should be zero + d = pk.quaternion_angular_distance(q, q) + assert torch.allclose(d, torch.zeros(n)) + # q = -q + d = pk.quaternion_angular_distance(q, -q) + assert torch.allclose(d, torch.zeros(n)) + + axis = torch.tensor([0.0, 0.5, 0.5]) + axis = axis / axis.norm() + magnitudes = torch.tensor([2.32, 1.56, -0.52, 0.1]) + n = len(magnitudes) + aa_1 = axis.repeat(n, 1) + aa_2 = axis * magnitudes[:, None] + q1 = pk.axis_angle_to_quaternion(aa_1) + q2 = pk.axis_angle_to_quaternion(aa_2) + d = pk.quaternion_angular_distance(q1, q2) + expected_d = (magnitudes - 1).abs() + assert torch.allclose(d, expected_d, atol=1e-4) + def test_compose(): import torch @@ -124,6 +153,19 @@ def test_compose(): print(a2c.transform_points(torch.zeros([1, 3]))) +def test_quaternion_slerp(): + q = tf.random_quaternions(20) + q1 = q[:10] + q2 = q[10:] + t = torch.rand(10) + q_interp = pk.quaternion_slerp(q1, q2, t) + # check the distance between them is consistent + full_dist = pk.quaternion_angular_distance(q1, q2) + interp_dist = pk.quaternion_angular_distance(q1, q_interp) + # print(f"full_dist: {full_dist} interp_dist: {interp_dist} t: {t}") + assert torch.allclose(full_dist * t, interp_dist, atol=1e-5) + + if __name__ == "__main__": test_compose() test_transform() @@ -132,3 +174,4 @@ def test_compose(): test_rotate() test_euler() test_quaternions() + test_quaternion_slerp()