Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RGBD diffusion policy implementation as well as Draw Triangle and Draw SVG task #643

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5c25469
Added draw triangle with success condition
arnavg115 Oct 7, 2024
f1e4826
parallelized progress
arnavg115 Oct 10, 2024
de9b921
fixed triangle rotation issues
arnavg115 Oct 10, 2024
1189e32
clean up and format
arnavg115 Oct 14, 2024
645a74c
rgbd diffusion policy progress
arnavg115 Oct 15, 2024
2bd9cfc
diff policy rgbd cpu fixes
arnavg115 Oct 17, 2024
41bc78a
minor diff policy fixes and finished draw triangle parallelization
arnavg115 Oct 21, 2024
e322dad
Added depth arg to diff pol rgbd + formatting
arnavg115 Oct 21, 2024
46c3b22
Removed unused code
arnavg115 Oct 21, 2024
4f598bd
Made requested fixes, made bugfix to frame stack wrapper
arnavg115 Oct 22, 2024
7a52287
Edited make_env and frame_stack
arnavg115 Oct 22, 2024
7adcd5d
Added state obs to draw triangle
arnavg115 Oct 24, 2024
a209f59
Update draw_triangle max steps
arnavg115 Oct 24, 2024
b5528d6
Added DrawTriangle Docs
arnavg115 Oct 26, 2024
513435e
Fixed naming
arnavg115 Oct 28, 2024
3c0bcae
draw svg progress
arnavg115 Oct 29, 2024
186677b
fixed most issues and parallelized draw svg
arnavg115 Oct 30, 2024
518a5c5
Update draw_svg.py
arnavg115 Oct 30, 2024
c8231ef
added success condition and state based obs
arnavg115 Oct 30, 2024
a1516f0
Added discontinuous paths for draw svg
arnavg115 Oct 30, 2024
35aff3b
formatting, discontinuous state
arnavg115 Oct 31, 2024
9d70fd4
Updated run.py
arnavg115 Oct 31, 2024
ed58ce4
Merged draw svg
arnavg115 Oct 31, 2024
578105e
fixed state obs error for drawing envs
arnavg115 Nov 4, 2024
cac38a7
Changed drawsvg imports
arnavg115 Nov 7, 2024
ba99b50
Update draw_svg.py
arnavg115 Nov 7, 2024
107f6f5
Update draw_svg.py
arnavg115 Nov 8, 2024
d241419
Small bugfix
arnavg115 Nov 9, 2024
b7fe6fe
Update draw_svg.py
arnavg115 Nov 10, 2024
aa4de17
Update draw_svg.py
arnavg115 Nov 11, 2024
9a78a43
Update draw_triangle.py
arnavg115 Nov 15, 2024
78b2a49
Bugfixes, speed progress
arnavg115 Nov 16, 2024
1ffc230
success condition speed up
arnavg115 Nov 19, 2024
e8aa8d2
small fix
arnavg115 Nov 19, 2024
3bfa298
Updated draw_triangle
arnavg115 Nov 19, 2024
6aa1793
drawing env gpu bugfixes
arnavg115 Nov 19, 2024
e4124aa
diff poll rgbd fixes
arnavg115 Nov 21, 2024
f7c43d9
Synchronized fork
arnavg115 Dec 19, 2024
fa11b62
minor changes
arnavg115 Dec 19, 2024
68dee3c
fix for wandb logging
arnavg115 Dec 19, 2024
7644765
autobuild docs for DrawTriangle and SVG
arnavg115 Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions docs/source/tasks/drawing/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ Table of all tasks/environments in this category. Task column is the environment
<td><p>❌</p></td>
<td><p>1000</p></td>
</tr>
<tr class="row-odd">
<td><p><a href="#drawsvg-v1">DrawSVG-v1</a></p></td>
<td><div style='display:flex;gap:4px;align-items:center'><img style='min-width:min(50%, 100px);max-width:100px;height:auto' src='../../_static/env_thumbnails/DrawSVG-v1_rt_thumb_first.png' alt='DrawSVG-v1'> <img style='min-width:min(50%, 100px);max-width:100px;height:auto' src='../../_static/env_thumbnails/DrawSVG-v1_rt_thumb_last.png' alt='DrawSVG-v1'></div></td>
<td><p>❌</p></td>
<td><p>✅</p></td>
<td><p>❌</p></td>
<td><p>500</p></td>
</tr>
<tr class="row-odd">
<td><p><a href="#drawtriangle-v1">DrawTriangle-v1</a></p></td>
<td><div style='display:flex;gap:4px;align-items:center'><img style='min-width:min(50%, 100px);max-width:100px;height:auto' src='../../_static/env_thumbnails/DrawTriangle-v1_rt_thumb_first.png' alt='DrawTriangle-v1'> <img style='min-width:min(50%, 100px);max-width:100px;height:auto' src='../../_static/env_thumbnails/DrawTriangle-v1_rt_thumb_last.png' alt='DrawTriangle-v1'></div></td>
<td><p>❌</p></td>
<td><p>✅</p></td>
<td><p>❌</p></td>
<td><p>300</p></td>
</tr>
</tbody>
</table>

Expand Down Expand Up @@ -62,3 +78,53 @@ None
<source src="https://github.com/haosulab/ManiSkill/raw/main/figures/environment_demos/TableTopFreeDraw-v1_rt.mp4" type="video/mp4">
</video>
</div>

## DrawSVG-v1

![no-dense-reward][no-dense-reward-badge]
![sparse-reward][sparse-reward-badge]
:::{dropdown} Task Card
:icon: note
:color: primary

**Task Description:**
Instantiates a table with a white canvas on it and a svg path specified with an outline. A robot with a stick is to draw the triangle with a red line.

**Randomizations:**
- the goal svg's position on the xy-plane is randomized
- the goal svg's z-rotation is randomized in range [0, 2 $\pi$]

**Success Conditions:**
- the drawn points by the robot are within a euclidean distance of 0.05m with points on the goal svg
:::

<div style="display: flex; justify-content: center;">
<video preload="none" controls="True" width="100%" style="max-width: min(100%, 512px);" poster="../../_static/env_thumbnails/DrawSVG-v1_rt_thumb_first.png">
<source src="https://github.com/haosulab/ManiSkill/raw/figures/environment_demos/DrawSVG-v1_rt.mp4" type="video/mp4">
</video>
</div>

## DrawTriangle-v1

![no-dense-reward][no-dense-reward-badge]
![sparse-reward][sparse-reward-badge]
:::{dropdown} Task Card
:icon: note
:color: primary

**Task Description:**
Instantiates a table with a white canvas on it and a goal triangle with an outline. A robot with a stick is to draw the triangle with a red line.

**Randomizations:**
- the goal triangle's position on the xy-plane is randomized
- the goal triangle's z-rotation is randomized in range [0, 2 $\pi$]

**Success Conditions:**
- the drawn points by the robot are within a euclidean distance of 0.05m with points on the goal triangle
:::

<div style="display: flex; justify-content: center;">
<video preload="none" controls="True" width="100%" style="max-width: min(100%, 512px);" poster="../../_static/env_thumbnails/DrawTriangle-v1_rt_thumb_first.png">
<source src="https://github.com/haosulab/ManiSkill/raw/figures/environment_demos/DrawTriangle-v1_rt.mp4" type="video/mp4">
</video>
</div>
120 changes: 110 additions & 10 deletions examples/baselines/diffusion_policy/diffusion_policy/make_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,69 @@
from collections import deque
from typing import Optional

import gymnasium as gym
import mani_skill.envs
import numpy as np
from gymnasium.spaces import Box
from gymnasium.wrappers.frame_stack import LazyFrames
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers import CPUGymWrapper, FrameStack, RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.wrappers import RecordEpisode, FrameStack, CPUGymWrapper
# from mani_skill.utils.wrappers.frame_stack import LazyFrames

class DictFrameStack(FrameStack):
def __init__(
self,
env: gym.Env,
num_stack: int,
lz4_compress: bool = False,
):
"""Observation wrapper that stacks the observations in a rolling manner.

Args:
env (Env): The environment to apply the wrapper
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
"""
# gym.utils.RecordConstructorArgs.__init__(
# self, num_stack=num_stack, lz4_compress=lz4_compress
# )
# gym.ObservationWrapper.__init__(self, env)
super().__init__(env, num_stack, lz4_compress)

new_observation_space = gym.spaces.Dict()
for k, v in self.observation_space.items():
low = np.repeat(v.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(v.high[np.newaxis, ...], num_stack, axis=0)
new_observation_space[k] = Box(low=low, high=high, dtype=v.dtype)
self.observation_space = new_observation_space


def observation(self, observation):
"""Converts the wrappers current frames to lazy frames.

def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, other_kwargs: dict, video_dir: Optional[str] = None, wrappers: list[gym.Wrapper] = []):
Args:
observation: Ignored

Returns:
:class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames`
"""
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
return {
k: LazyFrames([x[k] for x in self.frames], self.lz4_compress)
for k in self.observation_space.keys()
}


def make_eval_envs(
env_id,
num_envs: int,
sim_backend: str,
env_kwargs: dict,
other_kwargs: dict,
video_dir: Optional[str] = None,
wrappers: list[gym.Wrapper] = [],
):
"""Create vectorized environment for evaluation and/or recording videos.
For CPU vectorized environments only the first parallel environment is used to record videos.
For GPU vectorized environments all parallel environments are used to record videos.
Expand All @@ -20,29 +77,72 @@ def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, ot
wrappers: the list of wrappers to apply to the environment.
"""
if sim_backend == "cpu":
def cpu_make_env(env_id, seed, video_dir=None, env_kwargs = dict(), other_kwargs = dict()):

def cpu_make_env(
env_id, seed, video_dir=None, env_kwargs=dict(), other_kwargs=dict()
):
def thunk():
env = gym.make(env_id, reconfiguration_freq=1, **env_kwargs)
for wrapper in wrappers:
env = wrapper(env)
env = CPUGymWrapper(env, ignore_terminations=True, record_metrics=True)
if video_dir:
env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, info_on_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout")
env = gym.wrappers.FrameStack(env, other_kwargs['obs_horizon'])
env = RecordEpisode(
env,
output_dir=video_dir,
save_trajectory=False,
info_on_video=True,
source_type="diffusion_policy",
source_desc="diffusion_policy evaluation rollout",
)
if env_kwargs["obs_mode"] == "state":
env = gym.wrappers.FrameStack(env, other_kwargs["obs_horizon"])
elif env_kwargs["obs_mode"] == "rgbd":
env = DictFrameStack(env, other_kwargs["obs_horizon"])
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
vector_cls = gym.vector.SyncVectorEnv if num_envs == 1 else lambda x : gym.vector.AsyncVectorEnv(x, context="forkserver")
env = vector_cls([cpu_make_env(env_id, seed, video_dir if seed == 0 else None, env_kwargs, other_kwargs) for seed in range(num_envs)])

vector_cls = (
gym.vector.SyncVectorEnv
if num_envs == 1
else lambda x: gym.vector.AsyncVectorEnv(x, context="forkserver")
)
env = vector_cls(
[
cpu_make_env(
env_id,
seed,
video_dir if seed == 0 else None,
env_kwargs,
other_kwargs,
)
for seed in range(num_envs)
]
)
else:
env = gym.make(env_id, num_envs=num_envs, sim_backend=sim_backend, reconfiguration_freq=1, **env_kwargs)
env = gym.make(
env_id,
num_envs=num_envs,
sim_backend=sim_backend,
reconfiguration_freq=1,
**env_kwargs
)
max_episode_steps = gym_utils.find_max_episode_steps_value(env)
for wrapper in wrappers:
env = wrapper(env)
env = FrameStack(env, num_stack=other_kwargs['obs_horizon'])
env = FrameStack(env, num_stack=other_kwargs["obs_horizon"])
if video_dir:
env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, save_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout", max_steps_per_video=max_episode_steps)
env = RecordEpisode(
env,
output_dir=video_dir,
save_trajectory=False,
save_video=True,
source_type="diffusion_policy",
source_desc="diffusion_policy evaluation rollout",
max_steps_per_video=max_episode_steps,
)
env = ManiSkillVectorEnv(env, ignore_terminations=True, record_metrics=True)
return env
65 changes: 65 additions & 0 deletions examples/baselines/diffusion_policy/diffusion_policy/plain_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch.nn as nn


def make_mlp(in_channels, mlp_channels, act_builder=nn.ReLU, last_act=True):
c_in = in_channels
module_list = []
for idx, c_out in enumerate(mlp_channels):
module_list.append(nn.Linear(c_in, c_out))
if last_act or idx < len(mlp_channels) - 1:
module_list.append(act_builder())
c_in = c_out
return nn.Sequential(*module_list)


class PlainConv(nn.Module):
def __init__(
self,
in_channels=3,
out_dim=256,
pool_feature_map=False,
last_act=True, # True for ConvBody, False for CNN
):
super().__init__()
# assume input image size is 64x64

self.out_dim = out_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels, 16, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [32, 32]
nn.Conv2d(16, 32, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [16, 16]
nn.Conv2d(32, 64, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [8, 8]
nn.Conv2d(64, 128, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [4, 4]
nn.Conv2d(128, 128, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
)

if pool_feature_map:
self.pool = nn.AdaptiveMaxPool2d((1, 1))
self.fc = make_mlp(128, [out_dim], last_act=last_act)
else:
self.pool = None
self.fc = make_mlp(128 * 4 * 4 * 4, [out_dim], last_act=last_act)

self.reset_parameters()

def reset_parameters(self):
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
if module.bias is not None:
nn.init.zeros_(module.bias)

def forward(self, image):
x = self.cnn(image)
if self.pool is not None:
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x
Loading