diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dd7c6d9b2..8bc6718254 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: - id: pyupgrade args: ["--py310-plus"] # FIXME: This is a hack because Pytorch does not like: torch.Tensor | dict aliasing - exclude: "source/extensions/omni.isaac.lab/omni/isaac/lab/envs/common.py" + exclude: "source/extensions/omni.isaac.lab/omni/isaac/lab/envs/common.py|source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/image_plot.py" - repo: https://github.com/codespell-project/codespell rev: v2.2.6 hooks: diff --git a/source/extensions/omni.isaac.lab/config/extension.toml b/source/extensions/omni.isaac.lab/config/extension.toml index 9b53e611d5..2dbebe2a59 100644 --- a/source/extensions/omni.isaac.lab/config/extension.toml +++ b/source/extensions/omni.isaac.lab/config/extension.toml @@ -1,7 +1,7 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.27.26" +version = "0.29.1" # Description title = "Isaac Lab framework for Robot Learning" diff --git a/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst b/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst index e250c40900..9e05e76848 100644 --- a/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst +++ b/source/extensions/omni.isaac.lab/docs/CHANGELOG.rst @@ -1,6 +1,68 @@ Changelog --------- +0.29.1 (2024-12-15) +~~~~~~~~~~~~~~~~~~~ + +Changed +^^^^^^^ + +* Added call to update articulation kinematics after reset to ensure states are updated for non-rendering sensors. Previously, some changes in reset such as modifying joint states would not be reflected in the rigid body states immediately after reset. + + +0.29.0 (2024-12-15) +~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added UI interface to the Managers in the ManagerBasedEnv and MangerBasedRLEnv classes. +* Added UI widgets for :class:`LiveLinePlot` and :class:`ImagePlot`. +* Added ``ManagerLiveVisualizer/Cfg``: Given a ManagerBase (i.e. action_manager, observation_manager, etc) and a config file this class creates the the interface between managers and the UI. +* Added :class:`EnvLiveVisualizer`: A 'manager' of ManagerLiveVisualizer. This is added to the ManagerBasedEnv but is only called during the initialization of the managers in load_managers +* Added ``get_active_iterable_terms`` implementation methods to ActionManager, ObservationManager, CommandsManager, CurriculumManager, RewardManager, and TerminationManager. This method exports the active term data and labels for each manager and is called by ManagerLiveVisualizer. +* Additions to :class:`BaseEnvWindow` and :class:`RLEnvWindow` to register ManagerLiveVisualizer UI interfaces for the chosen managers. + + +0.28.0 (2024-12-15) +~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added observation history computation to :class:`omni.isaac.lab.manager.observation_manager.ObservationManager`. +* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationTermCfg` +* Added ``history_length`` and ``flatten_history_dim`` configuration parameters to :class:`omni.isaac.lab.manager.manager_term_cfg.ObservationGroupCfg` +* Added full buffer property to :class:`omni.isaac.lab.utils.buffers.circular_buffer.CircularBuffer` + + +0.27.29 (2024-12-15) +~~~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added action clip to all :class:`omni.isaac.lab.envs.mdp.actions`. + + +0.27.28 (2024-12-14) +~~~~~~~~~~~~~~~~~~~~ + +Changed +^^^^^^^ + +* Added check for error below threshold in state machines to ensure the state has been reached. + + +0.27.27 (2024-12-13) +~~~~~~~~~~~~~~~~~~~~ + +Fixed +^^^^^ + +* Fixed the shape of ``quat_w`` in the ``apply_actions`` method of :attr:`~omni.isaac.lab.env.mdp.NonHolonomicAction` (previously (N,B,4), now (N,4) since the number of root bodies B is required to be 1). Previously ``apply_actions`` errored because ``euler_xyz_from_quat`` requires inputs of shape (N,4). + + 0.27.26 (2024-12-11) ~~~~~~~~~~~~~~~~~~~~ @@ -54,7 +116,7 @@ Changed 0.27.21 (2024-12-06) -~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~ Fixed ^^^^^ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/direct_rl_env.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/direct_rl_env.py index 32dafdef33..5de734c4eb 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/direct_rl_env.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/direct_rl_env.py @@ -273,6 +273,10 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) indices = torch.arange(self.num_envs, dtype=torch.int64, device=self.device) self._reset_idx(indices) + # update articulation kinematics + self.scene.write_data_to_sim() + self.sim.forward() + # if sensors are added to the scene, make sure we render to reflect changes in reset if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset: self.sim.render() @@ -346,6 +350,9 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) if len(reset_env_ids) > 0: self._reset_idx(reset_env_ids) + # update articulation kinematics + self.scene.write_data_to_sim() + self.sim.forward() # if sensors are added to the scene, make sure we render to reflect changes in reset if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset: self.sim.render() diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_env.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_env.py index 548a3d2a6d..c2a7b4116c 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_env.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_env.py @@ -14,6 +14,7 @@ from omni.isaac.lab.managers import ActionManager, EventManager, ObservationManager, RecorderManager from omni.isaac.lab.scene import InteractiveScene from omni.isaac.lab.sim import SimulationContext +from omni.isaac.lab.ui.widgets import ManagerLiveVisualizer from omni.isaac.lab.utils.timer import Timer from .common import VecEnvObs @@ -148,6 +149,8 @@ def __init__(self, cfg: ManagerBasedEnvCfg): # we need to do this here after all the managers are initialized # this is because they dictate the sensors and commands right now if self.sim.has_gui() and self.cfg.ui_window_class_type is not None: + # setup live visualizers + self.setup_manager_visualizers() self._window = self.cfg.ui_window_class_type(self, window_name="IsaacLab") else: # if no window, then we don't need to store the window @@ -233,6 +236,14 @@ def load_managers(self): if self.__class__ == ManagerBasedEnv and "startup" in self.event_manager.available_modes: self.event_manager.apply(mode="startup") + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + } + """ Operations - MDP. """ @@ -269,15 +280,17 @@ def reset( # reset state of scene self._reset_idx(env_ids) - self.scene.write_data_to_sim() - - # trigger recorder terms for post-reset calls - self.recorder_manager.record_post_reset(env_ids) + # update articulation kinematics + self.scene.write_data_to_sim() + self.sim.forward() # if sensors are added to the scene, make sure we render to reflect changes in reset if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset: self.sim.render() + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + # compute observations self.obs_buf = self.observation_manager.compute() @@ -317,13 +330,16 @@ def reset_to( # set the state self.scene.reset_to(state, env_ids, is_relative=is_relative) - # trigger recorder terms for post-reset calls - self.recorder_manager.record_post_reset(env_ids) + # update articulation kinematics + self.sim.forward() # if sensors are added to the scene, make sure we render to reflect changes in reset if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset: self.sim.render() + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + # compute observations self.obs_buf = self.observation_manager.compute() diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py index 2968a20593..fda4dc11e2 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/manager_based_rl_env.py @@ -16,6 +16,7 @@ from omni.isaac.version import get_version from omni.isaac.lab.managers import CommandManager, CurriculumManager, RewardManager, TerminationManager +from omni.isaac.lab.ui.widgets import ManagerLiveVisualizer from .common import VecEnvStepReturn from .manager_based_env import ManagerBasedEnv @@ -132,6 +133,18 @@ def load_managers(self): if "startup" in self.event_manager.available_modes: self.event_manager.apply(mode="startup") + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + "command_manager": ManagerLiveVisualizer(manager=self.command_manager), + "termination_manager": ManagerLiveVisualizer(manager=self.termination_manager), + "reward_manager": ManagerLiveVisualizer(manager=self.reward_manager), + "curriculum_manager": ManagerLiveVisualizer(manager=self.curriculum_manager), + } + """ Operations - MDP """ @@ -204,9 +217,9 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self.recorder_manager.record_pre_reset(reset_env_ids) self._reset_idx(reset_env_ids) - - # this is needed to make joint positions set from reset events effective + # update articulation kinematics self.scene.write_data_to_sim() + self.sim.forward() # if sensors are added to the scene, make sure we render to reflect changes in reset if self.sim.has_rtx_sensors() and self.cfg.rerender_on_reset: diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/binary_joint_actions.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/binary_joint_actions.py index 5c2ba3fa15..45a7f14f10 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/binary_joint_actions.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/binary_joint_actions.py @@ -40,9 +40,10 @@ class BinaryJointAction(ActionTerm): cfg: actions_cfg.BinaryJointActionCfg """The configuration of the action term.""" - _asset: Articulation """The articulation asset on which the action term is applied.""" + _clip: torch.Tensor + """The clip applied to the input action.""" def __init__(self, cfg: actions_cfg.BinaryJointActionCfg, env: ManagerBasedEnv) -> None: # initialize the action term @@ -83,6 +84,17 @@ def __init__(self, cfg: actions_cfg.BinaryJointActionCfg, env: ManagerBasedEnv) ) self._close_command[index_list] = torch.tensor(value_list, device=self.device) + # parse clip + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + self._clip = torch.tensor([[-float("inf"), float("inf")]], device=self.device).repeat( + self.num_envs, self.action_dim, 1 + ) + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + self._clip[:, index_list] = torch.tensor(value_list, device=self.device) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") + """ Properties. """ @@ -115,6 +127,10 @@ def process_actions(self, actions: torch.Tensor): binary_mask = actions < 0 # compute the command self._processed_actions = torch.where(binary_mask, self._close_command, self._open_command) + if self.cfg.clip is not None: + self._processed_actions = torch.clamp( + self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1] + ) def reset(self, env_ids: Sequence[int] | None = None) -> None: self._raw_actions[env_ids] = 0.0 diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions.py index ee5586b7f2..e2f95987aa 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions.py @@ -50,6 +50,8 @@ class JointAction(ActionTerm): """The scaling factor applied to the input action.""" _offset: torch.Tensor | float """The offset applied to the input action.""" + _clip: torch.Tensor + """The clip applied to the input action.""" def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None: # initialize the action term @@ -94,6 +96,16 @@ def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> Non self._offset[:, index_list] = torch.tensor(value_list, device=self.device) else: raise ValueError(f"Unsupported offset type: {type(cfg.offset)}. Supported types are float and dict.") + # parse clip + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + self._clip = torch.tensor([[-float("inf"), float("inf")]], device=self.device).repeat( + self.num_envs, self.action_dim, 1 + ) + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + self._clip[:, index_list] = torch.tensor(value_list, device=self.device) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") """ Properties. @@ -120,6 +132,11 @@ def process_actions(self, actions: torch.Tensor): self._raw_actions[:] = actions # apply the affine transformations self._processed_actions = self._raw_actions * self._scale + self._offset + # clip actions + if self.cfg.clip is not None: + self._processed_actions = torch.clamp( + self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1] + ) def reset(self, env_ids: Sequence[int] | None = None) -> None: self._raw_actions[env_ids] = 0.0 diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions_to_limits.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions_to_limits.py index 3b31c9502a..b69ac4beda 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions_to_limits.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/joint_actions_to_limits.py @@ -44,6 +44,8 @@ class JointPositionToLimitsAction(ActionTerm): """The articulation asset on which the action term is applied.""" _scale: torch.Tensor | float """The scaling factor applied to the input action.""" + _clip: torch.Tensor + """The clip applied to the input action.""" def __init__(self, cfg: actions_cfg.JointPositionToLimitsActionCfg, env: ManagerBasedEnv): # initialize the action term @@ -76,6 +78,16 @@ def __init__(self, cfg: actions_cfg.JointPositionToLimitsActionCfg, env: Manager self._scale[:, index_list] = torch.tensor(value_list, device=self.device) else: raise ValueError(f"Unsupported scale type: {type(cfg.scale)}. Supported types are float and dict.") + # parse clip + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + self._clip = torch.tensor([[-float("inf"), float("inf")]], device=self.device).repeat( + self.num_envs, self.action_dim, 1 + ) + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + self._clip[:, index_list] = torch.tensor(value_list, device=self.device) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") """ Properties. @@ -102,6 +114,10 @@ def process_actions(self, actions: torch.Tensor): self._raw_actions[:] = actions # apply affine transformations self._processed_actions = self._raw_actions * self._scale + if self.cfg.clip is not None: + self._processed_actions = torch.clamp( + self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1] + ) # rescale the position targets if configured # this is useful when the input actions are in the range [-1, 1] if self.cfg.rescale_to_limits: diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/non_holonomic_actions.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/non_holonomic_actions.py index fc9ed89d6e..b95a058339 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/non_holonomic_actions.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/non_holonomic_actions.py @@ -11,6 +11,7 @@ import omni.log +import omni.isaac.lab.utils.string as string_utils from omni.isaac.lab.assets.articulation import Articulation from omni.isaac.lab.managers.action_manager import ActionTerm from omni.isaac.lab.utils.math import euler_xyz_from_quat @@ -59,6 +60,8 @@ class NonHolonomicAction(ActionTerm): """The scaling factor applied to the input action. Shape is (1, 2).""" _offset: torch.Tensor """The offset applied to the input action. Shape is (1, 2).""" + _clip: torch.Tensor + """The clip applied to the input action.""" def __init__(self, cfg: actions_cfg.NonHolonomicActionCfg, env: ManagerBasedEnv): # initialize the action term @@ -104,6 +107,16 @@ def __init__(self, cfg: actions_cfg.NonHolonomicActionCfg, env: ManagerBasedEnv) # save the scale and offset as tensors self._scale = torch.tensor(self.cfg.scale, device=self.device).unsqueeze(0) self._offset = torch.tensor(self.cfg.offset, device=self.device).unsqueeze(0) + # parse clip + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + self._clip = torch.tensor([[-float("inf"), float("inf")]], device=self.device).repeat( + self.num_envs, self.action_dim, 1 + ) + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + self._clip[:, index_list] = torch.tensor(value_list, device=self.device) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") """ Properties. @@ -129,10 +142,15 @@ def process_actions(self, actions): # store the raw actions self._raw_actions[:] = actions self._processed_actions = self.raw_actions * self._scale + self._offset + # clip actions + if self.cfg.clip is not None: + self._processed_actions = torch.clamp( + self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1] + ) def apply_actions(self): # obtain current heading - quat_w = self._asset.data.body_quat_w[:, self._body_idx] + quat_w = self._asset.data.body_quat_w[:, self._body_idx].view(self.num_envs, 4) yaw_w = euler_xyz_from_quat(quat_w)[2] # compute joint velocities targets self._joint_vel_command[:, 0] = torch.cos(yaw_w) * self.processed_actions[:, 0] # x diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/task_space_actions.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/task_space_actions.py index a8a1108f50..55fd5126b8 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/task_space_actions.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/actions/task_space_actions.py @@ -12,6 +12,7 @@ import omni.log import omni.isaac.lab.utils.math as math_utils +import omni.isaac.lab.utils.string as string_utils from omni.isaac.lab.assets.articulation import Articulation from omni.isaac.lab.controllers.differential_ik import DifferentialIKController from omni.isaac.lab.managers.action_manager import ActionTerm @@ -42,6 +43,8 @@ class DifferentialInverseKinematicsAction(ActionTerm): """The articulation asset on which the action term is applied.""" _scale: torch.Tensor """The scaling factor applied to the input action. Shape is (1, action_dim).""" + _clip: torch.Tensor + """The clip applied to the input action.""" def __init__(self, cfg: actions_cfg.DifferentialInverseKinematicsActionCfg, env: ManagerBasedEnv): # initialize the action term @@ -101,6 +104,17 @@ def __init__(self, cfg: actions_cfg.DifferentialInverseKinematicsActionCfg, env: else: self._offset_pos, self._offset_rot = None, None + # parse clip + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + self._clip = torch.tensor([[-float("inf"), float("inf")]], device=self.device).repeat( + self.num_envs, self.action_dim, 1 + ) + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + self._clip[:, index_list] = torch.tensor(value_list, device=self.device) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") + """ Properties. """ @@ -138,6 +152,10 @@ def process_actions(self, actions: torch.Tensor): # store the raw actions self._raw_actions[:] = actions self._processed_actions[:] = self.raw_actions * self._scale + if self.cfg.clip is not None: + self._processed_actions = torch.clamp( + self._processed_actions, min=self._clip[:, :, 0], max=self._clip[:, :, 1] + ) # obtain quantities from simulation ee_pos_curr, ee_quat_curr = self._compute_frame_pose() # set command into controller diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/base_env_window.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/base_env_window.py index 850ad0a355..4ae47df8e3 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/base_env_window.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/base_env_window.py @@ -16,6 +16,8 @@ import omni.usd from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics +from omni.isaac.lab.ui.widgets import ManagerLiveVisualizer + if TYPE_CHECKING: import omni.ui @@ -57,6 +59,9 @@ def __init__(self, env: ManagerBasedEnv, window_name: str = "IsaacLab"): *self.env.scene.articulations.keys(), ] + # Listeners for environment selection changes + self._ui_listeners: list[ManagerLiveVisualizer] = [] + print("Creating window for environment.") # create window for UI self.ui_window = omni.ui.Window( @@ -80,6 +85,10 @@ def __init__(self, env: ManagerBasedEnv, window_name: str = "IsaacLab"): self._build_viewer_frame() # create collapsable frame for debug visualization self._build_debug_vis_frame() + with self.ui_window_elements["debug_frame"]: + with self.ui_window_elements["debug_vstack"]: + self._visualize_manager(title="Actions", class_name="action_manager") + self._visualize_manager(title="Observations", class_name="observation_manager") def __del__(self): """Destructor for the window.""" @@ -200,9 +209,6 @@ def _build_debug_vis_frame(self): that has it implemented. If the element does not have a debug visualization implemented, a label is created instead. """ - # import omni.isaac.ui.ui_utils as ui_utils - # import omni.ui - # create collapsable frame for debug visualization self.ui_window_elements["debug_frame"] = omni.ui.CollapsableFrame( title="Scene Debug Visualization", @@ -234,6 +240,26 @@ def _build_debug_vis_frame(self): if elem is not None: self._create_debug_vis_ui_element(name, elem) + def _visualize_manager(self, title: str, class_name: str) -> None: + """Checks if the attribute with the name 'class_name' can be visualized. If yes, create vis interface. + + Args: + title: The title of the manager visualization frame. + class_name: The name of the manager to visualize. + """ + + if hasattr(self.env, class_name) and class_name in self.env.manager_visualizers: + manager = self.env.manager_visualizers[class_name] + if hasattr(manager, "has_debug_vis_implementation"): + self._create_debug_vis_ui_element(title, manager) + else: + print( + f"ManagerLiveVisualizer cannot be created for manager: {class_name}, has_debug_vis_implementation" + " does not exist" + ) + else: + print(f"ManagerLiveVisualizer cannot be created for manager: {class_name}, Manager does not exist") + """ Custom callbacks for UI elements. """ @@ -357,6 +383,9 @@ def _set_viewer_env_index_fn(self, model: omni.ui.SimpleIntModel): raise ValueError("Viewport camera controller is not initialized! Please check the rendering mode.") # store the desired env index, UI is 1-indexed vcc.set_view_env_index(model.as_int - 1) + # notify additional listeners + for listener in self._ui_listeners: + listener.set_env_selection(model.as_int - 1) """ Helper functions - UI building. @@ -379,14 +408,30 @@ def _create_debug_vis_ui_element(self, name: str, elem: object): alignment=omni.ui.Alignment.LEFT_CENTER, tooltip=text, ) + has_cfg = hasattr(elem, "cfg") and elem.cfg is not None + is_checked = False + if has_cfg: + is_checked = (hasattr(elem.cfg, "debug_vis") and elem.cfg.debug_vis) or ( + hasattr(elem, "debug_vis") and elem.debug_vis + ) self.ui_window_elements[f"{name}_cb"] = SimpleCheckBox( model=omni.ui.SimpleBoolModel(), enabled=elem.has_debug_vis_implementation, - checked=elem.cfg.debug_vis if elem.cfg else False, + checked=is_checked, on_checked_fn=lambda value, e=weakref.proxy(elem): e.set_debug_vis(value), ) omni.isaac.ui.ui_utils.add_line_rect_flourish() + # Create a panel for the debug visualization + if isinstance(elem, ManagerLiveVisualizer): + self.ui_window_elements[f"{name}_panel"] = omni.ui.Frame(width=omni.ui.Fraction(1)) + if not elem.set_vis_frame(self.ui_window_elements[f"{name}_panel"]): + print(f"Frame failed to set for ManagerLiveVisualizer: {name}") + + # Add listener for environment selection changes + if isinstance(elem, ManagerLiveVisualizer): + self._ui_listeners.append(elem) + async def _dock_window(self, window_title: str): """Docks the custom UI window to the property window.""" # wait for the window to be created diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/manager_based_rl_env_window.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/manager_based_rl_env_window.py index 3e149e1fef..9fa69c4468 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/manager_based_rl_env_window.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/envs/ui/manager_based_rl_env_window.py @@ -34,5 +34,7 @@ def __init__(self, env: ManagerBasedRLEnv, window_name: str = "IsaacLab"): with self.ui_window_elements["main_vstack"]: with self.ui_window_elements["debug_frame"]: with self.ui_window_elements["debug_vstack"]: - self._create_debug_vis_ui_element("commands", self.env.command_manager) - self._create_debug_vis_ui_element("actions", self.env.action_manager) + self._visualize_manager(title="Commands", class_name="command_manager") + self._visualize_manager(title="Rewards", class_name="reward_manager") + self._visualize_manager(title="Curriculum", class_name="curriculum_manager") + self._visualize_manager(title="Termination", class_name="termination_manager") diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/action_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/action_manager.py index 2f729cde23..30ed9d41ef 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/action_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/action_manager.py @@ -106,6 +106,7 @@ def set_debug_vis(self, debug_vis: bool) -> bool: # check if debug visualization is supported if not self.has_debug_vis_implementation: return False + # toggle debug visualization objects self._set_debug_vis_impl(debug_vis) # toggle debug visualization handles @@ -262,7 +263,26 @@ def has_debug_vis_implementation(self) -> bool: Operations. """ - def set_debug_vis(self, debug_vis: bool) -> bool: + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + idx = 0 + for name, term in self._terms.items(): + term_actions = self._action[env_idx, idx : idx + term.action_dim].cpu() + terms.append((name, term_actions.tolist())) + idx += term.action_dim + return terms + + def set_debug_vis(self, debug_vis: bool): """Sets whether to visualize the action data. Args: debug_vis: Whether to visualize the action data. diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/command_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/command_manager.py index 50a717b6d4..0c50e7a00d 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/command_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/command_manager.py @@ -296,7 +296,26 @@ def has_debug_vis_implementation(self) -> bool: Operations. """ - def set_debug_vis(self, debug_vis: bool) -> bool: + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + + terms = [] + idx = 0 + for name, term in self._terms.items(): + terms.append((name, term.command[env_idx].cpu().tolist())) + idx += term.command.shape[1] + return terms + + def set_debug_vis(self, debug_vis: bool): """Sets whether to visualize the command data. Args: diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/curriculum_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/curriculum_manager.py index 92fe7e7ef7..9e316cccf8 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/curriculum_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/curriculum_manager.py @@ -138,6 +138,40 @@ def compute(self, env_ids: Sequence[int] | None = None): state = term_cfg.func(self._env, env_ids, **term_cfg.params) self._curriculum_state[name] = state + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + + terms = [] + + for term_name, term_state in self._curriculum_state.items(): + if term_state is not None: + # deal with dict + data = [] + + if isinstance(term_state, dict): + # each key is a separate state to log + for key, value in term_state.items(): + if isinstance(value, torch.Tensor): + value = value.item() + terms[term_name].append(value) + else: + # log directly if not a dict + if isinstance(term_state, torch.Tensor): + term_state = term_state.item() + data.append(term_state) + terms.append((term_name, data)) + + return terms + """ Helper functions. """ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_base.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_base.py index 4da002934f..d30545a391 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_base.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_base.py @@ -193,6 +193,16 @@ def find_terms(self, name_keys: str | Sequence[str]) -> list[str]: # return the matching names return string_utils.resolve_matching_names(name_keys, list_of_strings)[1] + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Returns: + The active terms. + """ + raise NotImplementedError + """ Implementation specific. """ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py index 54c0b726d0..19c54f4e70 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/manager_term_cfg.py @@ -93,6 +93,9 @@ class for more details. debug_vis: bool = False """Whether to visualize debug information. Defaults to False.""" + clip: dict[str, tuple] | None = None + """Clip range for the action (dict of regex expressions). Defaults to None.""" + ## # Command manager. @@ -177,6 +180,19 @@ class ObservationTermCfg(ManagerTermBaseCfg): please make sure the length of the tuple matches the dimensions of the tensor outputted from the term. """ + history_length: int = 0 + """Number of past observations to store in the observation buffers. Defaults to 0, meaning no history. + + Observation history initializes to empty, but is filled with the first append after reset or initialization. Subsequent history + only adds a single entry to the history buffer. If flatten_history_dim is set to True, the source data of shape + (N, H, D, ...) where N is the batch dimension and H is the history length will be reshaped to a 2D tensor of shape + (N, H*D*...). Otherwise, the data will be returned as is. + """ + + flatten_history_dim: bool = True + """Whether or not the observation manager should flatten history-based observation terms to a 2D (N, D) tensor. + Defaults to True.""" + @configclass class ObservationGroupCfg: @@ -198,6 +214,22 @@ class ObservationGroupCfg: Otherwise, no corruption is applied. """ + history_length: int | None = None + """Number of past observation to store in the observation buffers for all observation terms in group. + + This parameter will override :attr:`ObservationTermCfg.history_length` if set. Defaults to None. If None, each + terms history will be controlled on a per term basis. See :class:`ObservationTermCfg` for details on history_length + implementation. + """ + + flatten_history_dim: bool = True + """Flag to flatten history-based observation terms to a 2D (num_env, D) tensor for all observation terms in group. + Defaults to True. + + This parameter will override all :attr:`ObservationTermCfg.flatten_history_dim` in the group if + ObservationGroupCfg.history_length is set. + """ + ## # Event manager diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/observation_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/observation_manager.py index 6bc9b0374b..1fc7660d79 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/observation_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/observation_manager.py @@ -8,12 +8,14 @@ from __future__ import annotations import inspect +import numpy as np import torch from collections.abc import Sequence from prettytable import PrettyTable from typing import TYPE_CHECKING from omni.isaac.lab.utils import modifiers +from omni.isaac.lab.utils.buffers import CircularBuffer from .manager_base import ManagerBase, ManagerTermBase from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg @@ -45,6 +47,11 @@ class ObservationManager(ManagerBase): concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the group configuration to False. + Observations can also have history. This means a running history is updated per sim step. History can be controlled + per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and + :attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg` + where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering. + The observation manager can be used to compute observations for all the groups or for a specific group. The observations are computed by calling the registered functions for each term in the group. The functions are called in the order of the terms in the group. The functions are expected to return a tensor with shape @@ -93,6 +100,9 @@ def __init__(self, cfg: object, env: ManagerBasedEnv): else: self._group_obs_dim[group_name] = group_term_dims + # Stores the latest observations. + self._obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] | None = None + def __str__(self) -> str: """Returns: A string representation for the observation manager.""" msg = f" contains {len(self._group_obs_term_names)} groups.\n" @@ -123,6 +133,43 @@ def __str__(self) -> str: return msg + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + + if self._obs_buffer is None: + self.compute() + obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] = self._obs_buffer + + for group_name, _ in self._group_obs_dim.items(): + if not self.group_obs_concatenate[group_name]: + for name, term in obs_buffer[group_name].items(): + terms.append((group_name + "-" + name, term[env_idx].cpu().tolist())) + continue + + idx = 0 + # add info for each term + data = obs_buffer[group_name] + for name, shape in zip( + self._group_obs_term_names[group_name], + self._group_obs_term_dim[group_name], + ): + data_length = np.prod(shape) + term = data[env_idx, idx : idx + data_length] + terms.append((group_name + "-" + name, term.cpu().tolist())) + idx += data_length + + return terms + """ Properties. """ @@ -174,12 +221,17 @@ def group_obs_concatenate(self) -> dict[str, bool]: def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, float]: # call all terms that are classes - for group_cfg in self._group_obs_class_term_cfgs.values(): + for group_name, group_cfg in self._group_obs_class_term_cfgs.items(): for term_cfg in group_cfg: term_cfg.func.reset(env_ids=env_ids) + # reset terms with history + for term_name in self._group_obs_term_names[group_name]: + if term_name in self._group_obs_term_history_buffer[group_name]: + self._group_obs_term_history_buffer[group_name][term_name].reset(batch_ids=env_ids) # call all modifiers that are classes for mod in self._group_obs_class_modifiers: mod.reset(env_ids=env_ids) + # nothing to log here return {} @@ -200,6 +252,9 @@ def compute(self) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: for group_name in self._group_obs_term_names: obs_buffer[group_name] = self.compute_group(group_name) # otherwise return a dict with observations of all groups + + # Cache the observations. + self._obs_buffer = obs_buffer return obs_buffer def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tensor]: @@ -248,7 +303,7 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso obs_terms = zip(group_term_names, self._group_obs_term_cfgs[group_name]) # evaluate terms: compute, add noise, clip, scale, custom modifiers - for name, term_cfg in obs_terms: + for term_name, term_cfg in obs_terms: # compute term's value obs: torch.Tensor = term_cfg.func(self._env, **term_cfg.params).clone() # apply post-processing @@ -261,8 +316,17 @@ def compute_group(self, group_name: str) -> torch.Tensor | dict[str, torch.Tenso obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) if term_cfg.scale is not None: obs = obs.mul_(term_cfg.scale) - # add value to list - group_obs[name] = obs + # Update the history buffer if observation term has history enabled + if term_cfg.history_length > 0: + self._group_obs_term_history_buffer[group_name][term_name].append(obs) + if term_cfg.flatten_history_dim: + group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer.reshape( + self._env.num_envs, -1 + ) + else: + group_obs[term_name] = self._group_obs_term_history_buffer[group_name][term_name].buffer + else: + group_obs[term_name] = obs # concatenate all observations in the group together if self._group_obs_concatenate[group_name]: @@ -283,7 +347,7 @@ def _prepare_terms(self): self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() self._group_obs_concatenate: dict[str, bool] = dict() - + self._group_obs_term_history_buffer: dict[str, dict] = dict() # create a list to store modifiers that are classes # we store it as a separate list to only call reset on them and prevent unnecessary calls self._group_obs_class_modifiers: list[modifiers.ModifierBase] = list() @@ -309,6 +373,7 @@ def _prepare_terms(self): self._group_obs_term_dim[group_name] = list() self._group_obs_term_cfgs[group_name] = list() self._group_obs_class_term_cfgs[group_name] = list() + group_entry_history_buffer: dict[str, CircularBuffer] = dict() # read common config for the group self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms # check if config is dict already @@ -319,7 +384,7 @@ def _prepare_terms(self): # iterate over all the terms in each group for term_name, term_cfg in group_cfg_items: # skip non-obs settings - if term_name in ["enable_corruption", "concatenate_terms"]: + if term_name in ["enable_corruption", "concatenate_terms", "history_length", "flatten_history_dim"]: continue # check for non config if term_cfg is None: @@ -335,12 +400,26 @@ def _prepare_terms(self): # check noise settings if not group_cfg.enable_corruption: term_cfg.noise = None + # check group history params and override terms + if group_cfg.history_length is not None: + term_cfg.history_length = group_cfg.history_length + term_cfg.flatten_history_dim = group_cfg.flatten_history_dim # add term config to list to list self._group_obs_term_names[group_name].append(term_name) self._group_obs_term_cfgs[group_name].append(term_cfg) - # call function the first time to fill up dimensions obs_dims = tuple(term_cfg.func(self._env, **term_cfg.params).shape) + # create history buffers and calculate history term dimensions + if term_cfg.history_length > 0: + group_entry_history_buffer[term_name] = CircularBuffer( + max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device + ) + old_dims = list(obs_dims) + old_dims.insert(1, term_cfg.history_length) + obs_dims = tuple(old_dims) + if term_cfg.flatten_history_dim: + obs_dims = (obs_dims[0], np.prod(obs_dims[1:])) + self._group_obs_term_dim[group_name].append(obs_dims[1:]) # if scale is set, check if single float or tuple @@ -411,3 +490,5 @@ def _prepare_terms(self): self._group_obs_class_term_cfgs[group_name].append(term_cfg) # call reset (in-case above call to get obs dims changed the state) term_cfg.func.reset() + # add history buffers for each group + self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/reward_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/reward_manager.py index 5e17e0516e..49369ca58a 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/reward_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/reward_manager.py @@ -61,6 +61,9 @@ def __init__(self, cfg: object, env: ManagerBasedRLEnv): # create buffer for managing reward per environment self._reward_buf = torch.zeros(self.num_envs, dtype=torch.float, device=self.device) + # Buffer which stores the current step reward for each term for each environment + self._step_reward = torch.zeros((self.num_envs, len(self._term_names)), dtype=torch.float, device=self.device) + def __str__(self) -> str: """Returns: A string representation for reward manager.""" msg = f" contains {len(self._term_names)} active terms.\n" @@ -148,6 +151,9 @@ def compute(self, dt: float) -> torch.Tensor: # update episodic sum self._episode_sums[name] += value + # Update current reward for this step. + self._step_reward[:, self._term_names.index(name)] = value / dt + return self._reward_buf """ @@ -186,6 +192,22 @@ def get_term_cfg(self, term_name: str) -> RewardTermCfg: # return the configuration return self._term_cfgs[self._term_names.index(term_name)] + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + for idx, name in enumerate(self._term_names): + terms.append((name, [self._step_reward[env_idx, idx].cpu().item()])) + return terms + """ Helper functions. """ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/termination_manager.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/termination_manager.py index 77b32f2a53..e19ed64fd2 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/termination_manager.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/managers/termination_manager.py @@ -184,6 +184,22 @@ def get_term(self, name: str) -> torch.Tensor: """ return self._term_dones[name] + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + for key in self._term_dones.keys(): + terms.append((key, [self._term_dones[key][env_idx].float().cpu().item()])) + return terms + """ Operations - Term settings. """ diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/sim/simulation_context.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/sim/simulation_context.py index 104afc538d..d36345a6d3 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/sim/simulation_context.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/sim/simulation_context.py @@ -416,6 +416,14 @@ def get_setting(self, name: str) -> Any: """ return self._settings.get(name) + def forward(self) -> None: + """Updates articulation kinematics and fabric for rendering.""" + if self._fabric_iface is not None: + if self.physics_sim_view is not None and self.is_playing(): + # Update the articulations' link's poses before rendering + self.physics_sim_view.update_articulations_kinematic() + self._update_fabric(0.0, 0.0) + """ Operations - Override (standalone) """ @@ -486,11 +494,7 @@ def render(self, mode: RenderMode | None = None): self.set_setting("/app/player/playSimulations", True) else: # manually flush the fabric data to update Hydra textures - if self._fabric_iface is not None: - if self.physics_sim_view is not None and self.is_playing(): - # Update the articulations' link's poses before rendering - self.physics_sim_view.update_articulations_kinematic() - self._update_fabric(0.0, 0.0) + self.forward() # render the simulation # note: we don't call super().render() anymore because they do above operation inside # and we don't want to do it twice. We may remove it once we drop support for Isaac Sim 2022.2. diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/__init__.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/__init__.py new file mode 100644 index 0000000000..8891b428a4 --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from .image_plot import ImagePlot +from .line_plot import LiveLinePlot +from .manager_live_visualizer import ManagerLiveVisualizer +from .ui_visualizer_base import UiVisualizerBase diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/image_plot.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/image_plot.py new file mode 100644 index 0000000000..4b3cb1d6d0 --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/image_plot.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +from matplotlib import cm +from typing import TYPE_CHECKING, Optional + +import carb +import omni +import omni.log + +from .ui_widget_wrapper import UIWidgetWrapper + +if TYPE_CHECKING: + import omni.isaac.ui + import omni.ui + + +class ImagePlot(UIWidgetWrapper): + """An image plot widget to display live data. + + It has the following Layout where the mode frame is only useful for depth images: + +-------------------------------------------------------+ + | containing_frame | + |+-----------------------------------------------------+| + | main_plot_frame | + ||+---------------------------------------------------+|| + ||| plot_frames ||| + ||| ||| + ||| ||| + ||| (Image Plot Data) ||| + ||| ||| + ||| ||| + |||+-------------------------------------------------+||| + ||| mode_frame ||| + ||| ||| + ||| [x][Absolute] [x][Grayscaled] [ ][Colorized] ||| + |+-----------------------------------------------------+| + +-------------------------------------------------------+ + + """ + + def __init__( + self, + image: Optional[np.ndarray] = None, + label: str = "", + widget_height: int = 200, + show_min_max: bool = True, + unit: tuple[float, str] = (1, ""), + ): + """Create an XY plot UI Widget with axis scaling, legends, and support for multiple plots. + + Overlapping data is most accurately plotted when centered in the frame with reasonable axis scaling. + Pressing down the mouse gives the x and y values of each function at an x coordinate. + + Args: + image: Image to display + label: Short descriptive text to the left of the plot + widget_height: Height of the plot in pixels + show_min_max: Whether to show the min and max values of the image + unit: Tuple of (scale, name) for the unit of the image + """ + self._show_min_max = show_min_max + self._unit_scale = unit[0] + self._unit_name = unit[1] + + self._curr_mode = "None" + + self._has_built = False + + self._enabled = True + + self._byte_provider = omni.ui.ByteImageProvider() + if image is None: + carb.log_warn("image is NONE") + image = np.ones((480, 640, 3), dtype=np.uint8) * 255 + image[:, :, 0] = 0 + image[:, :240, 1] = 0 + + # if image is channel first, convert to channel last + if image.ndim == 3 and image.shape[0] in [1, 3, 4]: + image = np.moveaxis(image, 0, -1) + + self._aspect_ratio = image.shape[1] / image.shape[0] + self._widget_height = widget_height + self._label = label + self.update_image(image) + + plot_frame = self._create_ui_widget() + + super().__init__(plot_frame) + + def setEnabled(self, enabled: bool): + self._enabled = enabled + + def update_image(self, image: np.ndarray): + if not self._enabled: + return + + # if image is channel first, convert to channel last + if image.ndim == 3 and image.shape[0] in [1, 3, 4]: + image = np.moveaxis(image, 0, -1) + + height, width = image.shape[:2] + + if self._curr_mode == "Normalization": + image = (image - image.min()) / (image.max() - image.min()) + image = (image * 255).astype(np.uint8) + elif self._curr_mode == "Colorization": + if image.ndim == 3 and image.shape[2] == 3: + omni.log.warn("Colorization mode is only available for single channel images") + else: + image = (image - image.min()) / (image.max() - image.min()) + colormap = cm.get_cmap("jet") + if image.ndim == 3 and image.shape[2] == 1: + image = (colormap(image).squeeze(2) * 255).astype(np.uint8) + else: + image = (colormap(image) * 255).astype(np.uint8) + + # convert image to 4-channel RGBA + if image.ndim == 2 or (image.ndim == 3 and image.shape[2] == 1): + image = np.dstack((image, image, image, np.full((height, width, 1), 255, dtype=np.uint8))) + + elif image.ndim == 3 and image.shape[2] == 3: + image = np.dstack((image, np.full((height, width, 1), 255, dtype=np.uint8))) + + self._byte_provider.set_bytes_data(image.flatten().data, [width, height]) + + def update_min_max(self, image: np.ndarray): + if self._show_min_max and hasattr(self, "_min_max_label"): + non_inf = image[np.isfinite(image)].flatten() + if len(non_inf) > 0: + self._min_max_label.text = self._get_unit_description( + np.min(non_inf), np.max(non_inf), np.median(non_inf) + ) + else: + self._min_max_label.text = self._get_unit_description(0, 0) + + def _create_ui_widget(self): + containing_frame = omni.ui.Frame(build_fn=self._build_widget) + return containing_frame + + def _get_unit_description(self, min_value: float, max_value: float, median_value: float = None): + return ( + f"Min: {min_value * self._unit_scale:.2f} {self._unit_name} Max:" + f" {max_value * self._unit_scale:.2f} {self._unit_name}" + + (f" Median: {median_value * self._unit_scale:.2f} {self._unit_name}" if median_value is not None else "") + ) + + def _build_widget(self): + + with omni.ui.VStack(spacing=3): + with omni.ui.HStack(): + # Write the leftmost label for what this plot is + omni.ui.Label( + self._label, width=omni.isaac.ui.ui_utils.LABEL_WIDTH, alignment=omni.ui.Alignment.LEFT_TOP + ) + with omni.ui.Frame(width=self._aspect_ratio * self._widget_height, height=self._widget_height): + self._base_plot = omni.ui.ImageWithProvider(self._byte_provider) + + if self._show_min_max: + self._min_max_label = omni.ui.Label(self._get_unit_description(0, 0)) + + omni.ui.Spacer(height=8) + self._mode_frame = omni.ui.Frame(build_fn=self._build_mode_frame) + + omni.ui.Spacer(width=5) + self._has_built = True + + def _build_mode_frame(self): + """Build the frame containing the mode selection for the plots. + + This is an internal function to build the frame containing the mode selection for the plots. This function + should only be called from within the build function of a frame. + + The built widget has the following layout: + +-------------------------------------------------------+ + | legends_frame | + ||+---------------------------------------------------+|| + ||| ||| + ||| [x][Series 1] [x][Series 2] [ ][Series 3] ||| + ||| ||| + |||+-------------------------------------------------+||| + |+-----------------------------------------------------+| + +-------------------------------------------------------+ + """ + with omni.ui.HStack(): + with omni.ui.HStack(): + + def _change_mode(value): + self._curr_mode = value + + omni.isaac.ui.ui_utils.dropdown_builder( + label="Mode", + type="dropdown", + items=["Original", "Normalization", "Colorization"], + tooltip="Select a mode", + on_clicked_fn=_change_mode, + ) diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/line_plot.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/line_plot.py new file mode 100644 index 0000000000..682b1f0dcf --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/line_plot.py @@ -0,0 +1,603 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import colorsys +import numpy as np +from typing import TYPE_CHECKING + +import omni +from omni.isaac.core.simulation_context import SimulationContext + +from .ui_widget_wrapper import UIWidgetWrapper + +if TYPE_CHECKING: + import omni.isaac.ui + import omni.ui + + +class LiveLinePlot(UIWidgetWrapper): + """A 2D line plot widget to display live data. + + + This widget is used to display live data in a 2D line plot. It can be used to display multiple series + in the same plot. + + It has the following Layout: + +-------------------------------------------------------+ + | containing_frame | + |+-----------------------------------------------------+| + | main_plot_frame | + ||+---------------------------------------------------+|| + ||| plot_frames + grid lines (Z_stacked) ||| + ||| ||| + ||| ||| + ||| (Live Plot Data) ||| + ||| ||| + ||| ||| + |||+-------------------------------------------------+||| + ||| legends_frame ||| + ||| ||| + ||| [x][Series 1] [x][Series 2] [ ][Series 3] ||| + |||+-------------------------------------------------+||| + ||| limits_frame ||| + ||| ||| + ||| [Y-Limits] [min] [max] [Autoscale] ||| + |||+-------------------------------------------------+||| + ||| filter_frame ||| + ||| ||| + ||| ||| + |+-----------------------------------------------------+| + +-------------------------------------------------------+ + + """ + + def __init__( + self, + y_data: list[list[float]], + y_min: float = -10, + y_max: float = 10, + plot_height: int = 150, + show_legend: bool = True, + legends: list[str] | None = None, + max_datapoints: int = 200, + ): + """Create a new LiveLinePlot widget. + + Args: + y_data: A list of lists of floats containing the data to plot. Each list of floats represents a series in the plot. + y_min: The minimum y value to display. Defaults to -10. + y_max: The maximum y value to display. Defaults to 10. + plot_height: The height of the plot in pixels. Defaults to 150. + show_legend: Whether to display the legend. Defaults to True. + legends: A list of strings containing the legend labels for each series. If None, the default labels are "Series_0", "Series_1", etc. Defaults to None. + max_datapoints: The maximum number of data points to display. If the number of data points exceeds this value, the oldest data points are removed. Defaults to 200. + """ + super().__init__(self._create_ui_widget()) + self.plot_height = plot_height + self.show_legend = show_legend + self._legends = legends if legends is not None else ["Series_" + str(i) for i in range(len(y_data))] + self._y_data = y_data + self._colors = self._get_distinct_hex_colors(len(y_data)) + self._y_min = y_min if y_min is not None else -10 + self._y_max = y_max if y_max is not None else 10 + self._max_data_points = max_datapoints + self._show_legend = show_legend + self._series_visible = [True for _ in range(len(y_data))] + self._plot_frames = [] + self._plots = [] + self._plot_selected_values = [] + self._is_built = False + self._filter_frame = None + self._filter_mode = None + self._last_values = None + self._is_paused = False + + # Gets populated when widget is built + self._main_plot_frame = None + + self._autoscale_model = omni.ui.SimpleBoolModel(True) + + """Properties""" + + @property + def autoscale_mode(self) -> bool: + return self._autoscale_model.as_bool + + @property + def y_data(self) -> list[list[float]]: + """The current data in the plot.""" + return self._y_data + + @property + def y_min(self) -> float: + """The current minimum y value.""" + return self._y_min + + @property + def y_max(self) -> float: + """The current maximum y value.""" + return self._y_max + + @property + def legends(self) -> list[str]: + """The current legend labels.""" + return self._legends + + """ General Functions """ + + def clear(self): + """Clears the plot.""" + self._y_data = [[] for _ in range(len(self._y_data))] + self._last_values = None + + for plt in self._plots: + plt.set_data() + + # self._container_frame.rebuild() + + def add_datapoint(self, y_coords: list[float]): + """Add a data point to the plot. + + The data point is added to the end of the plot. If the number of data points exceeds the maximum number + of data points, the oldest data point is removed. + + ``y_coords`` is assumed to be a list of floats with the same length as the number of series in the plot. + + Args: + y_coords: A list of floats containing the y coordinates of the new data points. + """ + + for idx, y_coord in enumerate(y_coords): + + if len(self._y_data[idx]) > self._max_data_points: + self._y_data[idx] = self._y_data[idx][1:] + + if self._filter_mode == "Lowpass": + if self._last_values is not None: + alpha = 0.8 + y_coord = self._y_data[idx][-1] * alpha + y_coord * (1 - alpha) + elif self._filter_mode == "Integrate": + if self._last_values is not None: + y_coord = self._y_data[idx][-1] + y_coord + elif self._filter_mode == "Derivative": + if self._last_values is not None: + y_coord = (y_coord - self._last_values[idx]) / SimulationContext.instance().get_rendering_dt() + + self._y_data[idx].append(float(y_coord)) + + if self._main_plot_frame is None: + # Widget not built, not visible + return + + # Check if the widget has been built, i.e. the plot references have been created. + if not self._is_built or self._is_paused: + return + + if len(self._y_data) != len(self._plots): + # Plots gotten out of sync, rebuild the widget + self._main_plot_frame.rebuild() + return + + if self.autoscale_mode: + self._rescale_btn_pressed() + + for idx, plt in enumerate(self._plots): + plt.set_data(*self._y_data[idx]) + + self._last_values = y_coords + # Autoscale the y-axis to the current data + + """ + Internal functions for building the UI. + """ + + def _build_stacked_plots(self, grid: bool = True): + """Builds multiple plots stacked on top of each other to display multiple series. + + This is an internal function to build the plots. It should not be called from outside the class and only + from within the build function of a frame. + + The built widget has the following layout: + +-------------------------------------------------------+ + | main_plot_frame | + ||+---------------------------------------------------+|| + ||| ||| + ||| y_max|*******-------------------*******| ||| + ||| |-------*****-----------**--------| ||| + ||| 0|------------**-----***-----------| ||| + ||| |--------------***----------------| ||| + ||| y_min|---------------------------------| ||| + ||| ||| + |||+-------------------------------------------------+||| + + + Args: + grid: Whether to display grid lines. Defaults to True. + """ + + # Reset lists which are populated in the build function + self._plot_frames = [] + + # Define internal builder function + def _build_single_plot(y_data: list[float], color: int, plot_idx: int): + """Build a single plot. + + This is an internal function to build a single plot with the given data and color. This function + should only be called from within the build function of a frame. + + Args: + y_data: The data to plot. + color: The color of the plot. + """ + plot = omni.ui.Plot( + omni.ui.Type.LINE, + self._y_min, + self._y_max, + *y_data, + height=self.plot_height, + style={"color": color, "background_color": 0x0}, + ) + + if len(self._plots) <= plot_idx: + self._plots.append(plot) + self._plot_selected_values.append(omni.ui.SimpleStringModel("")) + else: + self._plots[plot_idx] = plot + + # Begin building the widget + with omni.ui.HStack(): + # Space to the left to add y-axis labels + omni.ui.Spacer(width=20) + + # Built plots for each time series stacked on top of each other + with omni.ui.ZStack(): + # Background rectangle + omni.ui.Rectangle( + height=self.plot_height, + style={ + "background_color": 0x0, + "border_color": omni.ui.color.white, + "border_width": 0.4, + "margin": 0.0, + }, + ) + + # Draw grid lines and labels + if grid: + # Calculate the number of grid lines to display + # Absolute range of the plot + plot_range = self._y_max - self._y_min + grid_resolution = 10 ** np.floor(np.log10(0.5 * plot_range)) + + plot_range /= grid_resolution + + # Fraction of the plot range occupied by the first and last grid line + first_space = (self._y_max / grid_resolution) - np.floor(self._y_max / grid_resolution) + last_space = np.ceil(self._y_min / grid_resolution) - self._y_min / grid_resolution + + # Number of grid lines to display + n_lines = int(plot_range - first_space - last_space) + + plot_resolution = self.plot_height / plot_range + + with omni.ui.VStack(): + omni.ui.Spacer(height=plot_resolution * first_space) + + # Draw grid lines + with omni.ui.VGrid(row_height=plot_resolution): + for grid_line_idx in range(n_lines): + # Create grid line + with omni.ui.ZStack(): + omni.ui.Line( + style={ + "color": 0xAA8A8777, + "background_color": 0x0, + "border_width": 0.4, + }, + alignment=omni.ui.Alignment.CENTER_TOP, + height=0, + ) + with omni.ui.Placer(offset_x=-20): + omni.ui.Label( + f"{(self._y_max - first_space * grid_resolution - grid_line_idx * grid_resolution):.3f}", + width=8, + height=8, + alignment=omni.ui.Alignment.RIGHT_TOP, + style={ + "color": 0xFFFFFFFF, + "font_size": 8, + }, + ) + + # Create plots for each series + for idx, (data, color) in enumerate(zip(self._y_data, self._colors)): + plot_frame = omni.ui.Frame( + build_fn=lambda y_data=data, plot_idx=idx, color=color: _build_single_plot( + y_data, color, plot_idx + ), + ) + plot_frame.visible = self._series_visible[idx] + self._plot_frames.append(plot_frame) + + # Create an invisible frame on top that will give a helpful tooltip + self._tooltip_frame = omni.ui.Plot( + height=self.plot_height, + style={"color": 0xFFFFFFFF, "background_color": 0x0}, + ) + + self._tooltip_frame.set_mouse_pressed_fn(self._mouse_moved_on_plot) + + # Create top label for the y-axis + with omni.ui.Placer(offset_x=-20, offset_y=-8): + omni.ui.Label( + f"{self._y_max:.3f}", + width=8, + height=2, + alignment=omni.ui.Alignment.LEFT_TOP, + style={"color": 0xFFFFFFFF, "font_size": 8}, + ) + + # Create bottom label for the y-axis + with omni.ui.Placer(offset_x=-20, offset_y=self.plot_height): + omni.ui.Label( + f"{self._y_min:.3f}", + width=8, + height=2, + alignment=omni.ui.Alignment.LEFT_BOTTOM, + style={"color": 0xFFFFFFFF, "font_size": 8}, + ) + + def _mouse_moved_on_plot(self, x, y, *args): + # Show a tooltip with x,y and function values + if len(self._y_data) == 0 or len(self._y_data[0]) == 0: + # There is no data in the plots, so do nothing + return + + for idx, plot in enumerate(self._plots): + x_pos = plot.screen_position_x + width = plot.computed_width + + location_x = (x - x_pos) / width + + data = self._y_data[idx] + n_samples = len(data) + selected_sample = int(location_x * n_samples) + value = data[selected_sample] + # save the value in scientific notation + self._plot_selected_values[idx].set_value(f"{value:.3f}") + + def _build_legends_frame(self): + """Build the frame containing the legend for the plots. + + This is an internal function to build the frame containing the legend for the plots. This function + should only be called from within the build function of a frame. + + The built widget has the following layout: + +-------------------------------------------------------+ + | legends_frame | + ||+---------------------------------------------------+|| + ||| ||| + ||| [x][Series 1] [x][Series 2] [ ][Series 3] ||| + ||| ||| + |||+-------------------------------------------------+||| + |+-----------------------------------------------------+| + +-------------------------------------------------------+ + """ + if not self._show_legend: + return + + with omni.ui.HStack(): + omni.ui.Spacer(width=32) + + # Find the longest legend to determine the width of the frame + max_legend = max([len(legend) for legend in self._legends]) + CHAR_WIDTH = 8 + with omni.ui.VGrid( + row_height=omni.isaac.ui.ui_utils.LABEL_HEIGHT, + column_width=max_legend * CHAR_WIDTH + 6, + ): + for idx in range(len(self._y_data)): + with omni.ui.HStack(): + model = omni.ui.SimpleBoolModel() + model.set_value(self._series_visible[idx]) + omni.ui.CheckBox(model=model, tooltip="", width=4) + model.add_value_changed_fn(lambda val, idx=idx: self._change_plot_visibility(idx, val.as_bool)) + omni.ui.Spacer(width=2) + with omni.ui.VStack(): + omni.ui.Label( + self._legends[idx], + width=max_legend * CHAR_WIDTH, + alignment=omni.ui.Alignment.LEFT, + style={"color": self._colors[idx], "font_size": 12}, + ) + omni.ui.StringField( + model=self._plot_selected_values[idx], + width=max_legend * CHAR_WIDTH, + alignment=omni.ui.Alignment.LEFT, + style={"color": self._colors[idx], "font_size": 10}, + read_only=True, + ) + + def _build_limits_frame(self): + """Build the frame containing the controls for the y-axis limits. + + This is an internal function to build the frame containing the controls for the y-axis limits. This function + should only be called from within the build function of a frame. + + The built widget has the following layout: + +-------------------------------------------------------+ + | limits_frame | + ||+---------------------------------------------------+|| + ||| ||| + ||| Limits [min] [max] [Re-Sacle] ||| + ||| Autoscale[x] ||| + ||| ------------------------------------------- ||| + |||+-------------------------------------------------+||| + """ + with omni.ui.VStack(): + with omni.ui.HStack(): + omni.ui.Label( + "Limits", + width=omni.isaac.ui.ui_utils.LABEL_WIDTH, + alignment=omni.ui.Alignment.LEFT_CENTER, + ) + + self.lower_limit_drag = omni.ui.FloatDrag(name="min", enabled=True, alignment=omni.ui.Alignment.CENTER) + y_min_model = self.lower_limit_drag.model + y_min_model.set_value(self._y_min) + y_min_model.add_value_changed_fn(lambda x: self._set_y_min(x.as_float)) + omni.ui.Spacer(width=2) + + self.upper_limit_drag = omni.ui.FloatDrag(name="max", enabled=True, alignment=omni.ui.Alignment.CENTER) + y_max_model = self.upper_limit_drag.model + y_max_model.set_value(self._y_max) + y_max_model.add_value_changed_fn(lambda x: self._set_y_max(x.as_float)) + omni.ui.Spacer(width=2) + + omni.ui.Button( + "Re-Scale", + width=omni.isaac.ui.ui_utils.BUTTON_WIDTH, + clicked_fn=self._rescale_btn_pressed, + alignment=omni.ui.Alignment.LEFT_CENTER, + style=omni.isaac.ui.ui_utils.get_style(), + ) + + omni.ui.CheckBox(model=self._autoscale_model, tooltip="", width=4) + + omni.ui.Line( + style={"color": 0x338A8777}, + width=omni.ui.Fraction(1), + alignment=omni.ui.Alignment.CENTER, + ) + + def _build_filter_frame(self): + """Build the frame containing the filter controls. + + This is an internal function to build the frame containing the filter controls. This function + should only be called from within the build function of a frame. + + The built widget has the following layout: + +-------------------------------------------------------+ + | filter_frame | + ||+---------------------------------------------------+|| + ||| ||| + ||| ||| + ||| ||| + |||+-------------------------------------------------+||| + |+-----------------------------------------------------+| + +-------------------------------------------------------+ + """ + with omni.ui.VStack(): + with omni.ui.HStack(): + + def _filter_changed(value): + self.clear() + self._filter_mode = value + + omni.isaac.ui.ui_utils.dropdown_builder( + label="Filter", + type="dropdown", + items=["None", "Lowpass", "Integrate", "Derivative"], + tooltip="Select a filter", + on_clicked_fn=_filter_changed, + ) + + def _toggle_paused(): + self._is_paused = not self._is_paused + + # Button + omni.ui.Button( + "Play/Pause", + width=omni.isaac.ui.ui_utils.BUTTON_WIDTH, + clicked_fn=_toggle_paused, + alignment=omni.ui.Alignment.LEFT_CENTER, + style=omni.isaac.ui.ui_utils.get_style(), + ) + + def _create_ui_widget(self): + """Create the full UI widget.""" + + def _build_widget(): + self._is_built = False + with omni.ui.VStack(): + self._main_plot_frame = omni.ui.Frame(build_fn=self._build_stacked_plots) + omni.ui.Spacer(height=8) + self._legends_frame = omni.ui.Frame(build_fn=self._build_legends_frame) + omni.ui.Spacer(height=8) + self._limits_frame = omni.ui.Frame(build_fn=self._build_limits_frame) + omni.ui.Spacer(height=8) + self._filter_frame = omni.ui.Frame(build_fn=self._build_filter_frame) + self._is_built = True + + containing_frame = omni.ui.Frame(build_fn=_build_widget) + + return containing_frame + + """ UI Actions Listener Functions """ + + def _change_plot_visibility(self, idx: int, visible: bool): + """Change the visibility of a plot at position idx.""" + self._series_visible[idx] = visible + self._plot_frames[idx].visible = visible + # self._main_plot_frame.rebuild() + + def _set_y_min(self, val: float): + """Update the y-axis minimum.""" + self._y_min = val + self.lower_limit_drag.model.set_value(val) + self._main_plot_frame.rebuild() + + def _set_y_max(self, val: float): + """Update the y-axis maximum.""" + self._y_max = val + self.upper_limit_drag.model.set_value(val) + self._main_plot_frame.rebuild() + + def _rescale_btn_pressed(self): + """Autoscale the y-axis to the current data.""" + if any(self._series_visible): + y_min = np.round( + min([min(y) for idx, y in enumerate(self._y_data) if self._series_visible[idx]]), + 4, + ) + y_max = np.round( + max([max(y) for idx, y in enumerate(self._y_data) if self._series_visible[idx]]), + 4, + ) + if y_min == y_max: + y_max += 1e-4 # Make sure axes don't collapse + + self._y_max = y_max + self._y_min = y_min + + if hasattr(self, "lower_limit_drag") and hasattr(self, "upper_limit_drag"): + self.lower_limit_drag.model.set_value(self._y_min) + self.upper_limit_drag.model.set_value(self._y_max) + + self._main_plot_frame.rebuild() + + """ Helper Functions """ + + def _get_distinct_hex_colors(self, num_colors) -> list[int]: + """ + This function returns a list of distinct colors for plotting. + + Args: + num_colors (int): the number of colors to generate + + Returns: + List[int]: a list of distinct colors in hexadecimal format 0xFFBBGGRR + """ + # Generate equally spaced colors in HSV space + rgb_colors = [ + colorsys.hsv_to_rgb(hue / num_colors, 0.75, 1) for hue in np.linspace(0, num_colors - 1, num_colors) + ] + # Convert to 0-255 RGB values + rgb_colors = [[int(c * 255) for c in rgb] for rgb in rgb_colors] + # Convert to 0xFFBBGGRR format + hex_colors = [0xFF * 16**6 + c[2] * 16**4 + c[1] * 16**2 + c[0] for c in rgb_colors] + return hex_colors diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/manager_live_visualizer.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/manager_live_visualizer.py new file mode 100644 index 0000000000..e6ae4f7310 --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/manager_live_visualizer.py @@ -0,0 +1,297 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import numpy +import weakref +from dataclasses import MISSING +from typing import TYPE_CHECKING + +import carb +import omni.kit.app +from omni.isaac.core.simulation_context import SimulationContext + +from omni.isaac.lab.managers import ManagerBase +from omni.isaac.lab.utils import configclass + +from .image_plot import ImagePlot +from .line_plot import LiveLinePlot +from .ui_visualizer_base import UiVisualizerBase + +if TYPE_CHECKING: + import omni.ui + + +@configclass +class ManagerLiveVisualizerCfg: + "Configuration for ManagerLiveVisualizer" + + debug_vis: bool = False + """Flag used to set status of the live visualizers on startup. Defaults to closed.""" + manager_name: str = MISSING + """Manager name that corresponds to the manager of interest in the ManagerBasedEnv and ManagerBasedRLEnv""" + term_names: list[str] | dict[str, list[str]] | None = None + """Specific term names specified in a Manager config that are chosen to be plotted. Defaults to None. + + If None all terms will be plotted. For managers that utilize Groups (i.e. ObservationGroup) use a dictionary of + {group_names: [term_names]}. + """ + + +class ManagerLiveVisualizer(UiVisualizerBase): + """A interface object used to transfer data from a manager to a UI widget. This class handles the creation of UI + Widgets for selected terms given a ManagerLiveVisualizerCfg. + """ + + def __init__(self, manager: ManagerBase, cfg: ManagerLiveVisualizerCfg = ManagerLiveVisualizerCfg()): + """Initialize ManagerLiveVisualizer. + + Args: + manager: The manager with terms to be plotted. The manager must have a get_active_iterable_terms method. + cfg: The configuration file used to select desired manager terms to be plotted. + """ + + self._manager = manager + self.debug_vis = cfg.debug_vis + self._env_idx: int = 0 + self.cfg = cfg + self._viewer_env_idx = 0 + self._vis_frame: omni.ui.Frame + self._vis_window: omni.ui.Window + + # evaluate chosen terms if no terms provided use all available. + self.term_names = [] + + if self.cfg.term_names is not None: + # extract chosen terms + if isinstance(self.cfg.term_names, list): + for term_name in self.cfg.term_names: + if term_name in self._manager.active_terms: + self.term_names.append(term_name) + else: + carb.log_err( + f"ManagerVisualizer Failure: ManagerTerm ({term_name}) does not exist in" + f" Manager({self.cfg.manager_name})" + ) + + # extract chosen group-terms + elif isinstance(self.cfg.term_names, dict): + # if manager is using groups and terms are saved as a dictionary + if isinstance(self._manager.active_terms, dict): + for group, terms in self.cfg.term_names: + if group in self._manager.active_terms.keys(): + for term_name in terms: + if term_name in self._manager.active_terms[group]: + self.term_names.append(f"{group}-{term_name}") + else: + carb.log_err( + f"ManagerVisualizer Failure: ManagerTerm ({term_name}) does not exist in" + f" Group({group})" + ) + else: + carb.log_err( + f"ManagerVisualizer Failure: Group ({group}) does not exist in" + f" Manager({self.cfg.manager_name})" + ) + else: + carb.log_err( + f"ManagerVisualizer Failure: Manager({self.cfg.manager_name}) does not utilize grouping of" + " terms." + ) + + # + # Implementation checks + # + + @property + def get_vis_frame(self) -> omni.ui.Frame: + """Getter for the UI Frame object tied to this visualizer.""" + return self._vis_frame + + @property + def get_vis_window(self) -> omni.ui.Window: + """Getter for the UI Window object tied to this visualizer.""" + return self._vis_window + + # + # Setters + # + + def set_debug_vis(self, debug_vis: bool): + """Set the debug visualization external facing function. + + Args: + debug_vis: Whether to enable or disable the debug visualization. + """ + self._set_debug_vis_impl(debug_vis) + + # + # Implementations + # + + def _set_env_selection_impl(self, env_idx: int): + """Update the index of the selected environment to display. + + Args: + env_idx: The index of the selected environment. + """ + if env_idx > 0 and env_idx < self._manager.num_envs: + self._env_idx = env_idx + else: + carb.log_warn(f"Environment index is out of range (0,{self._manager.num_envs})") + + def _set_vis_frame_impl(self, frame: omni.ui.Frame): + """Updates the assigned frame that can be used for visualizations. + + Args: + frame: The debug visualization frame. + """ + self._vis_frame = frame + + def _debug_vis_callback(self, event): + """Callback for the debug visualization event.""" + + if not SimulationContext.instance().is_playing(): + # Visualizers have not been created yet. + return + + # get updated data and update visualization + for (_, term), vis in zip( + self._manager.get_active_iterable_terms(env_idx=self._env_idx), self._term_visualizers + ): + if isinstance(vis, LiveLinePlot): + vis.add_datapoint(term) + elif isinstance(vis, ImagePlot): + vis.update_image(numpy.array(term)) + + def _set_debug_vis_impl(self, debug_vis: bool): + """Set the debug visualization implementation. + + Args: + debug_vis: Whether to enable or disable debug visualization. + """ + + if not hasattr(self, "_vis_frame"): + raise RuntimeError("No frame set for debug visualization.") + + # Clear internal visualizers + self._term_visualizers = [] + self._vis_frame.clear() + + if debug_vis: + # if enabled create a subscriber for the post update event if it doesn't exist + if not hasattr(self, "_debug_vis_handle") or self._debug_vis_handle is None: + app_interface = omni.kit.app.get_app_interface() + self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( + lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event) + ) + else: + # if disabled remove the subscriber if it exists + if self._debug_vis_handle is not None: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + + self._vis_frame.visible = False + return + + self._vis_frame.visible = True + + with self._vis_frame: + with omni.ui.VStack(): + # Add a plot in a collapsible frame for each term available + for name, term in self._manager.get_active_iterable_terms(env_idx=self._env_idx): + if name in self.term_names or len(self.term_names) == 0: + frame = omni.ui.CollapsableFrame( + name, + collapsed=False, + style={"border_color": 0xFF8A8777, "padding": 4}, + ) + with frame: + # create line plot for single or multivariable signals + len_term_shape = len(numpy.array(term).shape) + if len_term_shape <= 2: + plot = LiveLinePlot( + y_data=[[elem] for elem in term], + plot_height=150, + show_legend=True, + ) + self._term_visualizers.append(plot) + # create an image plot for 2d and greater data (i.e. mono and rgb images) + elif len_term_shape == 3: + image = ImagePlot( + image=numpy.array(term), + label=name, + ) + self._term_visualizers.append(image) + else: + carb.log_warn( + f"ManagerLiveVisualizer: Term ({name}) is not a supported data type for" + " visualization." + ) + frame.collapsed = True + + self._debug_vis = debug_vis + + +@configclass +class DefaultManagerBasedEnvLiveVisCfg: + """Default configuration to use for the ManagerBasedEnv. Each chosen manager assumes all terms will be plotted.""" + + action_live_vis = ManagerLiveVisualizerCfg(manager_name="action_manager") + observation_live_vis = ManagerLiveVisualizerCfg(manager_name="observation_manager") + + +@configclass +class DefaultManagerBasedRLEnvLiveVisCfg(DefaultManagerBasedEnvLiveVisCfg): + """Default configuration to use for the ManagerBasedRLEnv. Each chosen manager assumes all terms will be plotted.""" + + curriculum_live_vis = ManagerLiveVisualizerCfg(manager_name="curriculum_manager") + command_live_vis = ManagerLiveVisualizerCfg(manager_name="command_manager") + reward_live_vis = ManagerLiveVisualizerCfg(manager_name="reward_manager") + termination_live_vis = ManagerLiveVisualizerCfg(manager_name="termination_manager") + + +class EnvLiveVisualizer: + """A class to handle all ManagerLiveVisualizers used in an Environment.""" + + def __init__(self, cfg: object, managers: dict[str, ManagerBase]): + """Initialize the EnvLiveVisualizer. + + Args: + cfg: The configuration file containing terms of ManagerLiveVisualizers. + managers: A dictionary of labeled managers. i.e. {"manager_name",manager}. + """ + self.cfg = cfg + self.managers = managers + self._prepare_terms() + + def _prepare_terms(self): + self._manager_visualizers: dict[str, ManagerLiveVisualizer] = dict() + + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + for term_name, term_cfg in cfg_items: + # check if term config is None + if term_cfg is None: + continue + # check if term config is viable + if isinstance(term_cfg, ManagerLiveVisualizerCfg): + # find appropriate manager name + manager = self.managers[term_cfg.manager_name] + self._manager_visualizers[term_cfg.manager_name] = ManagerLiveVisualizer(manager=manager, cfg=term_cfg) + else: + raise TypeError( + f"Provided EnvLiveVisualizer term: '{term_name}' is not of type ManagerLiveVisualizerCfg" + ) + + @property + def manager_visualizers(self) -> dict[str, ManagerLiveVisualizer]: + """A dictionary of labeled ManagerLiveVisualizers associated manager name as key.""" + return self._manager_visualizers diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_visualizer_base.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_visualizer_base.py new file mode 100644 index 0000000000..a28eb3983a --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_visualizer_base.py @@ -0,0 +1,148 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import omni.ui + + +class UiVisualizerBase: + """Base Class for components that support debug visualizations that requires access to some UI elements. + + This class provides a set of functions that can be used to assign ui interfaces. + + The following functions are provided: + + * :func:`set_debug_vis`: Assigns a debug visualization interface. This function is called by the main UI + when the checkbox for debug visualization is toggled. + * :func:`set_vis_frame`: Assigns a small frame within the isaac lab tab that can be used to visualize debug + information. Such as e.g. plots or images. It is called by the main UI on startup to create the frame. + * :func:`set_window`: Assigngs the main window that is used by the main UI. This allows the user + to have full controller over all UI elements. But be warned, with great power comes great responsibility. + """ + + """ + Exposed Properties + """ + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the component has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_debug_vis_impl) + return "NotImplementedError" not in source_code + + @property + def has_vis_frame_implementation(self) -> bool: + """Whether the component has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_vis_frame_impl) + return "NotImplementedError" not in source_code + + @property + def has_window_implementation(self) -> bool: + """Whether the component has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_window_impl) + return "NotImplementedError" not in source_code + + @property + def has_env_selection_implementation(self) -> bool: + """Whether the component has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_env_selection_impl) + return "NotImplementedError" not in source_code + + """ + Exposed Setters + """ + + def set_env_selection(self, env_selection: int) -> bool: + """Sets the selected environment id. + + This function is called by the main UI when the user selects a different environment. + + Args: + env_selection: The currently selected environment id. + + Returns: + Whether the environment selection was successfully set. False if the component + does not support environment selection. + """ + # check if environment selection is supported + if not self.has_env_selection_implementation: + return False + # set environment selection + self._set_env_selection_impl(env_selection) + return True + + def set_window(self, window: omni.ui.Window) -> bool: + """Sets the current main ui window. + + This function is called by the main UI when the window is created. It allows the component + to add custom UI elements to the window or to control the window and its elements. + + Args: + window: The ui window. + + Returns: + Whether the window was successfully set. False if the component + does not support this functionality. + """ + # check if window is supported + if not self.has_window_implementation: + return False + # set window + self._set_window_impl(window) + return True + + def set_vis_frame(self, vis_frame: omni.ui.Frame) -> bool: + """Sets the debug visualization frame. + + This function is called by the main UI when the window is created. It allows the component + to modify a small frame within the orbit tab that can be used to visualize debug information. + + Args: + vis_frame: The debug visualization frame. + + Returns: + Whether the debug visualization frame was successfully set. False if the component + does not support debug visualization. + """ + # check if debug visualization is supported + if not self.has_vis_frame_implementation: + return False + # set debug visualization frame + self._set_vis_frame_impl(vis_frame) + return True + + """ + Internal Implementation + """ + + def _set_env_selection_impl(self, env_idx: int): + """Set the environment selection.""" + raise NotImplementedError(f"Environment selection is not implemented for {self.__class__.__name__}.") + + def _set_window_impl(self, window: omni.ui.Window): + """Set the window.""" + raise NotImplementedError(f"Window is not implemented for {self.__class__.__name__}.") + + def _set_debug_vis_impl(self, debug_vis: bool): + """Set debug visualization state.""" + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + def _set_vis_frame_impl(self, vis_frame: omni.ui.Frame): + """Set debug visualization into visualization objects. + + This function is responsible for creating the visualization objects if they don't exist + and input ``debug_vis`` is True. If the visualization objects exist, the function should + set their visibility into the stage. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_widget_wrapper.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_widget_wrapper.py new file mode 100644 index 0000000000..998f6c6da9 --- /dev/null +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/ui/widgets/ui_widget_wrapper.py @@ -0,0 +1,51 @@ +# Copyright (c) 2022-2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# This file has been adapted from _isaac_sim/exts/omni.isaac.ui/omni/isaac/ui/element_wrappers/base_ui_element_wrappers.py + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import omni + +if TYPE_CHECKING: + import omni.ui + + +class UIWidgetWrapper: + """ + Base class for creating wrappers around any subclass of omni.ui.Widget in order to provide an easy interface + for creating and managing specific types of widgets such as state buttons or file pickers. + """ + + def __init__(self, container_frame: omni.ui.Frame): + self._container_frame = container_frame + + @property + def container_frame(self) -> omni.ui.Frame: + return self._container_frame + + @property + def enabled(self) -> bool: + return self.container_frame.enabled + + @enabled.setter + def enabled(self, value: bool): + self.container_frame.enabled = value + + @property + def visible(self) -> bool: + return self.container_frame.visible + + @visible.setter + def visible(self, value: bool): + self.container_frame.visible = value + + def cleanup(self): + """ + Perform any necessary cleanup + """ + pass diff --git a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/buffers/circular_buffer.py b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/buffers/circular_buffer.py index 1617344358..197878a016 100644 --- a/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/buffers/circular_buffer.py +++ b/source/extensions/omni.isaac.lab/omni/isaac/lab/utils/buffers/circular_buffer.py @@ -75,6 +75,16 @@ def current_length(self) -> torch.Tensor: """ return torch.minimum(self._num_pushes, self._max_len) + @property + def buffer(self) -> torch.Tensor: + """Complete circular buffer with most recent entry at the end and oldest entry at the beginning. + Returns: + Complete circular buffer with most recent entry at the end and oldest entry at the beginning of dimension 1. The shape is [batch_size, max_length, data.shape[1:]]. + """ + buf = self._buffer.clone() + buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0) + return torch.transpose(buf, dim0=0, dim1=1) + """ Operations. """ @@ -89,8 +99,10 @@ def reset(self, batch_ids: Sequence[int] | None = None): if batch_ids is None: batch_ids = slice(None) # reset the number of pushes for the specified batch indices - # note: we don't need to reset the buffer since it will be overwritten. The pointer handles this. self._num_pushes[batch_ids] = 0 + if self._buffer is not None: + # set buffer at batch_id reset indices to 0.0 so that the buffer() getter returns the cleared circular buffer after reset. + self._buffer[:, batch_ids, :] = 0.0 def append(self, data: torch.Tensor): """Append the data to the circular buffer. @@ -106,7 +118,7 @@ def append(self, data: torch.Tensor): if data.shape[0] != self.batch_size: raise ValueError(f"The input data has {data.shape[0]} environments while expecting {self.batch_size}") - # at the fist call, initialize the buffer + # at the first call, initialize the buffer size if self._buffer is None: self._pointer = -1 self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device) @@ -114,7 +126,12 @@ def append(self, data: torch.Tensor): self._pointer = (self._pointer + 1) % self.max_length # add the new data to the last layer self._buffer[self._pointer] = data.to(self._device) - # increment number of number of pushes + # Check for batches with zero pushes and initialize all values in batch to first append + if 0 in self._num_pushes.tolist(): + fill_ids = [i for i, x in enumerate(self._num_pushes.tolist()) if x == 0] + self._num_pushes.tolist().index(0) if 0 in self._num_pushes.tolist() else None + self._buffer[:, fill_ids, :] = data.to(self._device)[fill_ids] + # increment number of number of pushes for all batches self._num_pushes += 1 def __getitem__(self, key: torch.Tensor) -> torch.Tensor: diff --git a/source/extensions/omni.isaac.lab/test/managers/test_observation_manager.py b/source/extensions/omni.isaac.lab/test/managers/test_observation_manager.py index c624fb2bd1..253fc39228 100644 --- a/source/extensions/omni.isaac.lab/test/managers/test_observation_manager.py +++ b/source/extensions/omni.isaac.lab/test/managers/test_observation_manager.py @@ -131,8 +131,51 @@ class SampleGroupCfg(ObservationGroupCfg): self.obs_man = ObservationManager(cfg, self.env) self.assertEqual(len(self.obs_man.active_terms["policy"]), 5) # print the expected string + obs_man_str = str(self.obs_man) print() - print(self.obs_man) + print(obs_man_str) + obs_man_str_split = obs_man_str.split("|") + term_1_str_index = obs_man_str_split.index(" term_1 ") + term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip() + self.assertEqual(term_1_str_shape, "(4,)") + + def test_str_with_history(self): + """Test the string representation of the observation manager with history terms.""" + + TERM_1_HISTORY = 5 + + @configclass + class MyObservationManagerCfg: + """Test config class for observation manager.""" + + @configclass + class SampleGroupCfg(ObservationGroupCfg): + """Test config class for policy observation group.""" + + term_1 = ObservationTermCfg(func="__main__:grilled_chicken", scale=10, history_length=TERM_1_HISTORY) + term_2 = ObservationTermCfg(func=grilled_chicken, scale=2) + term_3 = ObservationTermCfg(func=grilled_chicken_with_bbq, scale=5, params={"bbq": True}) + term_4 = ObservationTermCfg( + func=grilled_chicken_with_yoghurt, scale=1.0, params={"hot": False, "bland": 2.0} + ) + term_5 = ObservationTermCfg( + func=grilled_chicken_with_yoghurt_and_bbq, scale=1.0, params={"hot": False, "bland": 2.0} + ) + + policy: ObservationGroupCfg = SampleGroupCfg() + + # create observation manager + cfg = MyObservationManagerCfg() + self.obs_man = ObservationManager(cfg, self.env) + self.assertEqual(len(self.obs_man.active_terms["policy"]), 5) + # print the expected string + obs_man_str = str(self.obs_man) + print() + print(obs_man_str) + obs_man_str_split = obs_man_str.split("|") + term_1_str_index = obs_man_str_split.index(" term_1 ") + term_1_str_shape = obs_man_str_split[term_1_str_index + 1].strip() + self.assertEqual(term_1_str_shape, "(20,)") def test_config_equivalence(self): """Test the equivalence of observation manager created from different config types.""" @@ -304,6 +347,157 @@ class ImageCfg(ObservationGroupCfg): torch.testing.assert_close(obs_policy[:, 5:8], obs_critic[:, 0:3]) torch.testing.assert_close(obs_policy[:, 8:11], obs_critic[:, 3:6]) + def test_compute_with_history(self): + """Test the observation computation with history buffers.""" + HISTORY_LENGTH = 5 + + @configclass + class MyObservationManagerCfg: + """Test config class for observation manager.""" + + @configclass + class PolicyCfg(ObservationGroupCfg): + """Test config class for policy observation group.""" + + term_1 = ObservationTermCfg(func=grilled_chicken, history_length=HISTORY_LENGTH) + # total observation size: term_dim (4) * history_len (5) = 20 + term_2 = ObservationTermCfg(func=lin_vel_w_data) + # total observation size: term_dim (3) = 3 + + policy: ObservationGroupCfg = PolicyCfg() + + # create observation manager + cfg = MyObservationManagerCfg() + self.obs_man = ObservationManager(cfg, self.env) + # compute observation using manager + observations = self.obs_man.compute() + # obtain the group observations + obs_policy: torch.Tensor = observations["policy"] + # check the observation shape + self.assertEqual((self.env.num_envs, 23), obs_policy.shape) + # check the observation data + expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device) + expected_obs_term_2_data = lin_vel_w_data(self.env) + expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) + print(expected_obs_data_t0, obs_policy) + self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) + # test that the history buffer holds previous data + for _ in range(HISTORY_LENGTH): + observations = self.obs_man.compute() + obs_policy = observations["policy"] + expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * HISTORY_LENGTH, device=self.env.device) + expected_obs_data_t5 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) + self.assertTrue(torch.equal(expected_obs_data_t5, obs_policy)) + # test reset + self.obs_man.reset() + observations = self.obs_man.compute() + obs_policy = observations["policy"] + self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) + # test reset of specific env ids + reset_env_ids = [2, 4, 16] + self.obs_man.reset(reset_env_ids) + self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])) + + def test_compute_with_2d_history(self): + """Test the observation computation with history buffers for 2D observations.""" + HISTORY_LENGTH = 5 + + @configclass + class MyObservationManagerCfg: + """Test config class for observation manager.""" + + @configclass + class FlattenedPolicyCfg(ObservationGroupCfg): + """Test config class for policy observation group.""" + + term_1 = ObservationTermCfg( + func=grilled_chicken_image, params={"bland": 1.0, "channel": 1}, history_length=HISTORY_LENGTH + ) + # total observation size: term_dim (128, 256) * history_len (5) = 163840 + + @configclass + class PolicyCfg(ObservationGroupCfg): + """Test config class for policy observation group.""" + + term_1 = ObservationTermCfg( + func=grilled_chicken_image, + params={"bland": 1.0, "channel": 1}, + history_length=HISTORY_LENGTH, + flatten_history_dim=False, + ) + # total observation size: (5, 128, 256, 1) + + flat_obs_policy: ObservationGroupCfg = FlattenedPolicyCfg() + policy: ObservationGroupCfg = PolicyCfg() + + # create observation manager + cfg = MyObservationManagerCfg() + self.obs_man = ObservationManager(cfg, self.env) + # compute observation using manager + observations = self.obs_man.compute() + # obtain the group observations + obs_policy_flat: torch.Tensor = observations["flat_obs_policy"] + obs_policy: torch.Tensor = observations["policy"] + # check the observation shapes + self.assertEqual((self.env.num_envs, 163840), obs_policy_flat.shape) + self.assertEqual((self.env.num_envs, HISTORY_LENGTH, 128, 256, 1), obs_policy.shape) + + def test_compute_with_group_history(self): + """Test the observation computation with group level history buffer configuration.""" + TERM_HISTORY_LENGTH = 5 + GROUP_HISTORY_LENGTH = 10 + + @configclass + class MyObservationManagerCfg: + """Test config class for observation manager.""" + + @configclass + class PolicyCfg(ObservationGroupCfg): + """Test config class for policy observation group.""" + + history_length = GROUP_HISTORY_LENGTH + # group level history length will override all terms + term_1 = ObservationTermCfg(func=grilled_chicken, history_length=TERM_HISTORY_LENGTH) + # total observation size: term_dim (4) * history_len (5) = 20 + # with override total obs size: term_dim (4) * history_len (10) = 40 + term_2 = ObservationTermCfg(func=lin_vel_w_data) + # total observation size: term_dim (3) = 3 + # with override total obs size: term_dim (3) * history_len (10) = 30 + + policy: ObservationGroupCfg = PolicyCfg() + + # create observation manager + cfg = MyObservationManagerCfg() + self.obs_man = ObservationManager(cfg, self.env) + # compute observation using manager + observations = self.obs_man.compute() + # obtain the group observations + obs_policy: torch.Tensor = observations["policy"] + # check the total observation shape + self.assertEqual((self.env.num_envs, 70), obs_policy.shape) + # check the observation data is initialized properly + expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device) + expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH) + expected_obs_data_t0 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) + self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) + # test that the history buffer holds previous data + for _ in range(GROUP_HISTORY_LENGTH): + observations = self.obs_man.compute() + obs_policy = observations["policy"] + expected_obs_term_1_data = torch.ones(self.env.num_envs, 4 * GROUP_HISTORY_LENGTH, device=self.env.device) + expected_obs_term_2_data = lin_vel_w_data(self.env).repeat(1, GROUP_HISTORY_LENGTH) + expected_obs_data_t10 = torch.concat((expected_obs_term_1_data, expected_obs_term_2_data), dim=-1) + self.assertTrue(torch.equal(expected_obs_data_t10, obs_policy)) + # test reset + self.obs_man.reset() + observations = self.obs_man.compute() + obs_policy = observations["policy"] + self.assertTrue(torch.equal(expected_obs_data_t0, obs_policy)) + # test reset of specific env ids + reset_env_ids = [2, 4, 16] + self.obs_man.reset(reset_env_ids) + self.assertTrue(torch.equal(expected_obs_data_t0[reset_env_ids], obs_policy[reset_env_ids])) + def test_invalid_observation_config(self): """Test the invalid observation config.""" diff --git a/source/extensions/omni.isaac.lab/test/sensors/test_outdated_sensor.py b/source/extensions/omni.isaac.lab/test/sensors/test_outdated_sensor.py new file mode 100644 index 0000000000..0079ece967 --- /dev/null +++ b/source/extensions/omni.isaac.lab/test/sensors/test_outdated_sensor.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, The Isaac Lab Project Developers. +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Launch Isaac Sim Simulator first.""" + +from omni.isaac.lab.app import AppLauncher, run_tests + +# launch the simulator +app_launcher = AppLauncher(headless=True, enable_cameras=True) +simulation_app = app_launcher.app + + +"""Rest everything follows.""" + +import gymnasium as gym +import shutil +import tempfile +import torch +import unittest + +import carb +import omni.usd + +import omni.isaac.lab_tasks # noqa: F401 +from omni.isaac.lab_tasks.utils.parse_cfg import parse_env_cfg + + +class TestFrameTransformerAfterReset(unittest.TestCase): + """Test cases for checking FrameTransformer values after reset.""" + + @classmethod + def setUpClass(cls): + # this flag is necessary to prevent a bug where the simulation gets stuck randomly when running the + # test on many environments. + carb_settings_iface = carb.settings.get_settings() + carb_settings_iface.set_bool("/physics/cooking/ujitsoCollisionCooking", False) + + def setUp(self): + # create a temporary directory to store the test datasets + self.temp_dir = tempfile.mkdtemp() + + def tearDown(self): + # delete the temporary directory after the test + shutil.rmtree(self.temp_dir) + + def test_action_state_reocrder_terms(self): + """Check FrameTransformer values after reset.""" + for task_name in ["Isaac-Stack-Cube-Franka-IK-Rel-v0"]: + for device in ["cuda:0", "cpu"]: + for num_envs in [1, 2]: + with self.subTest(task_name=task_name, device=device): + omni.usd.get_context().new_stage() + + # parse configuration + env_cfg = parse_env_cfg(task_name, device=device, num_envs=num_envs) + + # create environment + env = gym.make(task_name, cfg=env_cfg) + + # disable control on stop + env.unwrapped.sim._app_control_on_stop_handle = None # type: ignore + + # reset environment + obs = env.reset()[0] + + # get the end effector position after the reset + pre_reset_eef_pos = obs["policy"]["eef_pos"].clone() + print(pre_reset_eef_pos) + + # step the environment with idle actions + idle_actions = torch.zeros(env.action_space.shape, device=env.unwrapped.device) + obs = env.step(idle_actions)[0] + + # get the end effector position after the first step + post_reset_eef_pos = obs["policy"]["eef_pos"] + print(post_reset_eef_pos) + + # check if the end effector position is the same after the reset and the first step + print(torch.all(torch.isclose(pre_reset_eef_pos, post_reset_eef_pos))) + self.assertTrue(torch.all(torch.isclose(pre_reset_eef_pos, post_reset_eef_pos))) + + # close the environment + env.close() + + +if __name__ == "__main__": + run_tests() diff --git a/source/extensions/omni.isaac.lab/test/utils/test_circular_buffer.py b/source/extensions/omni.isaac.lab/test/utils/test_circular_buffer.py index 8286283a4c..57d75fc341 100644 --- a/source/extensions/omni.isaac.lab/test/utils/test_circular_buffer.py +++ b/source/extensions/omni.isaac.lab/test/utils/test_circular_buffer.py @@ -46,9 +46,31 @@ def test_reset(self): # reset the buffer self.buffer.reset() - # check if the buffer is empty + # check if the buffer has zeros entries self.assertEqual(self.buffer.current_length.tolist(), [0, 0, 0]) + def test_reset_subset(self): + """Test resetting a subset of batches in the circular buffer.""" + data1 = torch.ones((self.batch_size, 2), device=self.device) + data2 = 2.0 * data1.clone() + data3 = 3.0 * data1.clone() + self.buffer.append(data1) + self.buffer.append(data2) + # reset the buffer + reset_batch_id = 1 + self.buffer.reset(batch_ids=[reset_batch_id]) + # check that correct batch is reset + self.assertEqual(self.buffer.current_length.tolist()[reset_batch_id], 0) + # Append new set of data + self.buffer.append(data3) + # check if the correct number of entries are in each batch + expected_length = [3, 3, 3] + expected_length[reset_batch_id] = 1 + self.assertEqual(self.buffer.current_length.tolist(), expected_length) + # check that all entries of the recently reset and appended batch are equal + for i in range(self.max_len): + torch.testing.assert_close(self.buffer.buffer[reset_batch_id, 0], self.buffer.buffer[reset_batch_id, i]) + def test_append_and_retrieve(self): """Test appending and retrieving data from the circular buffer.""" # append some data @@ -121,6 +143,33 @@ def test_key_greater_than_pushes(self): retrieved_data = self.buffer[torch.tensor([5, 5, 5], device=self.device)] self.assertTrue(torch.equal(retrieved_data, data1)) + def test_return_buffer_prop(self): + """Test retrieving the whole buffer for correct size and contents. + Returning the whole buffer should have the shape [batch_size,max_len,data.shape[1:]] + """ + num_overflow = 2 + for i in range(self.buffer.max_length + num_overflow): + data = torch.tensor([[i]], device=self.device).repeat(3, 2) + self.buffer.append(data) + + retrieved_buffer = self.buffer.buffer + # check shape + self.assertTrue(retrieved_buffer.shape == torch.Size([self.buffer.batch_size, self.buffer.max_length, 2])) + # check that batch is first dimension + torch.testing.assert_close(retrieved_buffer[0], retrieved_buffer[1]) + # check oldest + torch.testing.assert_close( + retrieved_buffer[:, 0], torch.tensor([[num_overflow]], device=self.device).repeat(3, 2) + ) + # check most recent + torch.testing.assert_close( + retrieved_buffer[:, -1], + torch.tensor([[self.buffer.max_length + num_overflow - 1]], device=self.device).repeat(3, 2), + ) + # check that it is returned oldest first + for idx in range(self.buffer.max_length - 1): + self.assertTrue(torch.all(torch.le(retrieved_buffer[:, idx], retrieved_buffer[:, idx + 1]))) + if __name__ == "__main__": run_tests() diff --git a/source/standalone/environments/state_machine/lift_cube_sm.py b/source/standalone/environments/state_machine/lift_cube_sm.py index bd14dcaf0d..03f9c16e5f 100644 --- a/source/standalone/environments/state_machine/lift_cube_sm.py +++ b/source/standalone/environments/state_machine/lift_cube_sm.py @@ -81,6 +81,11 @@ class PickSmWaitTime: LIFT_OBJECT = wp.constant(1.0) +@wp.func +def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool: + return wp.length(current_pos - desired_pos) < threshold + + @wp.kernel def infer_state_machine( dt: wp.array(dtype=float), @@ -92,6 +97,7 @@ def infer_state_machine( des_ee_pose: wp.array(dtype=wp.transform), gripper_state: wp.array(dtype=float), offset: wp.array(dtype=wp.transform), + position_threshold: float, ): # retrieve thread id tid = wp.tid() @@ -109,21 +115,28 @@ def infer_state_machine( elif state == PickSmState.APPROACH_ABOVE_OBJECT: des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid]) gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.APPROACH_OBJECT - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_OBJECT + sm_wait_time[tid] = 0.0 elif state == PickSmState.APPROACH_OBJECT: des_ee_pose[tid] = object_pose[tid] gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.GRASP_OBJECT - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.GRASP_OBJECT + sm_wait_time[tid] = 0.0 elif state == PickSmState.GRASP_OBJECT: des_ee_pose[tid] = object_pose[tid] gripper_state[tid] = GripperState.CLOSE @@ -135,12 +148,16 @@ def infer_state_machine( elif state == PickSmState.LIFT_OBJECT: des_ee_pose[tid] = des_object_pose[tid] gripper_state[tid] = GripperState.CLOSE - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.LIFT_OBJECT - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.LIFT_OBJECT + sm_wait_time[tid] = 0.0 # increment wait time sm_wait_time[tid] = sm_wait_time[tid] + dt[tid] @@ -160,7 +177,7 @@ class PickAndLiftSm: 5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state. """ - def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): + def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01): """Initialize the state machine. Args: @@ -172,6 +189,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu") self.dt = float(dt) self.num_envs = num_envs self.device = device + self.position_threshold = position_threshold # initialize state machine self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) @@ -201,7 +219,7 @@ def reset_idx(self, env_ids: Sequence[int] = None): self.sm_state[env_ids] = 0 self.sm_wait_time[env_ids] = 0.0 - def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor): + def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_pose: torch.Tensor) -> torch.Tensor: """Compute the desired state of the robot's end-effector and the gripper.""" # convert all transformations from (w, x, y, z) to (x, y, z, w) ee_pose = ee_pose[:, [0, 1, 2, 4, 5, 6, 3]] @@ -227,6 +245,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p self.des_ee_pose_wp, self.des_gripper_state_wp, self.offset_wp, + self.position_threshold, ], device=self.device, ) @@ -257,7 +276,9 @@ def main(): desired_orientation = torch.zeros((env.unwrapped.num_envs, 4), device=env.unwrapped.device) desired_orientation[:, 1] = 1.0 # create state machine - pick_sm = PickAndLiftSm(env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device) + pick_sm = PickAndLiftSm( + env_cfg.sim.dt * env_cfg.decimation, env.unwrapped.num_envs, env.unwrapped.device, position_threshold=0.01 + ) while simulation_app.is_running(): # run everything in inference mode diff --git a/source/standalone/environments/state_machine/lift_teddy_bear.py b/source/standalone/environments/state_machine/lift_teddy_bear.py index 896aa614bc..016476066a 100644 --- a/source/standalone/environments/state_machine/lift_teddy_bear.py +++ b/source/standalone/environments/state_machine/lift_teddy_bear.py @@ -80,6 +80,11 @@ class PickSmWaitTime: OPEN_GRIPPER = wp.constant(0.0) +@wp.func +def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool: + return wp.length(current_pos - desired_pos) < threshold + + @wp.kernel def infer_state_machine( dt: wp.array(dtype=float), @@ -91,6 +96,7 @@ def infer_state_machine( des_ee_pose: wp.array(dtype=wp.transform), gripper_state: wp.array(dtype=float), offset: wp.array(dtype=wp.transform), + position_threshold: float, ): # retrieve thread id tid = wp.tid() @@ -108,21 +114,29 @@ def infer_state_machine( elif state == PickSmState.APPROACH_ABOVE_OBJECT: des_ee_pose[tid] = wp.transform_multiply(offset[tid], object_pose[tid]) gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.APPROACH_OBJECT - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.APPROACH_OBJECT + sm_wait_time[tid] = 0.0 elif state == PickSmState.APPROACH_OBJECT: des_ee_pose[tid] = object_pose[tid] gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.GRASP_OBJECT - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.APPROACH_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.GRASP_OBJECT + sm_wait_time[tid] = 0.0 elif state == PickSmState.GRASP_OBJECT: des_ee_pose[tid] = object_pose[tid] gripper_state[tid] = GripperState.CLOSE @@ -134,12 +148,16 @@ def infer_state_machine( elif state == PickSmState.LIFT_OBJECT: des_ee_pose[tid] = des_object_pose[tid] gripper_state[tid] = GripperState.CLOSE - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: - # move to next state and reset wait time - sm_state[tid] = PickSmState.OPEN_GRIPPER - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= PickSmWaitTime.LIFT_OBJECT: + # move to next state and reset wait time + sm_state[tid] = PickSmState.OPEN_GRIPPER + sm_wait_time[tid] = 0.0 elif state == PickSmState.OPEN_GRIPPER: # des_ee_pose[tid] = object_pose[tid] gripper_state[tid] = GripperState.OPEN @@ -167,7 +185,7 @@ class PickAndLiftSm: 5. LIFT_OBJECT: The robot lifts the object to the desired pose. This is the final state. """ - def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): + def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01): """Initialize the state machine. Args: @@ -179,6 +197,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu") self.dt = float(dt) self.num_envs = num_envs self.device = device + self.position_threshold = position_threshold # initialize state machine self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) @@ -234,6 +253,7 @@ def compute(self, ee_pose: torch.Tensor, object_pose: torch.Tensor, des_object_p self.des_ee_pose_wp, self.des_gripper_state_wp, self.offset_wp, + self.position_threshold, ], device=self.device, ) diff --git a/source/standalone/environments/state_machine/open_cabinet_sm.py b/source/standalone/environments/state_machine/open_cabinet_sm.py index ad40653fca..624defead2 100644 --- a/source/standalone/environments/state_machine/open_cabinet_sm.py +++ b/source/standalone/environments/state_machine/open_cabinet_sm.py @@ -83,6 +83,11 @@ class OpenDrawerSmWaitTime: RELEASE_HANDLE = wp.constant(0.2) +@wp.func +def distance_below_threshold(current_pos: wp.vec3, desired_pos: wp.vec3, threshold: float) -> bool: + return wp.length(current_pos - desired_pos) < threshold + + @wp.kernel def infer_state_machine( dt: wp.array(dtype=float), @@ -95,6 +100,7 @@ def infer_state_machine( handle_approach_offset: wp.array(dtype=wp.transform), handle_grasp_offset: wp.array(dtype=wp.transform), drawer_opening_rate: wp.array(dtype=wp.transform), + position_threshold: float, ): # retrieve thread id tid = wp.tid() @@ -112,21 +118,29 @@ def infer_state_machine( elif state == OpenDrawerSmState.APPROACH_INFRONT_HANDLE: des_ee_pose[tid] = wp.transform_multiply(handle_approach_offset[tid], handle_pose[tid]) gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE: - # move to next state and reset wait time - sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_INFRONT_HANDLE: + # move to next state and reset wait time + sm_state[tid] = OpenDrawerSmState.APPROACH_HANDLE + sm_wait_time[tid] = 0.0 elif state == OpenDrawerSmState.APPROACH_HANDLE: des_ee_pose[tid] = handle_pose[tid] gripper_state[tid] = GripperState.OPEN - # TODO: error between current and desired ee pose below threshold - # wait for a while - if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE: - # move to next state and reset wait time - sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE - sm_wait_time[tid] = 0.0 + if distance_below_threshold( + wp.transform_get_translation(ee_pose[tid]), + wp.transform_get_translation(des_ee_pose[tid]), + position_threshold, + ): + # wait for a while + if sm_wait_time[tid] >= OpenDrawerSmWaitTime.APPROACH_HANDLE: + # move to next state and reset wait time + sm_state[tid] = OpenDrawerSmState.GRASP_HANDLE + sm_wait_time[tid] = 0.0 elif state == OpenDrawerSmState.GRASP_HANDLE: des_ee_pose[tid] = wp.transform_multiply(handle_grasp_offset[tid], handle_pose[tid]) gripper_state[tid] = GripperState.CLOSE @@ -170,7 +184,7 @@ class OpenDrawerSm: 5. RELEASE_HANDLE: The robot releases the handle of the drawer. This is the final state. """ - def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu"): + def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu", position_threshold=0.01): """Initialize the state machine. Args: @@ -182,6 +196,7 @@ def __init__(self, dt: float, num_envs: int, device: torch.device | str = "cpu") self.dt = float(dt) self.num_envs = num_envs self.device = device + self.position_threshold = position_threshold # initialize state machine self.sm_dt = torch.full((self.num_envs,), self.dt, device=self.device) self.sm_state = torch.full((self.num_envs,), 0, dtype=torch.int32, device=self.device) @@ -248,6 +263,7 @@ def compute(self, ee_pose: torch.Tensor, handle_pose: torch.Tensor): self.handle_approach_offset_wp, self.handle_grasp_offset_wp, self.drawer_opening_rate_wp, + self.position_threshold, ], device=self.device, )