From e95f1b3dcb3a967be45d82587450fdaa671b0997 Mon Sep 17 00:00:00 2001 From: Tushar Sangam Date: Wed, 5 Feb 2025 15:54:26 -0800 Subject: [PATCH] fix actions --- skill_vla/configs/config.yaml | 8 ++++++++ skill_vla/envs/base_env.py | 23 +++++++++++++++-------- skill_vla/envs/pick_env.py | 4 +++- skill_vla/experiments/eval_skill_vla.py | 9 +++++++-- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/skill_vla/configs/config.yaml b/skill_vla/configs/config.yaml index feada3ed..d2efda20 100644 --- a/skill_vla/configs/config.yaml +++ b/skill_vla/configs/config.yaml @@ -5,13 +5,21 @@ VLA_CONFIGS: DEVICE: "cuda:0" +# Arm params PICK_ARM_JOINT_ANGLES: [0, -160, 100, 0, 75, 0] JOINT_BLACKLIST: [3, 5] # joints we can't control "arm0.el0", "arm0.wr1" ARM_LOWER_LIMITS: [-45, -180, 0, 0, -90, 0] ARM_UPPER_LIMITS: [45, 0, 180, 0, 90, 0] +MAX_JOINT_MOVEMENT: 0.08 # radians +# Nav params +MAX_LIN_DIST: 0.1 +MAX_ANG_DIST: 5.73 # BD params ARM_TRAJECTORY_TIME_IN_SECONDS: 2.0 +DISABLE_OBSTACLE_AVOIDANCE: False + +# Episode params CTRL_HZ: 2.0 MAX_EPISODE_STEPS: 200 \ No newline at end of file diff --git a/skill_vla/envs/base_env.py b/skill_vla/envs/base_env.py index d1210ab9..47211cca 100644 --- a/skill_vla/envs/base_env.py +++ b/skill_vla/envs/base_env.py @@ -77,6 +77,14 @@ def __init__( super().__init__(spot=spot) self.config = config self.spot = spot + self._max_lin_dist_scale = self.config.MAX_LIN_DIST + self._max_ang_dist_scale = self.config.MAX_ANG_DIST + self._max_joint_movement_scale = self.config.MAX_JOINT_MOVEMENT + self.ctrl_hz = self.config.CTRL_HZ + self.max_episode_steps = self.config.MAX_EPISODE_STEPS + self.prev_base_moved = False + self.num_steps = 0 + self.should_end = False def get_observations(self): raise NotImplementedError @@ -119,17 +127,16 @@ def process_base_action(self, base_action): self.prev_base_moved = False return base_action - def process_arm_action(self): - arm_action = rescale_actions(arm_action) + def process_arm_action(self, arm_action): + arm_action = rescale_actions(arm_action, action_thresh=0.003) if np.count_nonzero(arm_action) > 0: arm_action *= self._max_joint_movement_scale arm_action = self.current_arm_pose + pad_action(arm_action) arm_action = np.clip( arm_action, self.arm_lower_limits, self.arm_upper_limits ) - else: - arm_action = None - return arm_action + arm_action_unmasked = np.array([arm_action[0], arm_action[1], 0, arm_action[2], arm_action[3], 0]) + return arm_action_unmasked def pre_step(self, action_dict): # Update the action_dict with grasp and place flags @@ -138,12 +145,12 @@ def pre_step(self, action_dict): base_action = self.process_base_action(base_action) arm_action = action_dict.get("arm_action", None) arm_ee_action = action_dict.get("arm_ee_action", None) - arm_action = self.process_base_action(arm_action or arm_ee_action) + arm_action = self.process_arm_action(arm_action) grasp = action_dict.get("grasp", False) place = action_dict.get("place", False) - return arm_action, base_action + return base_action, arm_action def post_step(self): observations = self.get_observations() @@ -164,7 +171,7 @@ def step( travel_time_scale=1.0, ): assert self.reset_ran, ".reset() must be called first!" - arm_action, base_action = self.pre_step() + base_action, arm_action = self.pre_step() return self.post_step() diff --git a/skill_vla/envs/pick_env.py b/skill_vla/envs/pick_env.py index b375a8ec..14c6e31c 100644 --- a/skill_vla/envs/pick_env.py +++ b/skill_vla/envs/pick_env.py @@ -47,9 +47,11 @@ def reset(self, *args, **kwargs): return observations def step(self, action_dict: Dict[str, Any]): - arm_action, base_action = super().pre_step( + base_action, arm_action = super().pre_step( action_dict=action_dict, ) + print('base_action: ', base_action) + print('arm_action: ', arm_action) self.spot.set_base_vel_and_arm_pos( *base_action, arm_action, diff --git a/skill_vla/experiments/eval_skill_vla.py b/skill_vla/experiments/eval_skill_vla.py index a734a4b3..d8e711fd 100644 --- a/skill_vla/experiments/eval_skill_vla.py +++ b/skill_vla/experiments/eval_skill_vla.py @@ -24,8 +24,13 @@ def main(spot): done = False policy.reset() while not done: - action = policy.act(observations) - observations, _, done, _ = env.step(base_action=action) + action = policy.act(observations)[0] + action_dict = {"arm_action": action[:4], + "base_action": action[4:6], + "empty_action": action[-1]} + print('action: ', action) + print('action_dict: ', action_dict) + observations, _, done, _ = env.step(action_dict) if __name__ == "__main__":