From 3ca731a276b1c14b8919cb7051c8a008549fd0fe Mon Sep 17 00:00:00 2001 From: Sheng Zhong Date: Wed, 28 Aug 2024 01:19:52 -0400 Subject: [PATCH] Move serial chain testing to separate file --- tests/test_inverse_kinematics.py | 86 +---------------------------- tests/test_jacobian.py | 2 +- tests/test_serial_chain_creation.py | 86 +++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 84 deletions(-) create mode 100644 tests/test_serial_chain_creation.py diff --git a/tests/test_inverse_kinematics.py b/tests/test_inverse_kinematics.py index 105b8c2..95c006a 100644 --- a/tests/test_inverse_kinematics.py +++ b/tests/test_inverse_kinematics.py @@ -10,9 +10,8 @@ import pybullet as p import pybullet_data -visualize = True +visualize = False -TEST_DIR = os.path.dirname(__file__) def _make_robot_translucent(robot_id, alpha=0.4): def make_transparent(link): @@ -207,87 +206,8 @@ def test_ik_in_place_no_err(): assert torch.allclose(sol.err_rot[0], torch.zeros(1, device=device), atol=1e-6) -def test_extract_serial_chain_from_tree(): - pytorch_seed.seed(2) - device = "cuda" if torch.cuda.is_available() else "cpu" - # device = "cpu" - urdf = "widowx/wx250s.urdf" - full_urdf = os.path.join(TEST_DIR, urdf) - chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read()) - # full frames - full_frame_expected = """ -base_link -└── shoulder_link - └── upper_arm_link - └── upper_forearm_link - └── lower_forearm_link - └── wrist_link - └── gripper_link - └── ee_arm_link - ├── gripper_prop_link - └── gripper_bar_link - └── fingers_link - ├── left_finger_link - ├── right_finger_link - └── ee_gripper_link - """ - full_frame = chain.print_link_tree() - assert full_frame_expected.strip() == full_frame.strip() - - chain = pk.SerialChain(chain, "ee_gripper_link", "base_link") - serial_frame = chain.print_link_tree() - chain = chain.to(device=device) - - # full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6 - dof = len(chain.get_joints(exclude_fixed=True)) - assert dof == 6 - - # robot frame - pos = torch.tensor([0.0, 0.0, 0.0], device=device) - rot = torch.tensor([0.0, 0.0, 0.0], device=device) - rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device) - - # world frame goal - M = 1000 - # generate random goal joint angles (so these are all achievable) - # use the joint limits to generate random joint angles - lim = torch.tensor(chain.get_joint_limits(), device=device) - goal_q = torch.rand(M, 7, device=device) * (lim[1] - lim[0]) + lim[0] - - # get ee pose (in robot frame) - goal_in_rob_frame_tf = chain.forward_kinematics(goal_q) - - # transform to world frame for visualization - goal_tf = rob_tf.compose(goal_in_rob_frame_tf) - goal = goal_tf.get_matrix() - goal_pos = goal[..., :3, 3] - goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ") - - num_retries = 10 - ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=num_retries, - joint_limits=lim.T, - early_stopping_any_converged=True, - early_stopping_no_improvement="all", - # line_search=pk.BacktrackingLineSearch(max_lr=0.2), - debug=False, - lr=0.2) - - # do IK - timer_start = timer() - sol = ik.solve(goal_in_rob_frame_tf) - timer_end = timer() - print("IK took %f seconds" % (timer_end - timer_start)) - print("IK converged number: %d / %d" % (sol.converged.sum(), sol.converged.numel())) - print("IK took %d iterations" % sol.iterations) - print("IK solved %d / %d goals" % (sol.converged_any.sum(), M)) - - # check that solving again produces the same solutions - sol_again = ik.solve(goal_in_rob_frame_tf) - assert torch.allclose(sol.solutions, sol_again.solutions) - assert torch.allclose(sol.converged, sol_again.converged) if __name__ == "__main__": - # test_jacobian_follower() - # test_ik_in_place_no_err() - test_extract_serial_chain_from_tree() + test_jacobian_follower() + test_ik_in_place_no_err() \ No newline at end of file diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 599f26f..960c722 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -182,7 +182,7 @@ def get_pt(th): try: import functorch ft_start = timer() - grad_func = functorch.vmap(functorch.jacrev(get_pt)) + grad_func = torch.vmap(functorch.jacrev(get_pt)) j3 = grad_func(th).squeeze(1) ft_end = timer() assert torch.allclose(j1_, j3, atol=1e-6) diff --git a/tests/test_serial_chain_creation.py b/tests/test_serial_chain_creation.py new file mode 100644 index 0000000..d7b0347 --- /dev/null +++ b/tests/test_serial_chain_creation.py @@ -0,0 +1,86 @@ +import os +from timeit import default_timer as timer + +import torch + +import pytorch_kinematics as pk + +TEST_DIR = os.path.dirname(__file__) + + +def test_extract_serial_chain_from_tree(): + urdf = "widowx/wx250s.urdf" + full_urdf = os.path.join(TEST_DIR, urdf) + chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read()) + # full frames + full_frame_expected = """ +base_link +└── shoulder_link + └── upper_arm_link + └── upper_forearm_link + └── lower_forearm_link + └── wrist_link + └── gripper_link + └── ee_arm_link + ├── gripper_prop_link + └── gripper_bar_link + └── fingers_link + ├── left_finger_link + ├── right_finger_link + └── ee_gripper_link + """ + full_frame = chain.print_link_tree() + assert full_frame_expected.strip() == full_frame.strip() + + serial_chain = pk.SerialChain(chain, "ee_gripper_link", "base_link") + serial_frame_expected = """ +base_link +└── shoulder_link + └── upper_arm_link + └── upper_forearm_link + └── lower_forearm_link + └── wrist_link + └── gripper_link + └── ee_arm_link + └── gripper_bar_link + └── fingers_link + └── ee_gripper_link + """ + serial_frame = serial_chain.print_link_tree() + assert serial_frame_expected.strip() == serial_frame.strip() + + # full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6 + assert chain.n_joints == 8 + assert serial_chain.n_joints == 6 + + serial_chain = pk.SerialChain(chain, "gripper_prop_link", "base_link") + serial_frame_expected = """ +base_link +└── shoulder_link + └── upper_arm_link + └── upper_forearm_link + └── lower_forearm_link + └── wrist_link + └── gripper_link + └── ee_arm_link + └── gripper_prop_link + """ + serial_frame = serial_chain.print_link_tree() + assert serial_frame_expected.strip() == serial_frame.strip() + + serial_chain = pk.SerialChain(chain, "ee_gripper_link", "gripper_link") + serial_frame_expected = """ + gripper_link +└── ee_arm_link + └── gripper_bar_link + └── fingers_link + └── ee_gripper_link + """ + serial_frame = serial_chain.print_link_tree() + assert serial_frame_expected.strip() == serial_frame.strip() + # only gripper_link is the parent frame of a joint in this serial chain + assert serial_chain.n_joints == 1 + + +if __name__ == "__main__": + test_extract_serial_chain_from_tree()