diff --git a/omnigibson/utils/transform_utils.py b/omnigibson/utils/transform_utils.py index 985716e9d..a89a25491 100644 --- a/omnigibson/utils/transform_utils.py +++ b/omnigibson/utils/transform_utils.py @@ -44,7 +44,7 @@ @torch.compile -def _copysign(a, b): +def copysign(a, b): # type: (float, torch.Tensor) -> torch.Tensor a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) return torch.abs(a) * torch.sign(b) @@ -588,7 +588,7 @@ def quat2euler(q): roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) - pitch = torch.where(torch.abs(sinp) >= 1, _copysign(math.pi / 2.0, sinp), torch.asin(sinp)) + pitch = torch.where(torch.abs(sinp) >= 1, copysign(math.pi / 2.0, sinp), torch.asin(sinp)) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] diff --git a/omnigibson/utils/transform_utils_np.py b/omnigibson/utils/transform_utils_np.py index b530f9a49..ddf2ac0c3 100644 --- a/omnigibson/utils/transform_utils_np.py +++ b/omnigibson/utils/transform_utils_np.py @@ -50,6 +50,11 @@ _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) +def copysign(a, b): + a = np.array(a).repeat(b.shape[0]) + return np.abs(a) * np.sign(b) + + def anorm(x, axis=None, keepdims=False): """Compute L2 norms alogn specified axes.""" return np.linalg.norm(x, axis=axis, keepdims=keepdims)