Skip to content

Commit

Permalink
Generate copies of frames with only direct ancestors in SerialChain
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonPi committed Aug 28, 2024
1 parent c7cda5e commit 2baf77c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
46 changes: 24 additions & 22 deletions src/pytorch_kinematics/chain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import lru_cache
from typing import Optional, Sequence

import copy
import numpy as np
import torch

Expand Down Expand Up @@ -430,28 +431,29 @@ class SerialChain(Chain):
"""

def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs):
if root_frame_name == "":
super().__init__(chain._root, **kwargs)
else:
super().__init__(chain.find_frame(root_frame_name), **kwargs)
if self._root is None:
raise ValueError("Invalid root frame name %s." % root_frame_name)
self._serial_frames = [self._root] + self._generate_serial_chain_recurse(self._root, end_frame_name)
if self._serial_frames is None:
raise ValueError("Invalid end frame name %s." % end_frame_name)

@staticmethod
def _generate_serial_chain_recurse(root_frame, end_frame_name):
for child in root_frame.children:
if child.name == end_frame_name:
# chop off any remaining tree after end frame
child.children = []
return [child]
else:
frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name)
if not frames is None:
return [child] + frames
return None
root_frame = chain._root if root_frame_name == "" else chain.find_frame(root_frame_name)
if root_frame is None:
raise ValueError("Invalid root frame name %s." % root_frame_name)
chain = Chain(root_frame, **kwargs)

# make a copy of those frames that includes only the chain up to the end effector
end_frame_idx = chain.get_frame_indices(end_frame_name)
ancestors = chain.parents_indices[end_frame_idx]

frames = []
# first pass create copies of the ancestor nodes
for idx in ancestors:
this_frame_name = chain.idx_to_frame[idx.item()]
this_frame = copy.deepcopy(chain.find_frame(this_frame_name))
if idx == end_frame_idx:
this_frame.children = []
frames.append(this_frame)
# second pass assign correct children (only the next one in the frame list)
for i in range(len(ancestors) - 1):
frames[i].children = [frames[i + 1]]

self._serial_frames = frames
super().__init__(frames[0], **kwargs)

def jacobian(self, th, locations=None, **kwargs):
if locations is not None:
Expand Down
3 changes: 1 addition & 2 deletions src/pytorch_kinematics/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,4 @@ def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""):
SerialChain object created from URDF.
"""
urdf_chain = build_chain_from_urdf(data)
return chain.SerialChain(urdf_chain, end_link_name,
"" if root_link_name == "" else root_link_name)
return chain.SerialChain(urdf_chain, end_link_name, root_link_name or '')

0 comments on commit 2baf77c

Please sign in to comment.