Skip to content

Commit

Permalink
add normalizing function to cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mj-hwang committed Mar 23, 2023
1 parent b039035 commit ff0a782
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
24 changes: 24 additions & 0 deletions robosuite/environments/manipulation/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def __init__(
use_yaw=use_yaw,
)

self.use_yaw = use_yaw

self.keypoints = self.skill.get_keypoints_dict()
self.use_aff_rewards = use_aff_rewards
self.num_skills = self.skill.n_skills
Expand Down Expand Up @@ -820,3 +822,25 @@ def _scale_params(self, action):

return np.concatenate([action[:self.num_skills], params])

def _normalize_params(self, action):
"""
Normalize raw parameter to ([-1, 1]) range.
"""
action = np.copy(action)
params = action[self.num_skills:]

normalized_params = np.copy(params)
normalized_params[0] = 2 * (params[0] - self.workspace_x[0]) / (self.workspace_x[1] - self.workspace_x[0]) - 1
normalized_params[1] = 2 * (params[1] - self.workspace_y[0]) / (self.workspace_y[1] - self.workspace_y[0]) - 1
normalized_params[2] = 2 * (params[2] - self.workspace_z[0]) / (self.workspace_z[1] - self.workspace_z[0]) - 1

if action[2] > 0: # action is push
normalized_params[3] = 2 * (params[3] - self.workspace_x[0]) / (self.workspace_x[1] - self.workspace_x[0]) - 1
normalized_params[4] = 2 * (params[4] - self.workspace_y[0]) / (self.workspace_y[1] - self.workspace_y[0]) - 1
normalized_params[5] = 2 * (params[5] - self.workspace_z[0]) / (self.workspace_z[1] - self.workspace_z[0]) - 1
normalized_params[6] = 2 * (params[6] - self.yaw_bounds[0]) / (self.yaw_bounds[1] - self.yaw_bounds[0]) - 1

else: # action is pick or place
normalized_params[3] = 2 * (params[3] - self.yaw_bounds[0]) / (self.yaw_bounds[1] - self.yaw_bounds[0]) - 1

return np.concatenate([action[:self.num_skills], normalized_params])
3 changes: 1 addition & 2 deletions robosuite/environments/manipulation/stack_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ def _setup_observables(self):
# eef observations
@sensor(modality=modality)
def eef_xyz(obs_cache):
def eef_xyz(obs_cache):
eef_xyz = (
obs_cache[f"{pf}eef_pos"]
if f"{pf}eef_pos" in obs_cache
Expand Down Expand Up @@ -756,7 +755,7 @@ def _normalize_params(self, params):
return normalized_params

def synthetic_human_reward(self, action):
eef_pos = = self._eef_xpos
eef_pos = self._eef_xpos
grasping_A = self._check_grasp(gripper=self.robots[0].gripper, object_geoms=self.cubeA)

cubeA_pos = np.array(self.sim.data.body_xpos[self.cubeA_body_id])
Expand Down

0 comments on commit ff0a782

Please sign in to comment.