diff --git a/omnigibson/controllers/multi_finger_gripper_controller.py b/omnigibson/controllers/multi_finger_gripper_controller.py index 864eb24d4..ab888f2fe 100644 --- a/omnigibson/controllers/multi_finger_gripper_controller.py +++ b/omnigibson/controllers/multi_finger_gripper_controller.py @@ -1,6 +1,9 @@ +import torch as th + from omnigibson.controllers import ControlType, GripperController, IsGraspingState from omnigibson.macros import create_module_macros from omnigibson.utils.backend_utils import _compute_backend as cb +from omnigibson.utils.processing_utils import MovingAverageFilter from omnigibson.utils.python_utils import assert_valid_key VALID_MODES = { @@ -99,6 +102,9 @@ def __init__( # Create other args to be filled in at runtime self._is_grasping = IsGraspingState.FALSE + # Create ring buffer for velocity history to avoid high frequency nosie during grasp state inference + self._vel_filter = MovingAverageFilter(obs_dim=len(dof_idx), filter_width=5) + # If we're using binary signal, we override the command output limits if mode == "binary": command_output_limits = (-1.0, 1.0) @@ -121,9 +127,17 @@ def reset(self): # Call super first super().reset() + # Reset the filter + self._vel_filter.reset() + # reset grasping state self._is_grasping = IsGraspingState.FALSE + @property + def state_size(self): + # Add state size from the control filter + return super().state_size + self._vel_filter.state_size + def _preprocess_command(self, command): # We extend this method to make sure command is always n-dimensional if self._mode != "independent": @@ -211,6 +225,9 @@ def _update_grasping_state(self, control_dict): joint_position: Array of current joint positions joint_velocity: Array of current joint velocities """ + # Update velocity history + finger_vel = self._vel_filter.estimate(control_dict["joint_velocity"][self.dof_idx]) + # Calculate grasping state based on mode of this controller # Independent mode of MultiFingerGripperController does not have any good heuristics to determine is_grasping if self._mode == "independent": @@ -242,7 +259,6 @@ def _update_grasping_state(self, control_dict): # Otherwise, the last control signal intends to "move" the gripper else: - finger_vel = control_dict["joint_velocity"][self.dof_idx] min_pos = self._control_limits[ControlType.POSITION][0][self.dof_idx] max_pos = self._control_limits[ControlType.POSITION][1][self.dof_idx] @@ -299,6 +315,39 @@ def is_grasping(self): # Return cached value return self._is_grasping + def _dump_state(self): + # Run super first + state = super()._dump_state() + + # Add filter state + state["vel_filter"] = self._vel_filter.dump_state(serialized=False) + + return state + + def _load_state(self, state): + # Run super first + super()._load_state(state=state) + + # Also load velocity filter state + self._vel_filter.load_state(state["vel_filter"], serialized=False) + + def serialize(self, state): + # Run super first + state_flat = super().serialize(state=state) + + # Serialize state for this controller + return th.cat([state_flat, self._vel_filter.serialize(state=state["vel_filter"])]) + + def deserialize(self, state): + # Run super first + state_dict, idx = super().deserialize(state=state) + + # Deserialize state for the velocity filter + state_dict["vel_filter"], deserialized_items = self._vel_filter.deserialize(state=state[idx:]) + idx += deserialized_items + + return state_dict, idx + @property def control_type(self): return ControlType.get_type(type_str=self._motor_type) diff --git a/omnigibson/utils/processing_utils.py b/omnigibson/utils/processing_utils.py index 0aa341c1d..c47621317 100644 --- a/omnigibson/utils/processing_utils.py +++ b/omnigibson/utils/processing_utils.py @@ -43,6 +43,13 @@ def deserialize(self, state): # Default is no state, so do nothing return dict(), 0 + @property + def state_size(self): + """ + Size of the serialized state of this filter + """ + raise NotImplementedError + class MovingAverageFilter(Filter): """ @@ -99,6 +106,11 @@ def reset(self): self.current_idx = 0 self.fully_filled = False + @property + def state_size(self): + # This is the size of the internal buffer plus the current index and fully filled single values + return th.prod(self.past_samples.shape) + 2 + def _dump_state(self): # Run super init first state = super()._dump_state() @@ -186,6 +198,11 @@ def reset(self): self.avg *= 0.0 self.num_samples = 0 + @property + def state_size(self): + # This is the size of the internal value as well as a num samples + return len(self.avg) + 1 + def _dump_state(self): # Run super init first state = super()._dump_state()