-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IK computation out of joint limits in PDEEPos controller using a subset of robot joints #754
Comments
Are you using the CPU or GPU simulation? It is a known issue at the moment that the underlying IK library pytorch_kinematics can generate joint solutions that go beyond joint limits sometimes UM-ARM-Lab/pytorch_kinematics#48, but this should only affect the GPU sim since we use pinnochio for CPU. |
Thank you very much for your prompt reply! I was using the GPU version. I just tried the CPU version, but it fails to find a solution. I thought it may be caused by the intermediate goal set by the delta action command being unreachable, so I found a fingertip position (in robot base frame) that is reachable using only the finger joints. Then I tried using this position as the input action command, with |
Are you sure it is reachable? Is there a reproducible script somewhere? |
Thank you for your reply! Here are the scripts I am using. This is the environment. from typing import Any, Dict, Union
import numpy as np
import sapien
import torch
import torch.random
from transforms3d.euler import euler2quat
from mani_skill.agents.robots import AllegroHandRight
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.sensors.camera import CameraConfig
from mani_skill.utils import common, sapien_utils
from mani_skill.utils.building import actors
from mani_skill.utils.registration import register_env
from mani_skill.utils.scene_builder.table import TableSceneBuilder
from mani_skill.utils.structs import Pose
from mani_skill.utils.structs.types import Array, GPUMemoryConfig, SimConfig
@register_env("TestEnvHand-v1", max_episode_steps=50)
class TestEnvHand(BaseEnv):
SUPPORTED_ROBOTS = ["allegro_hand_right"]
# Specify some supported robot types
agent: Union[AllegroHandRight]
# set some commonly used values
goal_radius = 0.1
cube_half_size = 0.02
def __init__(self, *args, robot_uids="allegro_hand_right", robot_init_qpos_noise=0.02, **kwargs):
# specifying robot_uids="panda" as the default means gym.make("PushCube-v1") will default to using the panda arm.
self.robot_init_qpos_noise = robot_init_qpos_noise
self.robot_default_qpos = np.array([
0.0, 0.0, 0.0, 0.0,
# 1.0, 0.0, 0.0, 0.0,
# 1.0, 0.0, 0.0, 0.0,
# 1.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0,
])
super().__init__(*args, robot_uids=robot_uids, **kwargs)
# Specify default simulation/gpu memory configurations to override any default values
@property
def _default_sim_config(self):
return SimConfig(
gpu_memory_config=GPUMemoryConfig(
found_lost_pairs_capacity=2**25, max_rigid_patch_count=2**18
)
)
@property
def _default_sensor_configs(self):
# registers one 128x128 camera looking at the robot, cube, and target
# a smaller sized camera will be lower quality, but render faster
pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1])
return [
CameraConfig(
"base_camera",
pose=pose,
width=128,
height=128,
fov=np.pi / 2,
near=0.01,
far=100,
)
]
@property
def _default_human_render_camera_configs(self):
# registers a more high-definition (512x512) camera used just for rendering when render_mode="rgb_array" or calling env.render_rgb_array()
pose = sapien_utils.look_at([0.6, 0.7, 0.6], [0.0, 0.0, 0.35])
return CameraConfig(
"render_camera", pose=pose, width=512, height=512, fov=1, near=0.01, far=100
)
def _load_agent(self, options: dict):
# set a reasonable initial pose for the agent that doesn't intersect other objects
super()._load_agent(options, sapien.Pose(p=[-0.615, 0, 0.3]))
def _load_scene(self, options: dict):
# we use a prebuilt scene builder class that automatically loads in a floor and table.
self.table_scene = TableSceneBuilder(
env=self, robot_init_qpos_noise=self.robot_init_qpos_noise
)
self.table_scene.build()
# we also add in red/white target to visualize where we want the cube to be pushed to
# we specify add_collisions=False as we only use this as a visual for videos and do not want it to affect the actual physics
# we finally specify the body_type to be "kinematic" so that the object stays in place
self.goal_region = actors.build_red_white_target(
self.scene,
radius=self.goal_radius,
thickness=1e-5,
name="goal_region",
add_collision=False,
body_type="kinematic",
)
loader = self.scene.create_urdf_loader()
loader.fix_root_link = False
articulation_builders = loader.parse(str("mani_skill/assets/bottle/test_bottle.urdf"))["articulation_builders"]
builder = articulation_builders[0]
builder.initial_pose = sapien.Pose(p=[0.2, 0.0, 0.5])
self.bottle = builder.build(name="my_articulation")
def _initialize_episode(self, env_idx: torch.Tensor, options: dict):
# use the torch.device context manager to automatically create tensors on CPU or CUDA depending on self.device, the device the environment runs on
with torch.device(self.device):
b = len(env_idx)
self.table_scene.initialize(env_idx)
xyz = torch.zeros((b,3))
xyz[..., 0] = 0.05
xyz[..., 2] = 0.05
q = [1, 0, 0, 0]
obj_pose = Pose.create_from_pq(p=xyz, q=q)
self.bottle.set_pose(obj_pose)
# here we set the location of that red/white target (the goal region). In particular here, we set the position to be in front of the cube
# and we further rotate 90 degrees on the y-axis to make the target object face up
target_region_xyz = xyz + torch.tensor([0.1 + self.goal_radius, 0, 0])
# set a little bit above 0 so the target is sitting on the table
target_region_xyz[..., 2] = 1e-3
self.goal_region.set_pose(
Pose.create_from_pq(
p=target_region_xyz,
q=euler2quat(0, np.pi / 2, 0),
)
)
self.agent.robot.set_qpos(self.robot_default_qpos)
def evaluate(self):
return {}
def compute_dense_reward(self, obs: Any, action: Array, info: Dict):
return 0
def compute_normalized_dense_reward(self, obs: Any, action: Array, info: Dict):
# this should be equal to compute_dense_reward / max possible reward
max_reward = 3.0
return self.compute_dense_reward(obs=obs, action=action, info=info) / max_reward I modified the def _controller_configs(self):
# -------------------------------------------------------------------------- #
# Arm
# -------------------------------------------------------------------------- #
joint_pos = PDJointPosControllerConfig(
self.joint_names,
None,
None,
self.joint_stiffness,
self.joint_damping,
self.joint_force_limit,
normalize_action=False,
)
joint_delta_pos = PDJointPosControllerConfig(
self.joint_names,
-0.1,
0.1,
self.joint_stiffness,
self.joint_damping,
self.joint_force_limit,
use_delta=True,
)
joint_target_delta_pos = deepcopy(joint_delta_pos)
joint_target_delta_pos.use_target = True
link_pos_lower = [-0.01, -0.01, -0.01]
link_pos_upper = [0.01, 0.01, 0.01]
fingertip_pos = PDEEPosControllerConfig(
joint_names=["joint_0.0", "joint_1.0", "joint_2.0", "joint_3.0"],
pos_lower=link_pos_lower,
pos_upper=link_pos_upper,
stiffness=self.joint_stiffness,
damping=self.joint_damping,
force_limit=self.joint_force_limit,
ee_link='link_3.0_tip',
urdf_path=self.urdf_path,
)
fingertip_pose = PDEEPoseControllerConfig(
joint_names=["joint_0.0", "joint_1.0", "joint_2.0", "joint_3.0"],
pos_lower=link_pos_lower,
pos_upper=link_pos_upper,
rot_lower=-0.1,
rot_upper=0.1,
stiffness=self.joint_stiffness,
damping=self.joint_damping,
force_limit=self.joint_force_limit,
ee_link='link_3.0_tip',
urdf_path=self.urdf_path,
use_delta=False,
)
controller_configs = dict(
pd_joint_delta_pos=joint_delta_pos,
pd_joint_pos=joint_pos,
pd_joint_target_delta_pos=joint_target_delta_pos,
pd_fingertip_pos=fingertip_pos,
pd_fingertip_pose=fingertip_pose
)
# Make a deepcopy in case users modify any config
return deepcopy_dict(controller_configs) Finally, this is the script that runs the environment. The action is the target fingertip pose in the world frame, and this pose is reached by setting the robot joint angles to import gymnasium as gym
import numpy as np
import torch
import sapien
from mani_skill.envs.sapien_env import BaseEnv
from mani_skill.utils.wrappers import RecordEpisode
from mani_skill.utils.geometry.rotation_conversions import quaternion_to_axis_angle
import tyro
from dataclasses import dataclass
from typing import List, Optional, Annotated, Union
@dataclass
class Args:
env_id: Annotated[str, tyro.conf.arg(aliases=["-e"])] = "TestEnvHand-v1"
"""The environment ID of the task you want to simulate"""
obs_mode: Annotated[str, tyro.conf.arg(aliases=["-o"])] = "none"
"""Observation mode"""
robot_uids: Annotated[Optional[str], tyro.conf.arg(aliases=["-r"])] = "allegro_hand_right"
"""Robot UID(s) to use. Can be a comma separated list of UIDs or empty string to have no agents. If not given then defaults to the environments default robot"""
sim_backend: Annotated[str, tyro.conf.arg(aliases=["-b"])] = "auto"
"""Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'"""
reward_mode: Optional[str] = None
"""Reward mode"""
num_envs: Annotated[int, tyro.conf.arg(aliases=["-n"])] = 1
"""Number of environments to run."""
# control_mode: Annotated[Optional[str], tyro.conf.arg(aliases=["-c"])] = "pd_arm_ee_delta_pose_hand_delta_fingertip_pos"
control_mode: Annotated[Optional[str], tyro.conf.arg(aliases=["-c"])] = "pd_fingertip_pose"
"""Control mode"""
render_mode: str = "rgb_array"
# render_mode: str = "human"
"""Render mode"""
shader: str = "default"
"""Change shader used for all cameras in the environment for rendering. Default is 'minimal' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer"""
record_dir: Optional[str] = None
"""Directory to save recordings"""
pause: Annotated[bool, tyro.conf.arg(aliases=["-p"])] = False
"""If using human render mode, auto pauses the simulation upon loading"""
quiet: bool = True
"""Disable verbose output."""
seed: Annotated[Optional[Union[int, List[int]]], tyro.conf.arg(aliases=["-s"])] = None
"""Seed(s) for random actions and simulator. Can be a single integer or a list of integers. Default is None (no seeds)"""
def main(args: Args):
np.set_printoptions(suppress=True, precision=3)
verbose = not args.quiet
if isinstance(args.seed, int):
args.seed = [args.seed]
if args.seed is not None:
np.random.seed(args.seed[0])
parallel_in_single_scene = args.render_mode == "human"
if args.render_mode == "human" and args.obs_mode in ["sensor_data", "rgb", "rgbd", "depth", "point_cloud"]:
print("Disabling parallel single scene/GUI render as observation mode is a visual one. Change observation mode to state or state_dict to see a parallel env render")
parallel_in_single_scene = False
if args.render_mode == "human" and args.num_envs == 1:
parallel_in_single_scene = False
env_kwargs = dict(
obs_mode=args.obs_mode,
reward_mode=args.reward_mode,
control_mode=args.control_mode,
render_mode=args.render_mode,
sensor_configs=dict(shader_pack=args.shader),
human_render_camera_configs=dict(shader_pack=args.shader),
viewer_camera_configs=dict(shader_pack=args.shader),
num_envs=args.num_envs,
sim_backend=args.sim_backend,
enable_shadow=True,
parallel_in_single_scene=parallel_in_single_scene,
)
if args.robot_uids is not None:
env_kwargs["robot_uids"] = tuple(args.robot_uids.split(","))
env: BaseEnv = gym.make(
args.env_id,
**env_kwargs
)
record_dir = args.record_dir
if record_dir:
record_dir = record_dir.format(env_id=args.env_id)
env = RecordEpisode(env, record_dir, info_on_video=False, save_trajectory=False, max_steps_per_video=env._max_episode_steps)
if verbose:
print("Observation space", env.observation_space)
print("Action space", env.action_space)
if env.unwrapped.agent is not None:
print("Control mode", env.unwrapped.control_mode)
print("Reward mode", env.unwrapped.reward_mode)
obs, _ = env.reset(seed=args.seed, options=dict(reconfigure=True))
if args.seed is not None and env.action_space is not None:
env.action_space.seed(args.seed[0])
if args.render_mode is not None:
viewer = env.render()
if isinstance(viewer, sapien.utils.Viewer):
viewer.paused = args.pause
env.render()
# Target index finger pose in world frame (index finger qpos = [0.0, 1.0, 1.0, 1.0])
target_pose_pq = np.array([-0.5309, 0.0438, 0.3016, 0.0707, -0.0031, 0.9965, -0.0435])
target_pos = target_pose_pq[:3]
target_q = target_pose_pq[3:]
target_ori = quaternion_to_axis_angle(torch.tensor(target_q))
target_pose = np.concatenate([target_pos, target_ori])
while True:
action = target_pose
print("unprocessed action", action)
obs, reward, terminated, truncated, info = env.step(action)
if verbose:
print("reward", reward)
print("terminated", terminated)
print("truncated", truncated)
print("info", info)
if args.render_mode is not None:
env.render()
if args.render_mode is None or args.render_mode != "human":
if (terminated | truncated).any():
break
env.close()
if record_dir:
print(f"Saving video to {record_dir}")
if __name__ == "__main__":
parsed_args = tyro.cli(Args)
main(parsed_args) Thank you again! |
Hello. I have been stuck trying to control a dexterous hand using fingertip position delta commands. The problem I am facing is the
compute_ik
function of theKinematics
class returning joint configurations that are far beyond the joint limits, and wondering if you have suggestions to resolve this.Here is my setup:
ee_link
, and the finger joint names are inputted for thejoint_names
.compute_ik
function of theKinematics
class.compute_ik
function exceed the joint limits, even when the delta action is small (e.g. 0.0 in x and y directions, and -0.005 in z direction).I would greatly appreciate any help. Thank you so much!
The text was updated successfully, but these errors were encountered: