Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lints #9

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ clean:
# Static Checks #
# ------------------------ #

py-files := $(shell find . -name '*.py')
py-files := $(shell find . -not -path '*/.*' -name '*.py')

format:
@black $(py-files)
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ git clone https://github.com/kscalelabs/stompy-live
&& pip install -e .
```
# From ManiSkill2 Docs for downloading assets:
Some environments require downloading assets. You can download all the assets by ` python -m mani_skill.utils.download_asset all ` or download task-specific assets by ` python -m mani_skill.utils.download_asset ${ENV_ID} ` . The assets will be downloaded to ` ./data/ ` by default, and you can also use the environment variable ` MS2_ASSET_DIR `to specify this destination.

Some environments require downloading assets. You can download all the assets by ` python -m mani_skill.utils.download_asset all ` or download task-specific assets by ` python -m mani_skill.utils.download_asset ENV_ID` . The assets will be downloaded to ` ./data/ ` by default, and you can also use the environment variable ` MS2_ASSET_DIR `to specify this destination.

You will want to, at a minimum, download

- ycq
- ReplicaCAD
- AI2THOR

by replacing `ENV_ID` with the given datasets in the previously given command.

# Tests:

Expand Down
5 changes: 4 additions & 1 deletion stompy_live/agents/franka/franka_arm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Agent for the Panda arm. Code largely ported from ManiSkill."""

import numpy as np
import torch
from torch import nn
from torch.distributions import Normal

from stompy_live.agents.layer_init import layer_init


class Agent(nn.Module):
def __init__(self, envs):
def __init__(self, envs) -> None:
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 256)),
Expand Down
33 changes: 18 additions & 15 deletions stompy_live/agents/franka/franka_arm_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@
import torch
from torch import nn
from torch.distributions import Normal

from stompy_live.agents.layer_init import layer_init


class NatureCNN(nn.Module):
def __init__(self, sample_obs):
def __init__(self, sample_obs) -> None:
super().__init__()

extractors = {}

self.out_features = 0
feature_size = 256
in_channels=sample_obs["rgb"].shape[-1]
image_size=(sample_obs["rgb"].shape[1], sample_obs["rgb"].shape[2])
state_size=sample_obs["state"].shape[-1]
in_channels = sample_obs["rgb"].shape[-1]
(sample_obs["rgb"].shape[1], sample_obs["rgb"].shape[2])
state_size = sample_obs["state"].shape[-1]

# here we use a NatureCNN architecture to process images, but any architecture is permissble here
cnn = nn.Sequential(
Expand All @@ -26,20 +28,16 @@ def __init__(self, sample_obs):
padding=0,
),
nn.ReLU(),
nn.Conv2d(
in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0
),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0
),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)

# to easily figure out the dimensions after flattening, we pass a test tensor
with torch.no_grad():
n_flatten = cnn(sample_obs["rgb"].float().permute(0,3,1,2).cpu()).shape[1]
n_flatten = cnn(sample_obs["rgb"].float().permute(0, 3, 1, 2).cpu()).shape[1]
fc = nn.Sequential(nn.Linear(n_flatten, feature_size), nn.ReLU())
extractors["rgb"] = nn.Sequential(cnn, fc)
self.out_features += feature_size
Expand All @@ -56,13 +54,14 @@ def forward(self, observations) -> torch.Tensor:
for key, extractor in self.extractors.items():
obs = observations[key]
if key == "rgb":
obs = obs.float().permute(0,3,1,2)
obs = obs.float().permute(0, 3, 1, 2)
obs = obs / 255
encoded_tensor_list.append(extractor(obs))
return torch.cat(encoded_tensor_list, dim=1)


class Agent(nn.Module):
def __init__(self, envs, sample_obs):
def __init__(self, envs, sample_obs) -> None:
super().__init__()
self.feature_net = NatureCNN(sample_obs=sample_obs)
# latent_size = np.array(envs.unwrapped.single_observation_space.shape).prod()
Expand All @@ -75,14 +74,17 @@ def __init__(self, envs, sample_obs):
self.actor_mean = nn.Sequential(
layer_init(nn.Linear(latent_size, 512)),
nn.ReLU(inplace=True),
layer_init(nn.Linear(512, np.prod(envs.unwrapped.single_action_space.shape)), std=0.01*np.sqrt(2)),
layer_init(nn.Linear(512, np.prod(envs.unwrapped.single_action_space.shape)), std=0.01 * np.sqrt(2)),
)
self.actor_logstd = nn.Parameter(torch.ones(1, np.prod(envs.unwrapped.single_action_space.shape)) * -0.5)

def get_features(self, x):
return self.feature_net(x)

def get_value(self, x):
x = self.feature_net(x)
return self.critic(x)

def get_action(self, x, deterministic=False):
x = self.feature_net(x)
action_mean = self.actor_mean(x)
Expand All @@ -92,6 +94,7 @@ def get_action(self, x, deterministic=False):
action_std = torch.exp(action_logstd)
probs = Normal(action_mean, action_std)
return probs.sample()

def get_action_and_value(self, x, action=None):
x = self.feature_net(x)
action_mean = self.actor_mean(x)
Expand All @@ -100,4 +103,4 @@ def get_action_and_value(self, x, action=None):
probs = Normal(action_mean, action_std)
if action is None:
action = probs.sample()
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
3 changes: 2 additions & 1 deletion stompy_live/agents/layer_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import torch


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer
return layer
12 changes: 7 additions & 5 deletions stompy_live/agents/stompy/stompy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@

import numpy as np
import sapien
from mani_skill import PACKAGE_ASSET_DIR
from mani_skill.agents.base_agent import BaseAgent, Keyframe
from mani_skill.agents.controllers import PDJointPosControllerConfig
from mani_skill.agents.registration import register_agent
from mani_skill.sensors.camera import CameraConfig
from transforms3d import euler

from stompy_live.utils.config import get_model_dir

@register_agent() # uncomment this if you want to register the agent so you can instantiate it by ID when creating environments

@register_agent(
"stompy_latest"
) # uncomment this if you want to register the agent so you can instantiate it by ID when creating environments
class Stompy(BaseAgent):
uid = "stompy"
urdf_path = f"{PACKAGE_ASSET_DIR}/robots/stompy/robot.urdf"
uid = "stompy_latest"
urdf_path = f"{get_model_dir()}/7DOF_NEWEST/robot_7dof_arm_merged_simplified.urdf"
urdf_config = dict(
_materials=dict(gripper=dict(static_friction=2.0, dynamic_friction=2.0, restitution=0.0)),
link=dict(
Expand Down Expand Up @@ -116,7 +119,6 @@ def _sensor_configs(self) -> None:
fov=1.57,
near=0.01,
far=100,
entity_uid="link_head_1_head_1", # mount cameras relative to existing link IDs as so
)
]

Expand Down
16 changes: 8 additions & 8 deletions stompy_live/apis/franka_vla.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
deploy.py
"""deploy.py.

Code taken from https://github.com/openvla/openvla/blob/main/vla-scripts/deploy.py.

Expand Down Expand Up @@ -66,11 +65,12 @@ def get_openvla_prompt(instruction: str, openvla_path: Union[str, Path]) -> str:

# === Server Interface ===
class OpenVLAServer:
def __init__(self, openvla_path: Union[str, Path], attn_implementation: Optional[str] = "flash_attention_2") -> Path:
"""
A simple server for OpenVLA models; exposes `/act` to predict an action for a given image + instruction.
=> Takes in {"image": np.ndarray, "instruction": str, "unnorm_key": Optional[str]}
=> Returns {"action": np.ndarray}
def __init__(
self, openvla_path: Union[str, Path], attn_implementation: Optional[str] = "flash_attention_2"
) -> Path:
"""A simple server for OpenVLA models; exposes `/act` to predict an action for a given image + instruction.
=> Takes in {"image": np.ndarray, "instruction": str, "unnorm_key": Optional[str]}
=> Returns {"action": np.ndarray}.
"""
self.openvla_path, self.attn_implementation = openvla_path, attn_implementation
self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Expand Down Expand Up @@ -144,4 +144,4 @@ def deploy(cfg: DeployConfig) -> None:


if __name__ == "__main__":
deploy()
deploy()
11 changes: 11 additions & 0 deletions stompy_live/clients/README
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
These are the maniskill clients. They interface with the APIs running the
models, take the output of the APIs, and render it via maniskill. Then it gets
an image from maniskill and feed it to the API, repeating the cycle.

Stompy
------

To set up stompy.py download

https://drive.google.com/drive/folders/1dNL8i4sfu5N6ojUMOb9YDcCMRc3jRXg2

and unzip in stompy_live/assets

Then replace "fused" with "meshes/fused" in the main urdf file (robot_7dof_arm_merged_simplfied.urdf).
25 changes: 14 additions & 11 deletions stompy_live/clients/franka_vla.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from twitch.client import message_queue, init
import queue
import time
from threading import Thread
import time, queue
import requests

import gymnasium as gym
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper
import json_numpy
import requests
import torch
import stompy_live.envs.franka_push_cube # noqa: F401
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper, FlattenRGBDObservationWrapper

import stompy_live.envs.franka_push_cube # noqa: F401
from twitch.client import init, message_queue

import json_numpy
json_numpy.patch()
import pygame
import numpy as np
import cv2
import numpy as np
import pygame

window = pygame.display.set_mode((1024, 1024))

Expand Down Expand Up @@ -53,14 +56,14 @@

surface = pygame.surfarray.make_surface(upsized_image)
window.blit(surface, (0, 0))

pygame.display.update()

obs, reward, terminated, truncated, info = envs.step(torch.tensor(action, dtype=torch.float))

done = terminated or truncated

except queue.Empty:
time.sleep(1)
except KeyboardInterrupt:
break
break
31 changes: 31 additions & 0 deletions stompy_live/clients/stompy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Stompy client. Just performs random actions for now."""

import gymnasium as gym
import torch
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper

from stompy_live.envs.stompy_env import SceneManipulationEnv # noqa: F401

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env_kwargs = dict(obs_mode="state", control_mode="pd_joint_delta_pos", render_mode="human", sim_backend="gpu")
env = gym.make("New-SceneManipulation-v1", **env_kwargs, scene_builder_cls="ai2thor")
if isinstance(env.action_space, gym.spaces.Dict):
env = FlattenActionSpaceWrapper(env)
assert isinstance(env.single_action_space, gym.spaces.Box), "only continuous action space is supported"

while True:
obs, info = env.reset()
done = False
total_reward = 0

while not done:
# Get action from the model hosted at the API
with torch.no_grad():
action = env.action_space.sample()

obs, reward, terminated, truncated, info = env.step(action)

done = terminated or truncated
total_reward += reward
env.render()
Loading
Loading