Skip to content

Commit

Permalink
Merge branch 'asset-conversion' into feat/np-opt
Browse files Browse the repository at this point in the history
# Conflicts:
#	omnigibson/controllers/multi_finger_gripper_controller.py
  • Loading branch information
cremebrule committed Dec 12, 2024
2 parents f73f9f0 + 1eb8e0e commit 64fab4f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
51 changes: 50 additions & 1 deletion omnigibson/controllers/multi_finger_gripper_controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch as th

from omnigibson.controllers import ControlType, GripperController, IsGraspingState
from omnigibson.macros import create_module_macros
from omnigibson.utils.backend_utils import _compute_backend as cb
from omnigibson.utils.processing_utils import MovingAverageFilter
from omnigibson.utils.python_utils import assert_valid_key

VALID_MODES = {
Expand Down Expand Up @@ -99,6 +102,9 @@ def __init__(
# Create other args to be filled in at runtime
self._is_grasping = IsGraspingState.FALSE

# Create ring buffer for velocity history to avoid high frequency nosie during grasp state inference
self._vel_filter = MovingAverageFilter(obs_dim=len(dof_idx), filter_width=5)

# If we're using binary signal, we override the command output limits
if mode == "binary":
command_output_limits = (-1.0, 1.0)
Expand All @@ -121,9 +127,17 @@ def reset(self):
# Call super first
super().reset()

# Reset the filter
self._vel_filter.reset()

# reset grasping state
self._is_grasping = IsGraspingState.FALSE

@property
def state_size(self):
# Add state size from the control filter
return super().state_size + self._vel_filter.state_size

def _preprocess_command(self, command):
# We extend this method to make sure command is always n-dimensional
if self._mode != "independent":
Expand Down Expand Up @@ -211,6 +225,9 @@ def _update_grasping_state(self, control_dict):
joint_position: Array of current joint positions
joint_velocity: Array of current joint velocities
"""
# Update velocity history
finger_vel = self._vel_filter.estimate(control_dict["joint_velocity"][self.dof_idx])

# Calculate grasping state based on mode of this controller
# Independent mode of MultiFingerGripperController does not have any good heuristics to determine is_grasping
if self._mode == "independent":
Expand Down Expand Up @@ -242,7 +259,6 @@ def _update_grasping_state(self, control_dict):

# Otherwise, the last control signal intends to "move" the gripper
else:
finger_vel = control_dict["joint_velocity"][self.dof_idx]
min_pos = self._control_limits[ControlType.POSITION][0][self.dof_idx]
max_pos = self._control_limits[ControlType.POSITION][1][self.dof_idx]

Expand Down Expand Up @@ -299,6 +315,39 @@ def is_grasping(self):
# Return cached value
return self._is_grasping

def _dump_state(self):
# Run super first
state = super()._dump_state()

# Add filter state
state["vel_filter"] = self._vel_filter.dump_state(serialized=False)

return state

def _load_state(self, state):
# Run super first
super()._load_state(state=state)

# Also load velocity filter state
self._vel_filter.load_state(state["vel_filter"], serialized=False)

def serialize(self, state):
# Run super first
state_flat = super().serialize(state=state)

# Serialize state for this controller
return th.cat([state_flat, self._vel_filter.serialize(state=state["vel_filter"])])

def deserialize(self, state):
# Run super first
state_dict, idx = super().deserialize(state=state)

# Deserialize state for the velocity filter
state_dict["vel_filter"], deserialized_items = self._vel_filter.deserialize(state=state[idx:])
idx += deserialized_items

return state_dict, idx

@property
def control_type(self):
return ControlType.get_type(type_str=self._motor_type)
Expand Down
17 changes: 17 additions & 0 deletions omnigibson/utils/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def deserialize(self, state):
# Default is no state, so do nothing
return dict(), 0

@property
def state_size(self):
"""
Size of the serialized state of this filter
"""
raise NotImplementedError


class MovingAverageFilter(Filter):
"""
Expand Down Expand Up @@ -99,6 +106,11 @@ def reset(self):
self.current_idx = 0
self.fully_filled = False

@property
def state_size(self):
# This is the size of the internal buffer plus the current index and fully filled single values
return th.prod(self.past_samples.shape) + 2

def _dump_state(self):
# Run super init first
state = super()._dump_state()
Expand Down Expand Up @@ -186,6 +198,11 @@ def reset(self):
self.avg *= 0.0
self.num_samples = 0

@property
def state_size(self):
# This is the size of the internal value as well as a num samples
return len(self.avg) + 1

def _dump_state(self):
# Run super init first
state = super()._dump_state()
Expand Down

0 comments on commit 64fab4f

Please sign in to comment.