Skip to content

Commit

Permalink
fix observations
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharsangam committed Feb 5, 2025
1 parent 6150f78 commit 66f20b9
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 14 deletions.
7 changes: 4 additions & 3 deletions skill_vla/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
VLA_CONFIGS:
VLA_CKPT: "/home/tushar/Desktop/spot-vla/spot-sim2real/skill_vla/weights/step70000.pt"
VLA_CKPT: "/home/tushar/Desktop/spot-vla/spot-sim2real/skill_vla/weights/livingroom_step70000.pt"
VLA_CONFIG_FILE: "/home/tushar/Desktop/spot-vla/spot-sim2real/skill_vla/configs/vla_cfg.yaml"
LANGUAGE_INSTRUCTION: "Navigate to the dresser and pick up the avocado plush toy"

DEVICE: "cpu"
DEVICE: "cuda:0"

PICK_ARM_JOINT_ANGLES: [0, -160, 100, 0, 75, 0]
JOINT_BLACKLIST: [3, 5] # joints we can't control "arm0.el0", "arm0.wr1"
Expand All @@ -13,4 +13,5 @@ ARM_UPPER_LIMITS: [45, 0, 180, 0, 90, 0]

# BD params
ARM_TRAJECTORY_TIME_IN_SECONDS: 2.0

CTRL_HZ: 2.0
MAX_EPISODE_STEPS: 200
16 changes: 12 additions & 4 deletions skill_vla/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

# mypy: ignore-errors
import os
import time
import os.path as osp

import cv2
import gym
import numpy as np
from typing import Any, Dict
from spot_rl.utils.robot_subscriber import SpotRobotSubscriberMixin
from spot_rl.utils.utils import FixSizeOrderedDict, arr2str, object_id_to_object_name
from spot_rl.utils.utils import ros_topics as rt
from spot_wrapper.spot import Spot
from spot_wrapper.spot import Spot, wrap_heading
import rospy

try:
import magnum as mn
Expand Down Expand Up @@ -69,7 +72,7 @@ def __init__(
):
self.detections_buffer = {
k: FixSizeOrderedDict(maxlen=DETECTIONS_BUFFER_LEN)
for k in ["detections", "filtered_depth", "viz"]
for k in ["filtered_hand_rgb", "viz"]
}
super().__init__(spot=spot)
self.config = config
Expand Down Expand Up @@ -196,8 +199,13 @@ def get_gripper_images(self, save_image=False):
# Return blank images if the gripper is being blocked
blank_img = np.zeros([NEW_HEIGHT, NEW_WIDTH, 1], dtype=np.float32)
return blank_img, blank_img.copy()
arm_rgb = self.msg_to_cv2(self.detections_buffer["filtered_hand_rgb"][-1])
arm_rgb = self.process_images(arm_rgb)
print('# imgs: ', len(self.detections_buffer['filtered_hand_rgb']))
_, msg = self.detections_buffer["filtered_hand_rgb"].popitem(last=True)
arm_rgb = self.msg_to_cv2(msg)
# arm_rgb = self.msg_to_cv2(rgb_img)
cv2.imwrite(f'arm_rgb_before_{int(time.time()*10000)}.png', arm_rgb)
# arm_rgb = self.process_images(arm_rgb)
# cv2.imwrite(f'arm_rgb_after_12312{int(time.time()*10000)}.png', arm_rgb)

return arm_rgb

Expand Down
3 changes: 2 additions & 1 deletion skill_vla/envs/pick_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import magnum as mn
import numpy as np
import rospy
from spot_rl.envs.base_env import SpotBaseEnv, pad_action, rescale_actions
from envs.base_env import SpotBaseEnv, pad_action, rescale_actions
from spot_wrapper.spot import Spot, wrap_heading


Expand All @@ -21,6 +21,7 @@ def __init__(self, config, spot: Spot):
spot,
)
self.grasp_attempted = False
self.initial_arm_joint_angles = np.deg2rad(self.config.PICK_ARM_JOINT_ANGLES)

def reset(self, *args, **kwargs):
# Move arm to initial configuration
Expand Down
2 changes: 0 additions & 2 deletions skill_vla/experiments/eval_skill_vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
import numpy as np
import torch
from envs.pick_env import SpotPickEnv
from omegaconf import OmegaConf
from PIL import Image
from spot_wrapper.spot import Spot
from vla_policy import VLAPolicy
from utils.utils import construct_config
from yacs.config import CfgNode as CN

def main(spot):
config = construct_config('/home/tushar/Desktop/spot-vla/spot-sim2real/skill_vla/configs/config.yaml')
Expand Down
18 changes: 14 additions & 4 deletions skill_vla/vla_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, config, device):
) # This is to store the actions from action chunk
self.depoly_one_action = True # If we want to do MPC style -- only depoly one action from action chunk
# TODO: expand these to multi-sensors
self.vla_target_image = "articulated_agent_arm_rgb" # Target RGB
self.vla_target_image = "arm_rgb" # Target RGB
self.vla_target_proprio = "joint" # Target proprio sensor
# vla_target_proprio = "ee_pos" # Target proprio sensor
# load checkpoint
Expand All @@ -60,7 +60,8 @@ def process_rgb(rgbs, target_size):
# Resize the image here
rgbs_process = torch.zeros((rgbs.shape[0], 3, target_size, target_size))
for i, rgb in enumerate(rgbs):
img = Image.fromarray(rgb.cpu().detach().numpy())
print('rgb: ', rgb.shape, rgb.dtype)
img = Image.fromarray(rgb)
img = img.resize((target_size, target_size))
img = np.array(img)
rgb = torch.as_tensor(
Expand Down Expand Up @@ -165,6 +166,15 @@ def infer_action_vla_model(
"""Infer action using vla models."""
self.observation_dict.append(observation)
# Confirm the number of batches
print('observations: ', observation.keys())
print('observations shape 1: ', observation[self.vla_target_image].shape, observation[self.vla_target_image].dtype)
if len(observation[self.vla_target_image].shape) > 3:
observation[self.vla_target_image] = np.transpose(observation[self.vla_target_image], (3, 0, 1, 2))
else:
observation[self.vla_target_image] = np.expand_dims(observation[self.vla_target_image], axis=0)
observation[self.vla_target_proprio] = torch.as_tensor(observation[self.vla_target_proprio])
print('observations shape img 2: ', observation[self.vla_target_image].shape, observation[self.vla_target_image].dtype)
print('observations shape proprio 2: ', observation[self.vla_target_proprio].shape, observation[self.vla_target_proprio].dtype)
bsz = observation[self.vla_target_image].shape[0]

if len(self.observation_dict) < vla_config.cond_steps:
Expand All @@ -175,7 +185,7 @@ def infer_action_vla_model(
vla_config.image_size,
)
if self.vla_target_proprio == "ee_pos":
prop_obs = self.observation_dict[-i - 1]["ee_pose"][:, :3]
prop_obs = self.observation_dict[-i - 1][self.vla_target_proprio][:, :3]
else:
prop_obs = self.observation_dict[-i - 1][self.vla_target_proprio]
self.proprio[:, vla_config.cond_steps - i - 1] = prop_obs
Expand All @@ -186,7 +196,7 @@ def infer_action_vla_model(
vla_config.image_size,
)
if self.vla_target_proprio == "ee_pos":
prop_obs = self.observation_dict[-i - 1]["ee_pose"][:, :3]
prop_obs = self.observation_dict[-i - 1][self.vla_target_proprio][:, :3]
else:
prop_obs = self.observation_dict[-i - 1][self.vla_target_proprio]

Expand Down

0 comments on commit 66f20b9

Please sign in to comment.