Skip to content

Commit

Permalink
fix actions
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharsangam committed Feb 5, 2025
1 parent 66f20b9 commit e95f1b3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
8 changes: 8 additions & 0 deletions skill_vla/configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ VLA_CONFIGS:

DEVICE: "cuda:0"

# Arm params
PICK_ARM_JOINT_ANGLES: [0, -160, 100, 0, 75, 0]
JOINT_BLACKLIST: [3, 5] # joints we can't control "arm0.el0", "arm0.wr1"
ARM_LOWER_LIMITS: [-45, -180, 0, 0, -90, 0]
ARM_UPPER_LIMITS: [45, 0, 180, 0, 90, 0]
MAX_JOINT_MOVEMENT: 0.08 # radians

# Nav params
MAX_LIN_DIST: 0.1
MAX_ANG_DIST: 5.73

# BD params
ARM_TRAJECTORY_TIME_IN_SECONDS: 2.0
DISABLE_OBSTACLE_AVOIDANCE: False

# Episode params
CTRL_HZ: 2.0
MAX_EPISODE_STEPS: 200
23 changes: 15 additions & 8 deletions skill_vla/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def __init__(
super().__init__(spot=spot)
self.config = config
self.spot = spot
self._max_lin_dist_scale = self.config.MAX_LIN_DIST
self._max_ang_dist_scale = self.config.MAX_ANG_DIST
self._max_joint_movement_scale = self.config.MAX_JOINT_MOVEMENT
self.ctrl_hz = self.config.CTRL_HZ
self.max_episode_steps = self.config.MAX_EPISODE_STEPS
self.prev_base_moved = False
self.num_steps = 0
self.should_end = False

def get_observations(self):
raise NotImplementedError
Expand Down Expand Up @@ -119,17 +127,16 @@ def process_base_action(self, base_action):
self.prev_base_moved = False
return base_action

def process_arm_action(self):
arm_action = rescale_actions(arm_action)
def process_arm_action(self, arm_action):
arm_action = rescale_actions(arm_action, action_thresh=0.003)
if np.count_nonzero(arm_action) > 0:
arm_action *= self._max_joint_movement_scale
arm_action = self.current_arm_pose + pad_action(arm_action)
arm_action = np.clip(
arm_action, self.arm_lower_limits, self.arm_upper_limits
)
else:
arm_action = None
return arm_action
arm_action_unmasked = np.array([arm_action[0], arm_action[1], 0, arm_action[2], arm_action[3], 0])
return arm_action_unmasked

def pre_step(self, action_dict):
# Update the action_dict with grasp and place flags
Expand All @@ -138,12 +145,12 @@ def pre_step(self, action_dict):
base_action = self.process_base_action(base_action)
arm_action = action_dict.get("arm_action", None)
arm_ee_action = action_dict.get("arm_ee_action", None)
arm_action = self.process_base_action(arm_action or arm_ee_action)
arm_action = self.process_arm_action(arm_action)

grasp = action_dict.get("grasp", False)
place = action_dict.get("place", False)

return arm_action, base_action
return base_action, arm_action

def post_step(self):
observations = self.get_observations()
Expand All @@ -164,7 +171,7 @@ def step(
travel_time_scale=1.0,
):
assert self.reset_ran, ".reset() must be called first!"
arm_action, base_action = self.pre_step()
base_action, arm_action = self.pre_step()

return self.post_step()

Expand Down
4 changes: 3 additions & 1 deletion skill_vla/envs/pick_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def reset(self, *args, **kwargs):
return observations

def step(self, action_dict: Dict[str, Any]):
arm_action, base_action = super().pre_step(
base_action, arm_action = super().pre_step(
action_dict=action_dict,
)
print('base_action: ', base_action)
print('arm_action: ', arm_action)
self.spot.set_base_vel_and_arm_pos(
*base_action,
arm_action,
Expand Down
9 changes: 7 additions & 2 deletions skill_vla/experiments/eval_skill_vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ def main(spot):
done = False
policy.reset()
while not done:
action = policy.act(observations)
observations, _, done, _ = env.step(base_action=action)
action = policy.act(observations)[0]
action_dict = {"arm_action": action[:4],
"base_action": action[4:6],
"empty_action": action[-1]}
print('action: ', action)
print('action_dict: ', action_dict)
observations, _, done, _ = env.step(action_dict)


if __name__ == "__main__":
Expand Down

0 comments on commit e95f1b3

Please sign in to comment.