From 1eb8e0e61755124e644e6837f0619d2897fc5287 Mon Sep 17 00:00:00 2001 From: cremebrule <84cremebrule@gmail.com> Date: Thu, 12 Dec 2024 11:00:09 -0800 Subject: [PATCH] add low pass filter for multi finger gripper to reduce velocity noise during grasp state inference --- .../multi_finger_gripper_controller.py | 50 ++++++++++++++++++- omnigibson/utils/processing_utils.py | 17 +++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/omnigibson/controllers/multi_finger_gripper_controller.py b/omnigibson/controllers/multi_finger_gripper_controller.py index 35e83b48a..70dbc9681 100644 --- a/omnigibson/controllers/multi_finger_gripper_controller.py +++ b/omnigibson/controllers/multi_finger_gripper_controller.py @@ -1,8 +1,8 @@ import torch as th -import omnigibson.utils.transform_utils as T from omnigibson.controllers import ControlType, GripperController, IsGraspingState from omnigibson.macros import create_module_macros +from omnigibson.utils.processing_utils import MovingAverageFilter from omnigibson.utils.python_utils import assert_valid_key VALID_MODES = { @@ -101,6 +101,9 @@ def __init__( # Create other args to be filled in at runtime self._is_grasping = IsGraspingState.FALSE + # Create ring buffer for velocity history to avoid high frequency nosie during grasp state inference + self._vel_filter = MovingAverageFilter(obs_dim=len(dof_idx), filter_width=5) + # If we're using binary signal, we override the command output limits if mode == "binary": command_output_limits = (-1.0, 1.0) @@ -118,9 +121,17 @@ def reset(self): # Call super first super().reset() + # Reset the filter + self._vel_filter.reset() + # reset grasping state self._is_grasping = IsGraspingState.FALSE + @property + def state_size(self): + # Add state size from the control filter + return super().state_size + self._vel_filter.state_size + def _preprocess_command(self, command): # We extend this method to make sure command is always n-dimensional if self._mode != "independent": @@ -208,6 +219,9 @@ def _update_grasping_state(self, control_dict): joint_position: Array of current joint positions joint_velocity: Array of current joint velocities """ + # Update velocity history + finger_vel = self._vel_filter.estimate(control_dict["joint_velocity"][self.dof_idx]) + # Calculate grasping state based on mode of this controller # Independent mode of MultiFingerGripperController does not have any good heuristics to determine is_grasping @@ -240,7 +254,6 @@ def _update_grasping_state(self, control_dict): # Otherwise, the last control signal intends to "move" the gripper else: - finger_vel = control_dict["joint_velocity"][self.dof_idx] min_pos = self._control_limits[ControlType.POSITION][0][self.dof_idx] max_pos = self._control_limits[ControlType.POSITION][1][self.dof_idx] @@ -297,6 +310,39 @@ def is_grasping(self): # Return cached value return self._is_grasping + def _dump_state(self): + # Run super first + state = super()._dump_state() + + # Add filter state + state["vel_filter"] = self._vel_filter.dump_state(serialized=False) + + return state + + def _load_state(self, state): + # Run super first + super()._load_state(state=state) + + # Also load velocity filter state + self._vel_filter.load_state(state["vel_filter"], serialized=False) + + def serialize(self, state): + # Run super first + state_flat = super().serialize(state=state) + + # Serialize state for this controller + return th.cat([state_flat, self._vel_filter.serialize(state=state["vel_filter"])]) + + def deserialize(self, state): + # Run super first + state_dict, idx = super().deserialize(state=state) + + # Deserialize state for the velocity filter + state_dict["vel_filter"], deserialized_items = self._vel_filter.deserialize(state=state[idx:]) + idx += deserialized_items + + return state_dict, idx + @property def control_type(self): return ControlType.get_type(type_str=self._motor_type) diff --git a/omnigibson/utils/processing_utils.py b/omnigibson/utils/processing_utils.py index bb32f35f5..0f2527cd8 100644 --- a/omnigibson/utils/processing_utils.py +++ b/omnigibson/utils/processing_utils.py @@ -42,6 +42,13 @@ def deserialize(self, state): # Default is no state, so do nothing return dict(), 0 + @property + def state_size(self): + """ + Size of the serialized state of this filter + """ + raise NotImplementedError + class MovingAverageFilter(Filter): """ @@ -98,6 +105,11 @@ def reset(self): self.current_idx = 0 self.fully_filled = False + @property + def state_size(self): + # This is the size of the internal buffer plus the current index and fully filled single values + return th.prod(self.past_samples.shape) + 2 + def _dump_state(self): # Run super init first state = super()._dump_state() @@ -185,6 +197,11 @@ def reset(self): self.avg *= 0.0 self.num_samples = 0 + @property + def state_size(self): + # This is the size of the internal value as well as a num samples + return len(self.avg) + 1 + def _dump_state(self): # Run super init first state = super()._dump_state()