diff --git a/src/pytorch_kinematics/transforms/rotation_conversions.py b/src/pytorch_kinematics/transforms/rotation_conversions.py index dcc0ad4..bf75f98 100644 --- a/src/pytorch_kinematics/transforms/rotation_conversions.py +++ b/src/pytorch_kinematics/transforms/rotation_conversions.py @@ -497,8 +497,8 @@ def tensor_axis_and_angle_to_matrix(axis, theta): torch.stack([r10, r11, r12], -1), torch.stack([r20, r21, r22], -1)], -2) batch_shape = axis.shape[:-1] - mat44 = torch.eye(4, device=axis.device, dtype=axis.dtype).repeat(*batch_shape, 1, 1) - mat44[..., :3, :3] = rot + mat44 = torch.cat((rot, torch.zeros(*batch_shape, 3, 1).to(axis)), -1) + mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_shape, 1, 4).to(axis)), -2) return mat44