-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
187 changed files
with
6,965 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sac --seed $1 --task_name reach --train_steps 250k; | ||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sac --seed $1 --task_name hammerall --train_steps 250k; | ||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sac --seed $1 --task_name push --train_steps 250k; | ||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sac --seed $1 --task_name pegbox --train_steps 250k; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sgsac --seed $1 --task_name reach --train_steps 250k; | ||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm sgsac --seed $1 --task_name pegbox --train_steps 250k; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm soda --seed $1 --task_name reach --train_steps 250k; | ||
|
||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm soda --seed $1 --task_name pegbox --train_steps 250k; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm svea --seed $1 --task_name reach --train_steps 250k; | ||
|
||
|
||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/data/home/david.bertoin/.mujoco/mujoco210/bin | ||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia | ||
python src/train.py --algorithm svea --seed $1 --task_name pegbox --train_steps 250k; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
name: dmcgb | ||
channels: | ||
- defaults | ||
dependencies: | ||
- python=3.7.6 | ||
- cudatoolkit=9.2 | ||
- absl-py | ||
- pyparsing | ||
- pip | ||
- pip: | ||
- numpy==1.19.5 | ||
- torch==1.7.1 | ||
- torchvision==0.8.2 | ||
- pillow==6.2.0 | ||
- termcolor | ||
- imageio | ||
- imageio-ffmpeg | ||
- opencv-python | ||
- xmltodict | ||
- tqdm | ||
- einops | ||
- kornia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
"datasets": [ | ||
"/opt/mehdi_zouitine/distracting_dataset/" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
cd src/env/dm_control | ||
pip install -e . | ||
|
||
cd ../dmc2gym | ||
pip install -e . | ||
|
||
cd ../../.. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
cd $HOME/.mujoco | ||
wget https://roboti.us/download/mujoco200_linux.zip | ||
unzip mujoco200_linux.zip | ||
wget https://roboti.us/file/mjkey.txt | ||
mv mjkey.txt mujoco200_linux/bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
export MJKEY_PATH=$HOME/.mujoco/mujoco200_linux/bin/mjkey.txt | ||
export MUJOCO_GL=egl |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
import utils | ||
import algorithms.modules as m | ||
from algorithms.sac import SAC | ||
|
||
|
||
class CURL(SAC): | ||
def __init__(self, obs_shape, action_shape, args): | ||
super().__init__(obs_shape, action_shape, args) | ||
self.aux_update_freq = args.aux_update_freq | ||
|
||
self.curl_head = m.CURLHead(self.critic.encoder).cuda() | ||
|
||
self.curl_optimizer = torch.optim.Adam( | ||
self.curl_head.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999) | ||
) | ||
self.train() | ||
|
||
def train(self, training=True): | ||
super().train(training) | ||
if hasattr(self, 'curl_head'): | ||
self.curl_head.train(training) | ||
|
||
def update_curl(self, x, x_pos, L=None, step=None): | ||
assert x.size(-1) == 84 and x_pos.size(-1) == 84 | ||
|
||
z_a = self.curl_head.encoder(x) | ||
with torch.no_grad(): | ||
z_pos = self.critic_target.encoder(x_pos) | ||
|
||
logits = self.curl_head.compute_logits(z_a, z_pos) | ||
labels = torch.arange(logits.shape[0]).long().cuda() | ||
curl_loss = F.cross_entropy(logits, labels) | ||
|
||
self.curl_optimizer.zero_grad() | ||
curl_loss.backward() | ||
self.curl_optimizer.step() | ||
if L is not None: | ||
L.log('train/aux_loss', curl_loss, step) | ||
|
||
def update(self, replay_buffer, L, step): | ||
obs, action, reward, next_obs, not_done, pos = replay_buffer.sample_curl() | ||
|
||
self.update_critic(obs, action, reward, next_obs, not_done, L, step) | ||
|
||
if step % self.actor_update_freq == 0: | ||
self.update_actor_and_alpha(obs, L, step) | ||
|
||
if step % self.critic_target_update_freq == 0: | ||
self.soft_update_critic_target() | ||
|
||
if step % self.aux_update_freq == 0: | ||
self.update_curl(obs, pos, L, step) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from copy import deepcopy | ||
import utils | ||
import algorithms.modules as m | ||
from algorithms.sac import SAC | ||
|
||
|
||
class DrQ(SAC): # [K=1, M=1] | ||
def __init__(self, obs_shape, action_shape, args): | ||
super().__init__(obs_shape, action_shape, args) | ||
|
||
def update(self, replay_buffer, L, step): | ||
obs, action, reward, next_obs, not_done = replay_buffer.sample_drq() | ||
|
||
self.update_critic(obs, action, reward, next_obs, not_done, L, step) | ||
|
||
if step % self.actor_update_freq == 0: | ||
self.update_actor_and_alpha(obs, L, step) | ||
|
||
if step % self.critic_target_update_freq == 0: | ||
self.soft_update_critic_target() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from algorithms.sac import SAC | ||
from algorithms.rad import RAD | ||
from algorithms.curl import CURL | ||
from algorithms.pad import PAD | ||
from algorithms.soda import SODA | ||
from algorithms.drq import DrQ | ||
from algorithms.svea import SVEA | ||
from algorithms.sgsac import SGSAC | ||
|
||
algorithm = { | ||
"sac": SAC, | ||
"rad": RAD, | ||
"curl": CURL, | ||
"pad": PAD, | ||
"soda": SODA, | ||
"drq": DrQ, | ||
"svea": SVEA, | ||
"sgsac": SGSAC, | ||
} | ||
|
||
|
||
def make_agent(obs_shape, action_shape, args): | ||
return algorithm[args.algorithm](obs_shape, action_shape, args) |
Oops, something went wrong.