-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 9b854e8
Showing
520 changed files
with
67,864 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Anonymous | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
|
||
# Overview | ||
![Perf](figures/sgqnarchi.png) | ||
![Perf](figures/sgqn_perf.png) | ||
|
||
## Setup | ||
We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands: | ||
|
||
``` | ||
conda env create -f setup/conda.yml | ||
conda activate dmcgb | ||
sh setup/install_envs.sh | ||
``` | ||
|
||
If you don't have the right mujoco version installed: | ||
``` | ||
sh setup/install_mujoco_deps.sh | ||
sh setup/prepare_dm_control_xp.sh | ||
``` | ||
|
||
|
||
## Datasets | ||
|
||
``` | ||
wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar | ||
``` | ||
After downloading and extracting the data, add your dataset directory to the datasets list in `setup/config.cfg`. | ||
|
||
# Run `SGQN` training: | ||
```bash | ||
python src/train.py --algorithm sgsac --sgqn_quantile [QUANTILE] --seed [SEED] --eval_mode video_easy --domain_name [DOMAIN] --task_name [TASK]; | ||
``` | ||
You can also run `SVEA`, `SODA`, `RAD`, `SAC`. | ||
|
||
|
||
# DMControl Generalization Benchmark | ||
|
||
Benchmark for generalization in continuous control from pixels, based on [DMControl](https://github.com/deepmind/dm_control). | ||
|
||
## Test environments | ||
|
||
The DMControl Generalization Benchmark provides two distinct benchmarks for visual generalization, *random colors* and *video backgrounds*: | ||
|
||
![environment samples](figures/environments.png) | ||
|
||
Both benchmarks are offered in *easy* and *hard* variants. Samples are shown below. | ||
|
||
**video_easy**<br/> | ||
![video_easy](figures/video_easy.png) | ||
|
||
**video_hard**<br/> | ||
![video_hard](figures/video_hard.png) | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm curl \ | ||
--aux_update_freq 1 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm drq \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm curl \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm drq \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm pad \ | ||
--num_shared_layers 8 \ | ||
--num_head_layers 3 \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm rad \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm sac \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm soda \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ | ||
--algorithm svea \ | ||
--eval_episodes 100 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm pad \ | ||
--num_shared_layers 8 \ | ||
--num_head_layers 3 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm rad \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm sac \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm sgsac \ | ||
--seed 0 --eval_mode all --domain_name cartpole --task_name swingup --sgqn_quantile 0.98 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm soda \ | ||
--aux_lr 3e-4 \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ | ||
--algorithm svea \ | ||
--seed 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
name: dmcgb | ||
channels: | ||
- defaults | ||
dependencies: | ||
- python=3.7.6 | ||
- cudatoolkit=9.2 | ||
- absl-py | ||
- pyparsing | ||
- pip | ||
- pip: | ||
- numpy==1.19.5 | ||
- torch==1.7.1 | ||
- torchvision==0.8.2 | ||
- pillow==6.2.0 | ||
- termcolor | ||
- imageio | ||
- imageio-ffmpeg | ||
- opencv-python | ||
- xmltodict | ||
- tqdm | ||
- einops | ||
- kornia | ||
- captum | ||
- tensorboard |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"datasets": [ | ||
"places365_standard/" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
cd src/env/dm_control | ||
pip install -e . | ||
|
||
cd ../dmc2gym | ||
pip install -e . | ||
|
||
cd ../../.. | ||
|
||
curl https://codeload.github.com/nicklashansen/dmcontrol-generalization-benchmark/tar.gz/main | tar -xz --strip=3 dmcontrol-generalization-benchmark-main/src/env/data | ||
mv data src/env/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
cd $HOME/.mujoco | ||
wget https://roboti.us/download/mujoco200_linux.zip | ||
unzip mujoco200_linux.zip | ||
wget https://roboti.us/file/mjkey.txt | ||
mv mjkey.txt mujoco200_linux/bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
export MJKEY_PATH=$HOME/.mujoco/mujoco200_linux/bin/mjkey.txt | ||
export MUJOCO_GL=egl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
import utils | ||
import algorithms.modules as m | ||
from algorithms.sac import SAC | ||
|
||
|
||
class CURL(SAC): | ||
def __init__(self, obs_shape, action_shape, args): | ||
super().__init__(obs_shape, action_shape, args) | ||
self.aux_update_freq = args.aux_update_freq | ||
|
||
self.curl_head = m.CURLHead(self.critic.encoder).cuda() | ||
|
||
self.curl_optimizer = torch.optim.Adam( | ||
self.curl_head.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999) | ||
) | ||
self.train() | ||
|
||
def train(self, training=True): | ||
super().train(training) | ||
if hasattr(self, 'curl_head'): | ||
self.curl_head.train(training) | ||
|
||
def update_curl(self, x, x_pos, L=None, step=None): | ||
assert x.size(-1) == 84 and x_pos.size(-1) == 84 | ||
|
||
z_a = self.curl_head.encoder(x) | ||
with torch.no_grad(): | ||
z_pos = self.critic_target.encoder(x_pos) | ||
|
||
logits = self.curl_head.compute_logits(z_a, z_pos) | ||
labels = torch.arange(logits.shape[0]).long().cuda() | ||
curl_loss = F.cross_entropy(logits, labels) | ||
|
||
self.curl_optimizer.zero_grad() | ||
curl_loss.backward() | ||
self.curl_optimizer.step() | ||
if L is not None: | ||
L.log('train/aux_loss', curl_loss, step) | ||
|
||
def update(self, replay_buffer, L, step): | ||
obs, action, reward, next_obs, not_done, pos = replay_buffer.sample_curl() | ||
|
||
self.update_critic(obs, action, reward, next_obs, not_done, L, step) | ||
|
||
if step % self.actor_update_freq == 0: | ||
self.update_actor_and_alpha(obs, L, step) | ||
|
||
if step % self.critic_target_update_freq == 0: | ||
self.soft_update_critic_target() | ||
|
||
if step % self.aux_update_freq == 0: | ||
self.update_curl(obs, pos, L, step) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
import utils | ||
import algorithms.modules as m | ||
from algorithms.sac import SAC | ||
|
||
|
||
class DrQ(SAC): # [K=1, M=1] | ||
def __init__(self, obs_shape, action_shape, args): | ||
super().__init__(obs_shape, action_shape, args) | ||
|
||
def update(self, replay_buffer, L, step): | ||
obs, action, reward, next_obs, not_done = replay_buffer.sample_drq() | ||
|
||
self.update_critic(obs, action, reward, next_obs, not_done, L, step) | ||
|
||
if step % self.actor_update_freq == 0: | ||
self.update_actor_and_alpha(obs, L, step) | ||
|
||
if step % self.critic_target_update_freq == 0: | ||
self.soft_update_critic_target() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from algorithms.sac import SAC | ||
from algorithms.rad import RAD | ||
from algorithms.curl import CURL | ||
from algorithms.pad import PAD | ||
from algorithms.soda import SODA | ||
from algorithms.drq import DrQ | ||
from algorithms.svea import SVEA | ||
from algorithms.sgsac import SGSAC | ||
|
||
algorithm = { | ||
"sac": SAC, | ||
"rad": RAD, | ||
"curl": CURL, | ||
"pad": PAD, | ||
"soda": SODA, | ||
"drq": DrQ, | ||
"svea": SVEA, | ||
"sgsac": SGSAC, | ||
} | ||
|
||
|
||
def make_agent(obs_shape, action_shape, args): | ||
return algorithm[args.algorithm](obs_shape, action_shape, args) |
Oops, something went wrong.