Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBert committed Oct 3, 2022
0 parents commit 9b854e8
Show file tree
Hide file tree
Showing 520 changed files with 67,864 additions and 0 deletions.
21 changes: 21 additions & 0 deletions DMC/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Anonymous

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
53 changes: 53 additions & 0 deletions DMC/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@

# Overview
![Perf](figures/sgqnarchi.png)
![Perf](figures/sgqn_perf.png)

## Setup
We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands:

```
conda env create -f setup/conda.yml
conda activate dmcgb
sh setup/install_envs.sh
```

If you don't have the right mujoco version installed:
```
sh setup/install_mujoco_deps.sh
sh setup/prepare_dm_control_xp.sh
```


## Datasets

```
wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar
```
After downloading and extracting the data, add your dataset directory to the datasets list in `setup/config.cfg`.

# Run `SGQN` training:
```bash
python src/train.py --algorithm sgsac --sgqn_quantile [QUANTILE] --seed [SEED] --eval_mode video_easy --domain_name [DOMAIN] --task_name [TASK];
```
You can also run `SVEA`, `SODA`, `RAD`, `SAC`.


# DMControl Generalization Benchmark

Benchmark for generalization in continuous control from pixels, based on [DMControl](https://github.com/deepmind/dm_control).

## Test environments

The DMControl Generalization Benchmark provides two distinct benchmarks for visual generalization, *random colors* and *video backgrounds*:

![environment samples](figures/environments.png)

Both benchmarks are offered in *easy* and *hard* variants. Samples are shown below.

**video_easy**<br/>
![video_easy](figures/video_easy.png)

**video_hard**<br/>
![video_hard](figures/video_hard.png)

Binary file added DMC/figures/color_easy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/color_hard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/environments.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/results_table.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/sgqn_perf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/sgqnarchi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/video_easy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added DMC/figures/video_hard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions DMC/scripts/curl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm curl \
--aux_update_freq 1 \
--seed 0
3 changes: 3 additions & 0 deletions DMC/scripts/drq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm drq \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/curl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm curl \
--eval_episodes 100 \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/drq.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm drq \
--eval_episodes 100 \
--seed 0
6 changes: 6 additions & 0 deletions DMC/scripts/eval/pad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm pad \
--num_shared_layers 8 \
--num_head_layers 3 \
--eval_episodes 100 \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/rad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm rad \
--eval_episodes 100 \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/sac.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm sac \
--eval_episodes 100 \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/soda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm soda \
--eval_episodes 100 \
--seed 0
4 changes: 4 additions & 0 deletions DMC/scripts/eval/svea.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
--algorithm svea \
--eval_episodes 100 \
--seed 0
5 changes: 5 additions & 0 deletions DMC/scripts/pad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm pad \
--num_shared_layers 8 \
--num_head_layers 3 \
--seed 0
3 changes: 3 additions & 0 deletions DMC/scripts/rad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm rad \
--seed 0
3 changes: 3 additions & 0 deletions DMC/scripts/sac.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm sac \
--seed 0
3 changes: 3 additions & 0 deletions DMC/scripts/sgsac.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm sgsac \
--seed 0 --eval_mode all --domain_name cartpole --task_name swingup --sgqn_quantile 0.98
4 changes: 4 additions & 0 deletions DMC/scripts/soda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm soda \
--aux_lr 3e-4 \
--seed 0
3 changes: 3 additions & 0 deletions DMC/scripts/svea.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
--algorithm svea \
--seed 0
24 changes: 24 additions & 0 deletions DMC/setup/conda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: dmcgb
channels:
- defaults
dependencies:
- python=3.7.6
- cudatoolkit=9.2
- absl-py
- pyparsing
- pip
- pip:
- numpy==1.19.5
- torch==1.7.1
- torchvision==0.8.2
- pillow==6.2.0
- termcolor
- imageio
- imageio-ffmpeg
- opencv-python
- xmltodict
- tqdm
- einops
- kornia
- captum
- tensorboard
5 changes: 5 additions & 0 deletions DMC/setup/config.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"datasets": [
"places365_standard/"
]
}
10 changes: 10 additions & 0 deletions DMC/setup/install_envs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
cd src/env/dm_control
pip install -e .

cd ../dmc2gym
pip install -e .

cd ../../..

curl https://codeload.github.com/nicklashansen/dmcontrol-generalization-benchmark/tar.gz/main | tar -xz --strip=3 dmcontrol-generalization-benchmark-main/src/env/data
mv data src/env/
5 changes: 5 additions & 0 deletions DMC/setup/install_mujoco_deps.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cd $HOME/.mujoco
wget https://roboti.us/download/mujoco200_linux.zip
unzip mujoco200_linux.zip
wget https://roboti.us/file/mjkey.txt
mv mjkey.txt mujoco200_linux/bin
2 changes: 2 additions & 0 deletions DMC/setup/prepare_dm_control_xp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export MJKEY_PATH=$HOME/.mujoco/mujoco200_linux/bin/mjkey.txt
export MUJOCO_GL=egl
57 changes: 57 additions & 0 deletions DMC/src/algorithms/curl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import utils
import algorithms.modules as m
from algorithms.sac import SAC


class CURL(SAC):
def __init__(self, obs_shape, action_shape, args):
super().__init__(obs_shape, action_shape, args)
self.aux_update_freq = args.aux_update_freq

self.curl_head = m.CURLHead(self.critic.encoder).cuda()

self.curl_optimizer = torch.optim.Adam(
self.curl_head.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999)
)
self.train()

def train(self, training=True):
super().train(training)
if hasattr(self, 'curl_head'):
self.curl_head.train(training)

def update_curl(self, x, x_pos, L=None, step=None):
assert x.size(-1) == 84 and x_pos.size(-1) == 84

z_a = self.curl_head.encoder(x)
with torch.no_grad():
z_pos = self.critic_target.encoder(x_pos)

logits = self.curl_head.compute_logits(z_a, z_pos)
labels = torch.arange(logits.shape[0]).long().cuda()
curl_loss = F.cross_entropy(logits, labels)

self.curl_optimizer.zero_grad()
curl_loss.backward()
self.curl_optimizer.step()
if L is not None:
L.log('train/aux_loss', curl_loss, step)

def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done, pos = replay_buffer.sample_curl()

self.update_critic(obs, action, reward, next_obs, not_done, L, step)

if step % self.actor_update_freq == 0:
self.update_actor_and_alpha(obs, L, step)

if step % self.critic_target_update_freq == 0:
self.soft_update_critic_target()

if step % self.aux_update_freq == 0:
self.update_curl(obs, pos, L, step)
24 changes: 24 additions & 0 deletions DMC/src/algorithms/drq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import utils
import algorithms.modules as m
from algorithms.sac import SAC


class DrQ(SAC): # [K=1, M=1]
def __init__(self, obs_shape, action_shape, args):
super().__init__(obs_shape, action_shape, args)

def update(self, replay_buffer, L, step):
obs, action, reward, next_obs, not_done = replay_buffer.sample_drq()

self.update_critic(obs, action, reward, next_obs, not_done, L, step)

if step % self.actor_update_freq == 0:
self.update_actor_and_alpha(obs, L, step)

if step % self.critic_target_update_freq == 0:
self.soft_update_critic_target()
23 changes: 23 additions & 0 deletions DMC/src/algorithms/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from algorithms.sac import SAC
from algorithms.rad import RAD
from algorithms.curl import CURL
from algorithms.pad import PAD
from algorithms.soda import SODA
from algorithms.drq import DrQ
from algorithms.svea import SVEA
from algorithms.sgsac import SGSAC

algorithm = {
"sac": SAC,
"rad": RAD,
"curl": CURL,
"pad": PAD,
"soda": SODA,
"drq": DrQ,
"svea": SVEA,
"sgsac": SGSAC,
}


def make_agent(obs_shape, action_shape, args):
return algorithm[args.algorithm](obs_shape, action_shape, args)
Loading

0 comments on commit 9b854e8

Please sign in to comment.