diff --git a/DMC/LICENSE b/DMC/LICENSE
new file mode 100644
index 0000000..a6418fc
--- /dev/null
+++ b/DMC/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Anonymous
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/DMC/README.md b/DMC/README.md
new file mode 100644
index 0000000..0d88872
--- /dev/null
+++ b/DMC/README.md
@@ -0,0 +1,53 @@
+
+# Overview
+![Perf](figures/sgqnarchi.png)
+![Perf](figures/sgqn_perf.png)
+
+## Setup
+We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands:
+
+```
+conda env create -f setup/conda.yml
+conda activate dmcgb
+sh setup/install_envs.sh
+```
+
+If you don't have the right mujoco version installed:
+```
+sh setup/install_mujoco_deps.sh
+sh setup/prepare_dm_control_xp.sh
+```
+
+
+## Datasets
+
+```
+wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar
+```
+After downloading and extracting the data, add your dataset directory to the datasets list in `setup/config.cfg`.
+
+# Run `SGQN` training:
+```bash
+python src/train.py --algorithm sgsac --sgqn_quantile [QUANTILE] --seed [SEED] --eval_mode video_easy --domain_name [DOMAIN] --task_name [TASK];
+```
+You can also run `SVEA`, `SODA`, `RAD`, `SAC`.
+
+
+# DMControl Generalization Benchmark
+
+Benchmark for generalization in continuous control from pixels, based on [DMControl](https://github.com/deepmind/dm_control).
+
+## Test environments
+
+The DMControl Generalization Benchmark provides two distinct benchmarks for visual generalization, *random colors* and *video backgrounds*:
+
+![environment samples](figures/environments.png)
+
+Both benchmarks are offered in *easy* and *hard* variants. Samples are shown below.
+
+**video_easy**
+![video_easy](figures/video_easy.png)
+
+**video_hard**
+![video_hard](figures/video_hard.png)
+
diff --git a/DMC/figures/color_easy.png b/DMC/figures/color_easy.png
new file mode 100644
index 0000000..9b30e84
Binary files /dev/null and b/DMC/figures/color_easy.png differ
diff --git a/DMC/figures/color_hard.png b/DMC/figures/color_hard.png
new file mode 100644
index 0000000..c517d55
Binary files /dev/null and b/DMC/figures/color_hard.png differ
diff --git a/DMC/figures/environments.png b/DMC/figures/environments.png
new file mode 100644
index 0000000..5563ad9
Binary files /dev/null and b/DMC/figures/environments.png differ
diff --git a/DMC/figures/results_table.png b/DMC/figures/results_table.png
new file mode 100644
index 0000000..3ef8d29
Binary files /dev/null and b/DMC/figures/results_table.png differ
diff --git a/DMC/figures/sgqn_perf.png b/DMC/figures/sgqn_perf.png
new file mode 100644
index 0000000..f4f4901
Binary files /dev/null and b/DMC/figures/sgqn_perf.png differ
diff --git a/DMC/figures/sgqnarchi.png b/DMC/figures/sgqnarchi.png
new file mode 100644
index 0000000..497f1f5
Binary files /dev/null and b/DMC/figures/sgqnarchi.png differ
diff --git a/DMC/figures/video_easy.png b/DMC/figures/video_easy.png
new file mode 100644
index 0000000..a81692c
Binary files /dev/null and b/DMC/figures/video_easy.png differ
diff --git a/DMC/figures/video_hard.png b/DMC/figures/video_hard.png
new file mode 100644
index 0000000..095aea2
Binary files /dev/null and b/DMC/figures/video_hard.png differ
diff --git a/DMC/scripts/curl.sh b/DMC/scripts/curl.sh
new file mode 100644
index 0000000..4995c07
--- /dev/null
+++ b/DMC/scripts/curl.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm curl \
+ --aux_update_freq 1 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/drq.sh b/DMC/scripts/drq.sh
new file mode 100644
index 0000000..a8ffe56
--- /dev/null
+++ b/DMC/scripts/drq.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm drq \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/curl.sh b/DMC/scripts/eval/curl.sh
new file mode 100644
index 0000000..0b6b820
--- /dev/null
+++ b/DMC/scripts/eval/curl.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm curl \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/drq.sh b/DMC/scripts/eval/drq.sh
new file mode 100644
index 0000000..7dace5c
--- /dev/null
+++ b/DMC/scripts/eval/drq.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm drq \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/pad.sh b/DMC/scripts/eval/pad.sh
new file mode 100644
index 0000000..38afc8f
--- /dev/null
+++ b/DMC/scripts/eval/pad.sh
@@ -0,0 +1,6 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm pad \
+ --num_shared_layers 8 \
+ --num_head_layers 3 \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/rad.sh b/DMC/scripts/eval/rad.sh
new file mode 100644
index 0000000..70c55a7
--- /dev/null
+++ b/DMC/scripts/eval/rad.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm rad \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/sac.sh b/DMC/scripts/eval/sac.sh
new file mode 100644
index 0000000..130dca0
--- /dev/null
+++ b/DMC/scripts/eval/sac.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm sac \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/soda.sh b/DMC/scripts/eval/soda.sh
new file mode 100644
index 0000000..49f3acf
--- /dev/null
+++ b/DMC/scripts/eval/soda.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm soda \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/eval/svea.sh b/DMC/scripts/eval/svea.sh
new file mode 100644
index 0000000..dabd734
--- /dev/null
+++ b/DMC/scripts/eval/svea.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \
+ --algorithm svea \
+ --eval_episodes 100 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/pad.sh b/DMC/scripts/pad.sh
new file mode 100644
index 0000000..02093d1
--- /dev/null
+++ b/DMC/scripts/pad.sh
@@ -0,0 +1,5 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm pad \
+ --num_shared_layers 8 \
+ --num_head_layers 3 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/rad.sh b/DMC/scripts/rad.sh
new file mode 100644
index 0000000..97b5a0a
--- /dev/null
+++ b/DMC/scripts/rad.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm rad \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/sac.sh b/DMC/scripts/sac.sh
new file mode 100644
index 0000000..0a39f65
--- /dev/null
+++ b/DMC/scripts/sac.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm sac \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/sgsac.sh b/DMC/scripts/sgsac.sh
new file mode 100644
index 0000000..79ac13b
--- /dev/null
+++ b/DMC/scripts/sgsac.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm sgsac \
+ --seed 0 --eval_mode all --domain_name cartpole --task_name swingup --sgqn_quantile 0.98
\ No newline at end of file
diff --git a/DMC/scripts/soda.sh b/DMC/scripts/soda.sh
new file mode 100644
index 0000000..33cca72
--- /dev/null
+++ b/DMC/scripts/soda.sh
@@ -0,0 +1,4 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm soda \
+ --aux_lr 3e-4 \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/scripts/svea.sh b/DMC/scripts/svea.sh
new file mode 100644
index 0000000..fbc7157
--- /dev/null
+++ b/DMC/scripts/svea.sh
@@ -0,0 +1,3 @@
+CUDA_VISIBLE_DEVICES=0 python3 src/train.py \
+ --algorithm svea \
+ --seed 0
\ No newline at end of file
diff --git a/DMC/setup/conda.yaml b/DMC/setup/conda.yaml
new file mode 100644
index 0000000..bd24780
--- /dev/null
+++ b/DMC/setup/conda.yaml
@@ -0,0 +1,24 @@
+name: dmcgb
+channels:
+ - defaults
+dependencies:
+ - python=3.7.6
+ - cudatoolkit=9.2
+ - absl-py
+ - pyparsing
+ - pip
+ - pip:
+ - numpy==1.19.5
+ - torch==1.7.1
+ - torchvision==0.8.2
+ - pillow==6.2.0
+ - termcolor
+ - imageio
+ - imageio-ffmpeg
+ - opencv-python
+ - xmltodict
+ - tqdm
+ - einops
+ - kornia
+ - captum
+ - tensorboard
diff --git a/DMC/setup/config.cfg b/DMC/setup/config.cfg
new file mode 100644
index 0000000..71f5502
--- /dev/null
+++ b/DMC/setup/config.cfg
@@ -0,0 +1,5 @@
+{
+ "datasets": [
+ "places365_standard/"
+ ]
+}
diff --git a/DMC/setup/install_envs.sh b/DMC/setup/install_envs.sh
new file mode 100644
index 0000000..8ba6329
--- /dev/null
+++ b/DMC/setup/install_envs.sh
@@ -0,0 +1,10 @@
+cd src/env/dm_control
+pip install -e .
+
+cd ../dmc2gym
+pip install -e .
+
+cd ../../..
+
+curl https://codeload.github.com/nicklashansen/dmcontrol-generalization-benchmark/tar.gz/main | tar -xz --strip=3 dmcontrol-generalization-benchmark-main/src/env/data
+mv data src/env/
\ No newline at end of file
diff --git a/DMC/setup/install_mujoco_deps.sh b/DMC/setup/install_mujoco_deps.sh
new file mode 100644
index 0000000..b784d5c
--- /dev/null
+++ b/DMC/setup/install_mujoco_deps.sh
@@ -0,0 +1,5 @@
+cd $HOME/.mujoco
+wget https://roboti.us/download/mujoco200_linux.zip
+unzip mujoco200_linux.zip
+wget https://roboti.us/file/mjkey.txt
+mv mjkey.txt mujoco200_linux/bin
\ No newline at end of file
diff --git a/DMC/setup/prepare_dm_control_xp.sh b/DMC/setup/prepare_dm_control_xp.sh
new file mode 100644
index 0000000..7094a47
--- /dev/null
+++ b/DMC/setup/prepare_dm_control_xp.sh
@@ -0,0 +1,2 @@
+export MJKEY_PATH=$HOME/.mujoco/mujoco200_linux/bin/mjkey.txt
+export MUJOCO_GL=egl
\ No newline at end of file
diff --git a/DMC/src/algorithms/curl.py b/DMC/src/algorithms/curl.py
new file mode 100644
index 0000000..32b9466
--- /dev/null
+++ b/DMC/src/algorithms/curl.py
@@ -0,0 +1,57 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+
+class CURL(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+ self.aux_update_freq = args.aux_update_freq
+
+ self.curl_head = m.CURLHead(self.critic.encoder).cuda()
+
+ self.curl_optimizer = torch.optim.Adam(
+ self.curl_head.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999)
+ )
+ self.train()
+
+ def train(self, training=True):
+ super().train(training)
+ if hasattr(self, 'curl_head'):
+ self.curl_head.train(training)
+
+ def update_curl(self, x, x_pos, L=None, step=None):
+ assert x.size(-1) == 84 and x_pos.size(-1) == 84
+
+ z_a = self.curl_head.encoder(x)
+ with torch.no_grad():
+ z_pos = self.critic_target.encoder(x_pos)
+
+ logits = self.curl_head.compute_logits(z_a, z_pos)
+ labels = torch.arange(logits.shape[0]).long().cuda()
+ curl_loss = F.cross_entropy(logits, labels)
+
+ self.curl_optimizer.zero_grad()
+ curl_loss.backward()
+ self.curl_optimizer.step()
+ if L is not None:
+ L.log('train/aux_loss', curl_loss, step)
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done, pos = replay_buffer.sample_curl()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
+
+ if step % self.aux_update_freq == 0:
+ self.update_curl(obs, pos, L, step)
diff --git a/DMC/src/algorithms/drq.py b/DMC/src/algorithms/drq.py
new file mode 100644
index 0000000..51999f8
--- /dev/null
+++ b/DMC/src/algorithms/drq.py
@@ -0,0 +1,24 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+
+class DrQ(SAC): # [K=1, M=1]
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample_drq()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
diff --git a/DMC/src/algorithms/factory.py b/DMC/src/algorithms/factory.py
new file mode 100644
index 0000000..af3b374
--- /dev/null
+++ b/DMC/src/algorithms/factory.py
@@ -0,0 +1,23 @@
+from algorithms.sac import SAC
+from algorithms.rad import RAD
+from algorithms.curl import CURL
+from algorithms.pad import PAD
+from algorithms.soda import SODA
+from algorithms.drq import DrQ
+from algorithms.svea import SVEA
+from algorithms.sgsac import SGSAC
+
+algorithm = {
+ "sac": SAC,
+ "rad": RAD,
+ "curl": CURL,
+ "pad": PAD,
+ "soda": SODA,
+ "drq": DrQ,
+ "svea": SVEA,
+ "sgsac": SGSAC,
+}
+
+
+def make_agent(obs_shape, action_shape, args):
+ return algorithm[args.algorithm](obs_shape, action_shape, args)
diff --git a/DMC/src/algorithms/modules.py b/DMC/src/algorithms/modules.py
new file mode 100644
index 0000000..8837dc2
--- /dev/null
+++ b/DMC/src/algorithms/modules.py
@@ -0,0 +1,354 @@
+from turtle import forward
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from functools import partial
+
+
+def _get_out_shape_cuda(in_shape, layers):
+ x = torch.randn(*in_shape).cuda().unsqueeze(0)
+ return layers(x).squeeze(0).shape
+
+
+def _get_out_shape(in_shape, layers):
+ x = torch.randn(*in_shape).unsqueeze(0)
+ return layers(x).squeeze(0).shape
+
+
+def gaussian_logprob(noise, log_std):
+ """Compute Gaussian log probability"""
+ residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True)
+ return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1)
+
+
+def squash(mu, pi, log_pi):
+ """Apply squashing function, see appendix C from https://arxiv.org/pdf/1812.05905.pdf"""
+ mu = torch.tanh(mu)
+ if pi is not None:
+ pi = torch.tanh(pi)
+ if log_pi is not None:
+ log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True)
+ return mu, pi, log_pi
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ """Truncated normal distribution, see https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf"""
+
+ def norm_cdf(x):
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ with torch.no_grad():
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+ tensor.erfinv_()
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def weight_init(m):
+ """Custom weight init for Conv2D and Linear layers"""
+ if isinstance(m, nn.Linear):
+ nn.init.orthogonal_(m.weight.data)
+ if hasattr(m.bias, "data"):
+ m.bias.data.fill_(0.0)
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf
+ assert m.weight.size(2) == m.weight.size(3)
+ m.weight.data.fill_(0.0)
+ if hasattr(m.bias, "data"):
+ m.bias.data.fill_(0.0)
+ mid = m.weight.size(2) // 2
+ gain = nn.init.calculate_gain("relu")
+ nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain)
+
+
+class CenterCrop(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ assert size in {84, 100}, f"unexpected size: {size}"
+ self.size = size
+
+ def forward(self, x):
+ assert x.ndim == 4, "input must be a 4D tensor"
+ if x.size(2) == self.size and x.size(3) == self.size:
+ return x
+ assert x.size(3) == 100, f"unexpected size: {x.size(3)}"
+ if self.size == 84:
+ p = 8
+ return x[:, :, p:-p, p:-p]
+
+
+class NormalizeImg(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x / 255.0
+
+
+class Flatten(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class RLProjection(nn.Module):
+ def __init__(self, in_shape, out_dim):
+ super().__init__()
+ self.out_dim = out_dim
+ self.projection = nn.Sequential(
+ nn.Linear(in_shape[0], out_dim), nn.LayerNorm(out_dim), nn.Tanh()
+ )
+ self.apply(weight_init)
+
+ def forward(self, x):
+ y = self.projection(x)
+ return y
+
+
+class SODAMLP(nn.Module):
+ def __init__(self, projection_dim, hidden_dim, out_dim):
+ super().__init__()
+ self.out_dim = out_dim
+ self.mlp = nn.Sequential(
+ nn.Linear(projection_dim, hidden_dim),
+ nn.BatchNorm1d(hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, out_dim),
+ )
+ self.apply(weight_init)
+
+ def forward(self, x):
+ return self.mlp(x)
+
+
+class SharedCNN(nn.Module):
+ def __init__(self, obs_shape, num_layers=11, num_filters=32):
+ super().__init__()
+ assert len(obs_shape) == 3
+ self.num_layers = num_layers
+ self.num_filters = num_filters
+
+ self.layers = [
+ CenterCrop(size=84),
+ NormalizeImg(),
+ nn.Conv2d(obs_shape[0], num_filters, 3, stride=2),
+ ]
+ for _ in range(1, num_layers):
+ self.layers.append(nn.ReLU())
+ self.layers.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
+ self.layers = nn.Sequential(*self.layers)
+ self.out_shape = _get_out_shape(obs_shape, self.layers)
+ self.apply(weight_init)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class HeadCNN(nn.Module):
+ def __init__(self, in_shape, num_layers=0, num_filters=32):
+ super().__init__()
+ self.layers = []
+ for _ in range(0, num_layers):
+ self.layers.append(nn.ReLU())
+ self.layers.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
+ self.layers.append(Flatten())
+ self.layers = nn.Sequential(*self.layers)
+ self.out_shape = _get_out_shape(in_shape, self.layers)
+ self.apply(weight_init)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class Encoder(nn.Module):
+ def __init__(self, shared_cnn, head_cnn, projection):
+ super().__init__()
+ self.shared_cnn = shared_cnn
+ self.head_cnn = head_cnn
+ self.projection = projection
+ self.out_dim = projection.out_dim
+
+ def forward(self, x, detach=False):
+ x = self.shared_cnn(x)
+ x = self.head_cnn(x)
+ if detach:
+ x = x.detach()
+ return self.projection(x)
+
+
+class Actor(nn.Module):
+ def __init__(self, encoder, action_shape, hidden_dim, log_std_min, log_std_max):
+ super().__init__()
+ self.encoder = encoder
+ self.log_std_min = log_std_min
+ self.log_std_max = log_std_max
+ self.mlp = nn.Sequential(
+ nn.Linear(self.encoder.out_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, 2 * action_shape[0]),
+ )
+ self.mlp.apply(weight_init)
+
+ def forward(
+ self,
+ x,
+ compute_pi=True,
+ compute_log_pi=True,
+ detach=False,
+ compute_attrib=False,
+ ):
+ x = self.encoder(x, detach)
+ mu, log_std = self.mlp(x).chunk(2, dim=-1)
+ log_std = torch.tanh(log_std)
+ log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (
+ log_std + 1
+ )
+
+ if compute_pi:
+ std = log_std.exp()
+ noise = torch.randn_like(mu)
+ pi = mu + noise * std
+ else:
+ pi = None
+ entropy = None
+
+ if compute_log_pi:
+ log_pi = gaussian_logprob(noise, log_std)
+ else:
+ log_pi = None
+
+ mu, pi, log_pi = squash(mu, pi, log_pi)
+
+ return mu, pi, log_pi, log_std
+
+
+class QFunction(nn.Module):
+ def __init__(self, obs_dim, action_dim, hidden_dim):
+ super().__init__()
+ self.trunk = nn.Sequential(
+ nn.Linear(obs_dim + action_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, 1),
+ )
+ self.apply(weight_init)
+
+ def forward(self, obs, action):
+ assert obs.size(0) == action.size(0)
+ return self.trunk(torch.cat([obs, action], dim=1))
+
+
+class Critic(nn.Module):
+ def __init__(self, encoder, action_shape, hidden_dim):
+ super().__init__()
+ self.encoder = encoder
+ self.Q1 = QFunction(self.encoder.out_dim, action_shape[0], hidden_dim)
+ self.Q2 = QFunction(self.encoder.out_dim, action_shape[0], hidden_dim)
+
+ def forward(self, x, action, detach=False):
+ x = self.encoder(x, detach)
+ return self.Q1(x, action), self.Q2(x, action)
+
+
+class CURLHead(nn.Module):
+ def __init__(self, encoder):
+ super().__init__()
+ self.encoder = encoder
+ self.W = nn.Parameter(torch.rand(encoder.out_dim, encoder.out_dim))
+
+ def compute_logits(self, z_a, z_pos):
+ """
+ Uses logits trick for CURL:
+ - compute (B,B) matrix z_a (W z_pos.T)
+ - positives are all diagonal elements
+ - negatives are all other elements
+ - to compute loss use multiclass cross entropy with identity matrix for labels
+ """
+ Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B)
+ logits = torch.matmul(z_a, Wz) # (B,B)
+ logits = logits - torch.max(logits, 1)[0][:, None]
+ return logits
+
+
+class InverseDynamics(nn.Module):
+ def __init__(self, encoder, action_shape, hidden_dim):
+ super().__init__()
+ self.encoder = encoder
+ self.mlp = nn.Sequential(
+ nn.Linear(2 * encoder.out_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, action_shape[0]),
+ )
+ self.apply(weight_init)
+
+ def forward(self, x, x_next):
+ h = self.encoder(x)
+ h_next = self.encoder(x_next)
+ joint_h = torch.cat([h, h_next], dim=1)
+ return self.mlp(joint_h)
+
+
+class SODAPredictor(nn.Module):
+ def __init__(self, encoder, hidden_dim):
+ super().__init__()
+ self.encoder = encoder
+ self.mlp = SODAMLP(encoder.out_dim, hidden_dim, encoder.out_dim)
+ self.apply(weight_init)
+
+ def forward(self, x):
+ return self.mlp(self.encoder(x))
+
+
+class AttributionDecoder(nn.Module):
+ def __init__(self,action_shape, emb_dim=100) -> None:
+ super().__init__()
+ self.proj = nn.Linear(in_features=emb_dim+action_shape, out_features=14112)
+ self.conv1 = nn.Conv2d(
+ in_channels=32, out_channels=128, kernel_size=3, padding=1
+ )
+ self.relu = nn.ReLU()
+ self.conv2 = nn.Conv2d(
+ in_channels=128, out_channels=64, kernel_size=3, padding=1
+ )
+ self.conv3 = nn.Conv2d(in_channels=64, out_channels=9, kernel_size=3, padding=1)
+
+ def forward(self, x, action):
+ x = torch.cat([x,action],dim=1)
+ x = self.proj(x).view(-1, 32, 21, 21)
+ x = self.relu(x)
+ x = self.conv1(x)
+ x = F.upsample(x, scale_factor=2)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = F.upsample(x, scale_factor=2)
+ x = self.relu(x)
+ x = self.conv3(x)
+ return x
+
+
+
+class AttributionPredictor(nn.Module):
+ def __init__(self, action_shape,encoder, emb_dim=100):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = AttributionDecoder(action_shape,encoder.out_dim)
+ self.features_decoder = nn.Sequential(
+ nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, emb_dim)
+ )
+
+ def forward(self, x,action):
+ x = self.encoder(x)
+ return self.decoder(x,action)
diff --git a/DMC/src/algorithms/pad.py b/DMC/src/algorithms/pad.py
new file mode 100644
index 0000000..4a1d9a7
--- /dev/null
+++ b/DMC/src/algorithms/pad.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+
+class PAD(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+ self.aux_update_freq = args.aux_update_freq
+ self.aux_lr = args.aux_lr
+ self.aux_beta = args.aux_beta
+
+ shared_cnn = self.critic.encoder.shared_cnn
+ aux_cnn = m.HeadCNN(shared_cnn.out_shape, args.num_head_layers, args.num_filters).cuda()
+ aux_encoder = m.Encoder(
+ shared_cnn,
+ aux_cnn,
+ m.RLProjection(aux_cnn.out_shape, args.projection_dim)
+ )
+ self.pad_head = m.InverseDynamics(aux_encoder, action_shape, args.hidden_dim).cuda()
+ self.init_pad_optimizer()
+ self.train()
+
+ def train(self, training=True):
+ super().train(training)
+ if hasattr(self, 'pad_head'):
+ self.pad_head.train(training)
+
+ def init_pad_optimizer(self):
+ self.pad_optimizer = torch.optim.Adam(
+ self.pad_head.parameters(), lr=self.aux_lr, betas=(self.aux_beta, 0.999)
+ )
+
+ def update_inverse_dynamics(self, obs, obs_next, action, L=None, step=None):
+ assert obs.shape[-1] == 84 and obs_next.shape[-1] == 84
+
+ pred_action = self.pad_head(obs, obs_next)
+ pad_loss = F.mse_loss(pred_action, action)
+
+ self.pad_optimizer.zero_grad()
+ pad_loss.backward()
+ self.pad_optimizer.step()
+ if L is not None:
+ L.log('train/aux_loss', pad_loss, step)
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
+
+ if step % self.aux_update_freq == 0:
+ self.update_inverse_dynamics(obs, next_obs, action, L, step)
diff --git a/DMC/src/algorithms/rad.py b/DMC/src/algorithms/rad.py
new file mode 100644
index 0000000..89b257e
--- /dev/null
+++ b/DMC/src/algorithms/rad.py
@@ -0,0 +1,13 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+
+class RAD(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
diff --git a/DMC/src/algorithms/rl_utils.py b/DMC/src/algorithms/rl_utils.py
new file mode 100644
index 0000000..1bb27ad
--- /dev/null
+++ b/DMC/src/algorithms/rl_utils.py
@@ -0,0 +1,107 @@
+import torch
+from torchvision.utils import make_grid
+
+from captum.attr import GuidedBackprop, GuidedGradCam
+
+
+class HookFeatures:
+ def __init__(self, module):
+ self.feature_hook = module.register_forward_hook(self.feature_hook_fn)
+
+ def feature_hook_fn(self, module, input, output):
+ self.features = output.clone().detach()
+ self.gradient_hook = output.register_hook(self.gradient_hook_fn)
+
+ def gradient_hook_fn(self, grad):
+ self.gradients = grad
+
+ def close(self):
+ self.feature_hook.remove()
+ self.gradient_hook.remove()
+
+
+class ModelWrapper(torch.nn.Module):
+ def __init__(self, model, action=None):
+ super(ModelWrapper, self).__init__()
+ self.model = model
+ self.action = action
+
+ def forward(self, obs):
+ if self.action is None:
+ return self.model(obs)[0]
+ return self.model(obs, self.action)[0]
+
+
+def compute_guided_backprop(obs, action, model):
+ model = ModelWrapper(model, action=action)
+ gbp = GuidedBackprop(model)
+ attribution = gbp.attribute(obs)
+ return attribution
+
+def compute_guided_gradcam(obs, action, model):
+ obs.requires_grad_()
+ obs.retain_grad()
+ model = ModelWrapper(model, action=action)
+ gbp = GuidedGradCam(model,layer=model.model.encoder.head_cnn.layers)
+ attribution = gbp.attribute(obs,attribute_to_layer_input=True)
+ return attribution
+
+def compute_vanilla_grad(critic_target, obs, action):
+ obs.requires_grad_()
+ obs.retain_grad()
+ q, q2 = critic_target(obs, action.detach())
+ q.sum().backward()
+ return obs.grad
+
+
+def compute_attribution(model, obs, action=None,method="guided_backprop"):
+ if method == "guided_backprop":
+ return compute_guided_backprop(obs, action, model)
+ if method == 'guided_gradcam':
+ return compute_guided_gradcam(obs,action,model)
+ return compute_vanilla_grad(model, obs, action)
+
+
+def compute_features_attribution(critic_target, obs, action):
+ obs.requires_grad_()
+ obs.retain_grad()
+ hook = HookFeatures(critic_target.encoder)
+ q, _ = critic_target(obs, action.detach())
+ q.sum().backward()
+ features_gardients = hook.gradients
+ hook.close()
+ return obs.grad, features_gardients
+
+
+def compute_attribution_mask(obs_grad, quantile=0.95):
+ mask = []
+ for i in [0, 3, 6]:
+ attributions = obs_grad[:, i : i + 3].abs().max(dim=1)[0]
+ q = torch.quantile(attributions.flatten(1), quantile, 1)
+ mask.append((attributions >= q[:, None, None]).unsqueeze(1).repeat(1, 3, 1, 1))
+ return torch.cat(mask, dim=1)
+
+
+def make_obs_grid(obs, n=4):
+ sample = []
+ for i in range(n):
+ for j in range(0, 9, 3):
+ sample.append(obs[i, j : j + 3].unsqueeze(0))
+ sample = torch.cat(sample, 0)
+ return make_grid(sample, nrow=3) / 255.0
+
+
+def make_attribution_pred_grid(attribution_pred, n=4):
+ return make_grid(attribution_pred[:n], nrow=1)
+
+
+def make_obs_grad_grid(obs_grad, n=4):
+ sample = []
+ for i in range(n):
+ for j in range(0, 9, 3):
+ channel_attribution, _ = torch.max(obs_grad[i, j : j + 3], dim=0)
+ sample.append(channel_attribution[(None,) * 2] / channel_attribution.max())
+ sample = torch.cat(sample, 0)
+ q = torch.quantile(sample.flatten(1), 0.97, 1)
+ sample[sample <= q[:, None, None, None]] = 0
+ return make_grid(sample, nrow=3)
diff --git a/DMC/src/algorithms/sac.py b/DMC/src/algorithms/sac.py
new file mode 100644
index 0000000..38e9c66
--- /dev/null
+++ b/DMC/src/algorithms/sac.py
@@ -0,0 +1,169 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+
+
+class FeaturesHook:
+ def __init__(self, module):
+ self.hook = module.register_forward_hook(self.hook_fn)
+
+ def hook_fn(self, module, input, output):
+ self.features = output
+
+ def close(self):
+ self.hook.remove()
+
+
+class SAC(object):
+ def __init__(self, obs_shape, action_shape, args):
+ self.discount = args.discount
+ self.critic_tau = args.critic_tau
+ self.encoder_tau = args.encoder_tau
+ self.actor_update_freq = args.actor_update_freq
+ self.critic_target_update_freq = args.critic_target_update_freq
+
+ shared_cnn = m.SharedCNN(
+ obs_shape, args.num_shared_layers, args.num_filters
+ ).cuda()
+ head_cnn = m.HeadCNN(
+ shared_cnn.out_shape, args.num_head_layers, args.num_filters
+ ).cuda()
+ actor_encoder = m.Encoder(
+ shared_cnn,
+ head_cnn,
+ m.RLProjection(head_cnn.out_shape, args.projection_dim),
+ )
+ critic_encoder = m.Encoder(
+ shared_cnn,
+ head_cnn,
+ m.RLProjection(head_cnn.out_shape, args.projection_dim),
+ )
+
+ self.actor = m.Actor(
+ actor_encoder,
+ action_shape,
+ args.hidden_dim,
+ args.actor_log_std_min,
+ args.actor_log_std_max,
+ ).cuda()
+ self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim).cuda()
+ self.critic_target = deepcopy(self.critic)
+
+ self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda()
+ self.log_alpha.requires_grad = True
+ self.target_entropy = -np.prod(action_shape)
+
+ self.actor_optimizer = torch.optim.Adam(
+ self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999)
+ )
+ self.critic_optimizer = torch.optim.Adam(
+ self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999),weight_decay=args.critic_weight_decay,
+ )
+ self.log_alpha_optimizer = torch.optim.Adam(
+ [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999)
+ )
+
+ self.train()
+ self.critic_target.train()
+ self.hook = FeaturesHook(self.critic.encoder.head_cnn)
+
+ def train(self, training=True):
+ self.training = training
+ self.actor.train(training)
+ self.critic.train(training)
+
+ def eval(self):
+ self.train(False)
+
+ @property
+ def alpha(self):
+ return self.log_alpha.exp()
+
+ def _obs_to_input(self, obs):
+ if isinstance(obs, utils.LazyFrames):
+ _obs = np.array(obs)
+ else:
+ _obs = obs
+ _obs = torch.FloatTensor(_obs).cuda()
+ _obs = _obs.unsqueeze(0)
+ return _obs
+
+ def select_action(self, obs):
+ _obs = self._obs_to_input(obs)
+ with torch.no_grad():
+ mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False)
+ return mu.cpu().data.numpy().flatten()
+
+ def sample_action(self, obs):
+ _obs = self._obs_to_input(obs)
+ with torch.no_grad():
+ mu, pi, _, _ = self.actor(_obs, compute_log_pi=False)
+ return pi.cpu().data.numpy().flatten()
+
+ def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None):
+ with torch.no_grad():
+ _, policy_action, log_pi, _ = self.actor(next_obs)
+ target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
+ target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
+ target_Q = reward + (not_done * self.discount * target_V)
+
+ current_Q1, current_Q2 = self.critic(obs, action)
+ critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
+ current_Q2, target_Q
+ )
+ if L is not None:
+ L.log("train_critic/loss", critic_loss, step)
+
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True):
+ _, pi, log_pi, log_std = self.actor(obs, detach=True)
+ actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True)
+
+ actor_Q = torch.min(actor_Q1, actor_Q2)
+ actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
+
+ if L is not None:
+ L.log("train_actor/loss", actor_loss, step)
+ entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum(
+ dim=-1
+ )
+
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ self.actor_optimizer.step()
+
+ if update_alpha:
+ self.log_alpha_optimizer.zero_grad()
+ alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean()
+
+ if L is not None:
+ L.log("train_alpha/loss", alpha_loss, step)
+ L.log("train_alpha/value", self.alpha, step)
+
+ alpha_loss.backward()
+ self.log_alpha_optimizer.step()
+
+ def soft_update_critic_target(self):
+ utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau)
+ utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau)
+ utils.soft_update_params(
+ self.critic.encoder, self.critic_target.encoder, self.encoder_tau
+ )
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
diff --git a/DMC/src/algorithms/sgsac.py b/DMC/src/algorithms/sgsac.py
new file mode 100644
index 0000000..9dc63aa
--- /dev/null
+++ b/DMC/src/algorithms/sgsac.py
@@ -0,0 +1,137 @@
+import os
+from copy import deepcopy
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.tensorboard import SummaryWriter
+
+import utils
+import augmentations
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+from .rl_utils import (
+ compute_attribution,
+ compute_attribution_mask,
+ make_attribution_pred_grid,
+ make_obs_grid,
+ make_obs_grad_grid,
+)
+import random
+
+
+class SGSAC(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+
+ self.attribution_predictor = m.AttributionPredictor(action_shape[0],self.critic.encoder).cuda()
+ self.quantile = args.sgqn_quantile
+ self.aux_update_freq = args.aux_update_freq
+ self.consistency = args.consistency
+
+ self.aux_optimizer = torch.optim.Adam(
+ self.attribution_predictor.parameters(),
+ lr=args.aux_lr,
+ betas=(args.aux_beta, 0.999),
+ )
+
+ tb_dir = os.path.join(
+ args.log_dir,
+ args.domain_name + "_" + args.task_name,
+ args.algorithm,
+ str(args.seed),
+ "tensorboard",
+ )
+ self.writer = SummaryWriter(tb_dir)
+
+
+ def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None):
+ with torch.no_grad():
+ _, policy_action, log_pi, _ = self.actor(next_obs)
+ target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
+ target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
+ target_Q = reward + (not_done * self.discount * target_V)
+
+ current_Q1, current_Q2 = self.critic(obs, action)
+
+
+ critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
+ current_Q2, target_Q
+ )
+ if self.consistency:
+ obs_grad = compute_attribution(self.critic,obs,action.detach())
+ mask = compute_attribution_mask(obs_grad,self.quantile)
+ masked_obs = obs*mask
+ masked_obs[mask<1] = random.uniform(obs.view(-1).min(),obs.view(-1).max())
+ masked_Q1,masked_Q2 = self.critic(masked_obs,action)
+ critic_loss += 0.5 *(F.mse_loss(current_Q1,masked_Q1) + F.mse_loss(current_Q2,masked_Q2))
+ if L is not None:
+ L.log("train_critic/loss", critic_loss, step)
+
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ def update_aux(self, obs, action, obs_grad, mask, step=None, L=None):
+ mask = compute_attribution_mask(obs_grad,self.quantile)
+ s_prime = augmentations.attribution_augmentation(obs.clone(), mask.float())
+
+ s_tilde = augmentations.random_overlay(obs.clone())
+ self.aux_optimizer.zero_grad()
+ pred_attrib, aux_loss = self.compute_attribution_loss(s_tilde,action, mask)
+ aux_loss.backward()
+ self.aux_optimizer.step()
+
+ if L is not None:
+ L.log("train/aux_loss", aux_loss, step)
+
+ if step % 10000 == 0:
+ self.log_tensorboard(obs, action, step, prefix="original")
+ self.log_tensorboard(s_tilde, action, step, prefix="augmented")
+ self.log_tensorboard(s_prime, action, step, prefix="super_augmented")
+
+ def log_tensorboard(self, obs, action, step, prefix="original"):
+ obs_grad = compute_attribution(self.critic, obs, action.detach())
+ mask = compute_attribution_mask(obs_grad, quantile=self.quantile)
+ attrib = self.attribution_predictor(obs.detach(),action.detach())
+ grid = make_obs_grid(obs)
+ self.writer.add_image(prefix + "/observation", grid, global_step=step)
+ grad_grid = make_obs_grad_grid(obs_grad.data.abs())
+ self.writer.add_image(prefix + "/attributions", grad_grid, global_step=step)
+ mask = torch.sigmoid(attrib)
+ mask = (mask > 0.5).float()
+ masked_obs = make_obs_grid(obs * mask)
+ self.writer.add_image(prefix + "/masked_obs{}", masked_obs, global_step=step)
+ attrib_grid = make_obs_grad_grid(torch.sigmoid(attrib))
+ self.writer.add_image(
+ prefix + "/predicted_attrib", attrib_grid, global_step=step
+ )
+ for q in [0.95, 0.975, 0.9, 0.995, 0.999]:
+ mask = compute_attribution_mask(obs_grad, quantile=q)
+ masked_obs = make_obs_grid(obs * mask)
+ self.writer.add_image(
+ prefix + "/attrib_q{}".format(q), masked_obs, global_step=step
+ )
+
+ def compute_attribution_loss(self, obs,action, mask):
+ mask = mask.float()
+ attrib = self.attribution_predictor(obs.detach(),action.detach())
+ aux_loss = F.binary_cross_entropy_with_logits(attrib, mask.detach())
+ return attrib, aux_loss
+
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample_drq()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+ obs_grad = compute_attribution(self.critic, obs, action.detach())
+ mask = compute_attribution_mask(obs_grad, quantile=self.quantile)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
+
+ if step % self.aux_update_freq == 0:
+ self.update_aux(obs, action, obs_grad, mask, step, L)
diff --git a/DMC/src/algorithms/soda.py b/DMC/src/algorithms/soda.py
new file mode 100644
index 0000000..21dcd68
--- /dev/null
+++ b/DMC/src/algorithms/soda.py
@@ -0,0 +1,84 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import algorithms.modules as m
+from algorithms.sac import SAC
+import augmentations
+
+
+class SODA(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+ self.aux_update_freq = args.aux_update_freq
+ self.soda_batch_size = args.soda_batch_size
+ self.soda_tau = args.soda_tau
+
+ shared_cnn = self.critic.encoder.shared_cnn
+ aux_cnn = self.critic.encoder.head_cnn
+ soda_encoder = m.Encoder(
+ shared_cnn,
+ aux_cnn,
+ m.SODAMLP(aux_cnn.out_shape[0], args.projection_dim, args.projection_dim)
+ )
+
+ self.predictor = m.SODAPredictor(soda_encoder, args.projection_dim).cuda()
+ self.predictor_target = deepcopy(self.predictor)
+
+ self.soda_optimizer = torch.optim.Adam(
+ self.predictor.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999)
+ )
+ self.train()
+
+ def train(self, training=True):
+ super().train(training)
+ if hasattr(self, 'soda_predictor'):
+ self.soda_predictor.train(training)
+
+ def compute_soda_loss(self, x0, x1):
+ h0 = self.predictor(x0)
+ with torch.no_grad():
+ h1 = self.predictor_target.encoder(x1)
+ h0 = F.normalize(h0, p=2, dim=1)
+ h1 = F.normalize(h1, p=2, dim=1)
+
+ return F.mse_loss(h0, h1)
+
+ def update_soda(self, replay_buffer, L=None, step=None):
+ x = replay_buffer.sample_soda(self.soda_batch_size)
+ assert x.size(-1) == 100
+
+ aug_x = x.clone()
+
+ x = augmentations.random_crop(x)
+ aug_x = augmentations.random_crop(aug_x)
+ aug_x = augmentations.random_overlay(aug_x)
+
+ soda_loss = self.compute_soda_loss(aug_x, x)
+
+ self.soda_optimizer.zero_grad()
+ soda_loss.backward()
+ self.soda_optimizer.step()
+ if L is not None:
+ L.log('train/aux_loss', soda_loss, step)
+
+ utils.soft_update_params(
+ self.predictor, self.predictor_target,
+ self.soda_tau
+ )
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
+
+ if step % self.aux_update_freq == 0:
+ self.update_soda(replay_buffer, L, step)
diff --git a/DMC/src/algorithms/svea.py b/DMC/src/algorithms/svea.py
new file mode 100644
index 0000000..e337b20
--- /dev/null
+++ b/DMC/src/algorithms/svea.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from copy import deepcopy
+import utils
+import augmentations
+import algorithms.modules as m
+from algorithms.sac import SAC
+
+
+class SVEA(SAC):
+ def __init__(self, obs_shape, action_shape, args):
+ super().__init__(obs_shape, action_shape, args)
+ self.svea_alpha = args.svea_alpha
+ self.svea_beta = args.svea_beta
+
+ def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None):
+ with torch.no_grad():
+ _, policy_action, log_pi, _ = self.actor(next_obs)
+ target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
+ target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi
+ target_Q = reward + (not_done * self.discount * target_V)
+
+ if self.svea_alpha == self.svea_beta:
+ obs = utils.cat(obs, augmentations.random_overlay(obs.clone()))
+ action = utils.cat(action, action)
+ target_Q = utils.cat(target_Q, target_Q)
+
+ current_Q1, current_Q2 = self.critic(obs, action)
+ critic_loss = (self.svea_alpha + self.svea_beta) * (
+ F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
+ )
+ else:
+ current_Q1, current_Q2 = self.critic(obs, action)
+ critic_loss = self.svea_alpha * (
+ F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
+ )
+
+ obs_aug = augmentations.random_overlay(obs.clone())
+ current_Q1_aug, current_Q2_aug = self.critic(obs_aug, action)
+ critic_loss += self.svea_beta * (
+ F.mse_loss(current_Q1_aug, target_Q)
+ + F.mse_loss(current_Q2_aug, target_Q)
+ )
+
+ if L is not None:
+ L.log("train_critic/loss", critic_loss, step)
+
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ def update(self, replay_buffer, L, step):
+ obs, action, reward, next_obs, not_done = replay_buffer.sample_drq()
+
+ self.update_critic(obs, action, reward, next_obs, not_done, L, step)
+
+ if step % self.actor_update_freq == 0:
+ self.update_actor_and_alpha(obs, L, step)
+
+ if step % self.critic_target_update_freq == 0:
+ self.soft_update_critic_target()
diff --git a/DMC/src/arguments.py b/DMC/src/arguments.py
new file mode 100644
index 0000000..1359266
--- /dev/null
+++ b/DMC/src/arguments.py
@@ -0,0 +1,127 @@
+import argparse
+import numpy as np
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+
+ # environment
+ parser.add_argument("--domain_name", default="walker")
+ parser.add_argument("--task_name", default="walk")
+ parser.add_argument("--frame_stack", default=3, type=int)
+ parser.add_argument("--action_repeat", default=4, type=int)
+ parser.add_argument("--episode_length", default=1000, type=int)
+ parser.add_argument("--eval_mode", default="color_hard", type=str)
+
+ # agent
+ parser.add_argument("--algorithm", default="sgsac", type=str)
+ parser.add_argument("--train_steps", default="500k", type=str)
+ parser.add_argument("--discount", default=0.99, type=float)
+ parser.add_argument("--init_steps", default=1000, type=int)
+ parser.add_argument("--batch_size", default=128, type=int)
+ parser.add_argument("--hidden_dim", default=1024, type=int)
+
+ # actor
+ parser.add_argument("--actor_lr", default=1e-3, type=float)
+ parser.add_argument("--actor_beta", default=0.9, type=float)
+ parser.add_argument("--actor_log_std_min", default=-10, type=float)
+ parser.add_argument("--actor_log_std_max", default=2, type=float)
+ parser.add_argument("--actor_update_freq", default=2, type=int)
+
+ # critic
+ parser.add_argument("--critic_lr", default=1e-3, type=float)
+ parser.add_argument("--critic_beta", default=0.9, type=float)
+ parser.add_argument("--critic_tau", default=0.01, type=float)
+ parser.add_argument("--critic_target_update_freq", default=2, type=int)
+ parser.add_argument("--critic_weight_decay", default=0, type=float)
+
+
+ # architecture
+ parser.add_argument("--num_shared_layers", default=11, type=int)
+ parser.add_argument("--num_head_layers", default=0, type=int)
+ parser.add_argument("--num_filters", default=32, type=int)
+ parser.add_argument("--projection_dim", default=100, type=int)
+ parser.add_argument("--encoder_tau", default=0.05, type=float)
+
+ # entropy maximization
+ parser.add_argument("--init_temperature", default=0.1, type=float)
+ parser.add_argument("--alpha_lr", default=1e-4, type=float)
+ parser.add_argument("--alpha_beta", default=0.5, type=float)
+
+ # auxiliary tasks
+ parser.add_argument("--aux_lr", default=3e-4, type=float)
+ parser.add_argument("--aux_beta", default=0.9, type=float)
+ parser.add_argument("--aux_update_freq", default=2, type=int)
+
+ # soda
+ parser.add_argument("--soda_batch_size", default=256, type=int)
+ parser.add_argument("--soda_tau", default=0.005, type=float)
+
+ # svea
+ parser.add_argument("--svea_alpha", default=0.5, type=float)
+ parser.add_argument("--svea_beta", default=0.5, type=float)
+ parser.add_argument("--sgqn_quantile", default=0.90, type=float)
+ parser.add_argument("--svea_contrastive_coeff", default=0.1, type=float)
+ parser.add_argument("--svea_norm_coeff", default=0.1, type=float)
+ parser.add_argument("--attrib_coeff", default=0.25, type=float)
+ parser.add_argument("--consistency", default=1, type=int)
+
+ # eval
+ parser.add_argument("--save_freq", default="100k", type=str)
+ parser.add_argument("--eval_freq", default="10k", type=str)
+ parser.add_argument("--eval_episodes", default=30, type=int)
+ parser.add_argument("--distracting_cs_intensity", default=0.0, type=float)
+
+ # misc
+ parser.add_argument("--seed", default=10081, type=int)
+ parser.add_argument("--log_dir", default="logs", type=str)
+ parser.add_argument("--save_video", default=False, action="store_true")
+
+ args = parser.parse_args()
+
+ assert args.algorithm in {
+ "sac",
+ "rad",
+ "curl",
+ "pad",
+ "soda",
+ "drq",
+ "svea",
+ "saca",
+ "sacfa",
+ "sgsac",
+ }, f'specified algorithm "{args.algorithm}" is not supported'
+
+ assert args.eval_mode in {
+ "train",
+ "color_easy",
+ "color_hard",
+ "video_easy",
+ "video_hard",
+ "distracting_cs",
+ "all",
+ "none",
+ }, f'specified mode "{args.eval_mode}" is not supported'
+ assert args.seed is not None, "must provide seed for experiment"
+ assert args.log_dir is not None, "must provide a log directory for experiment"
+
+ intensities = {0.0, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5}
+ assert (
+ args.distracting_cs_intensity in intensities
+ ), f"distracting_cs has only been implemented for intensities: {intensities}"
+
+ args.train_steps = int(args.train_steps.replace("k", "000"))
+ args.save_freq = int(args.save_freq.replace("k", "000"))
+ args.eval_freq = int(args.eval_freq.replace("k", "000"))
+
+ if args.eval_mode == "none":
+ args.eval_mode = None
+
+ if args.algorithm in {"rad", "curl", "pad", "soda"}:
+ args.image_size = 100
+ args.image_crop_size = 84
+ else:
+ args.image_size = 84
+ args.image_crop_size = 84
+
+ return args
diff --git a/DMC/src/augmentations.py b/DMC/src/augmentations.py
new file mode 100644
index 0000000..e40d20d
--- /dev/null
+++ b/DMC/src/augmentations.py
@@ -0,0 +1,253 @@
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as TF
+import torchvision.datasets as datasets
+import utils
+import os
+import kornia
+import random
+
+places_dataloader = None
+places_iter = None
+
+
+def _load_places(batch_size=256, image_size=84, num_workers=8, use_val=False):
+ global places_dataloader, places_iter
+ partition = "val" if use_val else "train"
+ print(f"Loading {partition} partition of places365_standard...")
+ for data_dir in utils.load_config("datasets"):
+ if os.path.exists(data_dir):
+ fp = os.path.join(data_dir, "places365_standard", partition)
+ if not os.path.exists(fp):
+ print(f"Warning: path {fp} does not exist, falling back to {data_dir}")
+ fp = data_dir
+ places_dataloader = torch.utils.data.DataLoader(
+ datasets.ImageFolder(
+ fp,
+ TF.Compose(
+ [
+ TF.RandomResizedCrop(image_size),
+ TF.RandomHorizontalFlip(),
+ TF.ToTensor(),
+ ]
+ ),
+ ),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ pin_memory=True,
+ )
+ places_iter = iter(places_dataloader)
+ break
+ if places_iter is None:
+ raise FileNotFoundError(
+ "failed to find places365 data at any of the specified paths"
+ )
+ print("Loaded dataset from", data_dir)
+
+
+def _get_places_batch(batch_size):
+ global places_iter
+ try:
+ imgs, _ = next(places_iter)
+ if imgs.size(0) < batch_size:
+ places_iter = iter(places_dataloader)
+ imgs, _ = next(places_iter)
+ except StopIteration:
+ places_iter = iter(places_dataloader)
+ imgs, _ = next(places_iter)
+ return imgs.cuda()
+
+
+def random_overlay(x, dataset="places365_standard"):
+ """Randomly overlay an image from Places"""
+ global places_iter
+ alpha = 0.5
+
+ if dataset == "places365_standard":
+ if places_dataloader is None:
+ _load_places(batch_size=x.size(0), image_size=x.size(-1))
+ imgs = _get_places_batch(batch_size=x.size(0)).repeat(1, x.size(1) // 3, 1, 1)
+ else:
+ raise NotImplementedError(
+ f'overlay has not been implemented for dataset "{dataset}"'
+ )
+
+ return ((1 - alpha) * (x / 255.0) + (alpha) * imgs) * 255.0
+
+
+def attribution_augmentation(x, mask, dataset="places365_standard"):
+ """Complete non importnant pixels with a random image from Places"""
+ global places_iter
+
+ if dataset == "places365_standard":
+ if places_dataloader is None:
+ _load_places(batch_size=x.size(0), image_size=x.size(-1))
+ imgs = _get_places_batch(batch_size=x.size(0)).repeat(1, x.size(1) // 3, 1, 1)
+ else:
+ raise NotImplementedError(
+ f'overlay has not been implemented for dataset "{dataset}"'
+ )
+
+ # s_plus = random_conv(x) * mask
+ s_plus = x * mask
+ s_tilde = (((s_plus) / 255.0) + (imgs * (torch.ones_like(mask) - mask))) * 255.0
+ s_minus = imgs * 255
+ return s_tilde
+
+
+def paired_aug(obs, mask):
+ mask = mask.float()
+ SEMANTIC = [kornia.augmentation.RandomAffine([-45., 45.], [0.3, 0.3], [0.5, 1.5], [0., 0.15]),kornia.augmentation.RandomErasing()]
+ no_sem = lambda x : random_overlay(x)
+ sem = random.sample(SEMANTIC,k=1)[0]
+ img_out = no_sem(sem(obs))
+
+ mask_out = sem(mask, sem._params)
+ return img_out, mask_out
+
+def attribution_random_patch_augmentation(
+ x,
+ cam,
+ image_size=84,
+ output_size=4,
+ quantile=0.90,
+ patch_proba=0.7,
+ dataset="places365_standard",
+):
+
+ if dataset == "places365_standard":
+ if places_dataloader is None:
+ _load_places(batch_size=x.size(0), image_size=x.size(-1))
+ negative = _get_places_batch(batch_size=x.size(0)).repeat(
+ 1, x.size(1) // 3, 1, 1
+ )
+ else:
+ raise NotImplementedError(
+ f'overlay has not been implemented for dataset "{dataset}"'
+ )
+ cam = cam.to(x.device)
+ cam = F.adaptive_avg_pool2d(cam, output_size=output_size)
+ q = torch.quantile(cam.flatten(1), quantile, 1)
+ mask = (cam >= q[:, None, None]).long()
+ exploration_mask = torch.rand(*mask.shape).to(x.device)
+ exploration_mask[~mask] = 0
+ expl_max = torch.amax(exploration_mask.view(mask.size(0), -1), dim=1)
+ exploration_mask = (
+ exploration_mask.view(-1, mask.size(1), mask.size(2)) == expl_max[:, None, None]
+ ).long()
+ bern = torch.bernoulli(torch.ones_like(mask) * patch_proba).long().to(x.device)
+ selected_patch = (mask * bern) + exploration_mask
+ selected_patch[selected_patch > 1] = 1
+ selected_patch = F.upsample_nearest(selected_patch.float().unsqueeze(1), image_size)
+ # augmented_x = (((0.5) * (x / 255.0) + (0.5) * negative) * 255.0) * selected_patch
+ augmented_x = x * selected_patch
+ complementary_mask = ~(selected_patch.bool())
+ negative = negative * (complementary_mask.float())
+ return augmented_x + (negative * 255)
+
+
+def blending_augmentation(x, mask, overlay=True):
+ s_plus, s_minus, s_tilde = attribution_augmentation(x, mask)
+ if overlay:
+ overlay_x = (1 - 0.5) * x + (0.5) * s_minus
+ s_tilde = overlay_x * (~mask) + s_plus
+ else:
+ s_tilde = s_minus * (~mask) + s_plus
+ return s_plus, s_minus, s_tilde
+
+
+def random_conv(x):
+ """Applies a random conv2d, deviates slightly from https://arxiv.org/abs/1910.05396"""
+ n, c, h, w = x.shape
+ for i in range(n):
+ weights = torch.randn(3, 3, 3, 3).to(x.device)
+ temp_x = x[i : i + 1].reshape(-1, 3, h, w) / 255.0
+ temp_x = F.pad(temp_x, pad=[1] * 4, mode="replicate")
+ out = torch.sigmoid(F.conv2d(temp_x, weights)) * 255.0
+ total_out = out if i == 0 else torch.cat([total_out, out], axis=0)
+ return total_out.reshape(n, c, h, w)
+
+
+def batch_from_obs(obs, batch_size=32):
+ """Copy a single observation along the batch dimension"""
+ if isinstance(obs, torch.Tensor):
+ if len(obs.shape) == 3:
+ obs = obs.unsqueeze(0)
+ return obs.repeat(batch_size, 1, 1, 1)
+
+ if len(obs.shape) == 3:
+ obs = np.expand_dims(obs, axis=0)
+ return np.repeat(obs, repeats=batch_size, axis=0)
+
+
+def prepare_pad_batch(obs, next_obs, action, batch_size=32):
+ """Prepare batch for self-supervised policy adaptation at test-time"""
+ batch_obs = batch_from_obs(torch.from_numpy(obs).cuda(), batch_size)
+ batch_next_obs = batch_from_obs(torch.from_numpy(next_obs).cuda(), batch_size)
+ batch_action = torch.from_numpy(action).cuda().unsqueeze(0).repeat(batch_size, 1)
+
+ return random_crop_cuda(batch_obs), random_crop_cuda(batch_next_obs), batch_action
+
+
+def identity(x):
+ return x
+
+
+def random_shift(imgs, pad=4):
+ """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
+ _, _, h, w = imgs.shape
+ imgs = F.pad(imgs, (pad, pad, pad, pad), mode="replicate")
+ return kornia.augmentation.RandomCrop((h, w))(imgs)
+
+
+def random_crop(x, size=84, w1=None, h1=None, return_w1_h1=False):
+ """Vectorized CUDA implementation of random crop, imgs: (B,C,H,W), size: output size"""
+ assert (w1 is None and h1 is None) or (
+ w1 is not None and h1 is not None
+ ), "must either specify both w1 and h1 or neither of them"
+ assert isinstance(x, torch.Tensor) and x.is_cuda, "input must be CUDA tensor"
+
+ n = x.shape[0]
+ img_size = x.shape[-1]
+ crop_max = img_size - size
+
+ if crop_max <= 0:
+ if return_w1_h1:
+ return x, None, None
+ return x
+
+ x = x.permute(0, 2, 3, 1)
+
+ if w1 is None:
+ w1 = torch.LongTensor(n).random_(0, crop_max)
+ h1 = torch.LongTensor(n).random_(0, crop_max)
+
+ windows = view_as_windows_cuda(x, (1, size, size, 1))[..., 0, :, :, 0]
+ cropped = windows[torch.arange(n), w1, h1]
+
+ if return_w1_h1:
+ return cropped, w1, h1
+
+ return cropped
+
+
+def view_as_windows_cuda(x, window_shape):
+ """PyTorch CUDA-enabled implementation of view_as_windows"""
+ assert isinstance(window_shape, tuple) and len(window_shape) == len(
+ x.shape
+ ), "window_shape must be a tuple with same number of dimensions as x"
+
+ slices = tuple(slice(None, None, st) for st in torch.ones(4).long())
+ win_indices_shape = [
+ x.size(0),
+ x.size(1) - int(window_shape[1]),
+ x.size(2) - int(window_shape[2]),
+ x.size(3),
+ ]
+
+ new_shape = tuple(list(win_indices_shape) + list(window_shape))
+ strides = tuple(list(x[slices].stride()) + list(x.stride()))
+
+ return x.as_strided(new_shape, strides)
diff --git a/DMC/src/env/distracting_control/background.py b/DMC/src/env/distracting_control/background.py
new file mode 100644
index 0000000..326ef63
--- /dev/null
+++ b/DMC/src/env/distracting_control/background.py
@@ -0,0 +1,265 @@
+# coding=utf-8
+# Copyright 2021 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""A wrapper for dm_control environments which applies color distractions."""
+import os
+
+from PIL import Image
+import collections
+from dm_control.rl import control
+import numpy as np
+
+import utils
+from dm_control.mujoco.wrapper import mjbindings
+
+DAVIS17_TRAINING_VIDEOS = [
+ 'bear', 'bmx-bumps', 'boat', 'boxing-fisheye', 'breakdance-flare', 'bus',
+ 'car-turn', 'cat-girl', 'classic-car', 'color-run', 'crossing',
+ 'dance-jump', 'dancing', 'disc-jockey', 'dog-agility', 'dog-gooses',
+ 'dogs-scale', 'drift-turn', 'drone', 'elephant', 'flamingo', 'hike',
+ 'hockey', 'horsejump-low', 'kid-football', 'kite-walk', 'koala',
+ 'lady-running', 'lindy-hop', 'longboard', 'lucia', 'mallard-fly',
+ 'mallard-water', 'miami-surf', 'motocross-bumps', 'motorbike', 'night-race',
+ 'paragliding', 'planes-water', 'rallye', 'rhino', 'rollerblade',
+ 'schoolgirls', 'scooter-board', 'scooter-gray', 'sheep', 'skate-park',
+ 'snowboard', 'soccerball', 'stroller', 'stunt', 'surf', 'swing', 'tennis',
+ 'tractor-sand', 'train', 'tuk-tuk', 'upside-down', 'varanus-cage', 'walking'
+]
+DAVIS17_VALIDATION_VIDEOS = [
+ 'bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel',
+ 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump',
+ 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high',
+ 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick',
+ 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs', 'scooter-black',
+ 'shooting', 'soapbox'
+]
+SKY_TEXTURE_INDEX = 0
+Texture = collections.namedtuple('Texture', ('size', 'address', 'textures'))
+
+
+def imread(filename):
+ img = Image.open(filename)
+ img_np = np.asarray(img)
+ return img_np
+
+
+def size_and_flatten(image, ref_height, ref_width):
+ # Resize image if necessary and flatten the result.
+ image_height, image_width = image.shape[:2]
+
+ if image_height != ref_height or image_width != ref_width:
+ image = np.asarray(Image.fromarray(image).resize(size=(ref_width, ref_height)))
+
+ #utils.save_obs(torch.from_numpy(image).permute(2,0,1), 'image')
+
+ return image.flatten(order='K')
+
+
+def blend_to_background(alpha, image, background):
+ if alpha == 1.0:
+ return image
+ elif alpha == 0.0:
+ return background
+ else:
+ return (alpha * image.astype(np.float32)
+ + (1. - alpha) * background.astype(np.float32)).astype(np.uint8)
+
+
+class DistractingBackgroundEnv(control.Environment):
+ """Environment wrapper for background visual distraction.
+
+ **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
+ the background image changes are applied before rendering occurs.
+ """
+
+ def __init__(self,
+ env,
+ dataset_paths=None,
+ dataset_videos=None,
+ video_alpha=1.0,
+ ground_plane_alpha=1.0,
+ num_videos=None,
+ dynamic=False,
+ seed=None,
+ shuffle_buffer_size=None):
+
+ if not 0 <= video_alpha <= 1:
+ raise ValueError('`video_alpha` must be in the range [0, 1]')
+
+ self._env = env
+ self._video_alpha = video_alpha
+ self._ground_plane_alpha = ground_plane_alpha
+ self._random_state = np.random.RandomState(seed=seed)
+ self._dynamic = dynamic
+ self._shuffle_buffer_size = shuffle_buffer_size
+ self._background = None
+ self._current_img_index = 0
+ self._frame_count = 0
+
+ if not dataset_paths or num_videos == 0:
+ # Allow running the wrapper without backgrounds to still set the ground
+ # plane alpha value.
+ self._video_paths = []
+ print('Warning: no dataset paths and/or number of videos set to 0!')
+ else:
+ # Use all videos if no specific ones were passed.
+ if not dataset_videos:
+ _, _, filenames = next(walk(mypath))
+ dataset_videos = sorted(filenames)
+ # Replace video placeholders 'train'/'val' with the list of videos.
+ elif dataset_videos in ['train', 'training']:
+ dataset_videos = DAVIS17_TRAINING_VIDEOS
+ elif dataset_videos in ['val', 'validation']:
+ dataset_videos = DAVIS17_VALIDATION_VIDEOS
+ # Get complete paths for all videos.
+ for dataset_path in dataset_paths:
+ video_paths = [
+ os.path.join(dataset_path, subdir) for subdir in dataset_videos
+ ]
+ if len(video_paths) > 0:
+ break
+ assert len(video_paths) > 0, 'DAVIS dataset not found!'
+
+ # Optionally use only the first num_paths many paths.
+ if num_videos is not None:
+ if num_videos > len(video_paths) or num_videos < 0:
+ raise ValueError(f'`num_bakground_paths` is {num_videos} but '
+ 'should not be larger than the number of available '
+ f'background paths ({len(video_paths)}) and at '
+ 'least 0.')
+ video_paths = video_paths[:num_videos]
+
+ self._video_paths = video_paths
+
+ def reset(self):
+ """Reset the background state."""
+ self._frame_count = 0
+ time_step = self._env.reset()
+ self._reset_background()
+ return time_step
+
+ def _reset_background(self):
+ # Make grid semi-transparent.
+ if self._ground_plane_alpha is not None:
+ self._env.physics.named.model.mat_rgba['grid',
+ 'a'] = self._ground_plane_alpha
+
+ # For some reason the height of the skybox is set to 4800 by default,
+ # which does not work with new textures.
+ self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] = 800
+
+ # Set the sky texture reference.
+ sky_height = self._env.physics.model.tex_height[SKY_TEXTURE_INDEX]
+ sky_width = self._env.physics.model.tex_width[SKY_TEXTURE_INDEX]
+ sky_size = sky_height * sky_width * 3
+ sky_address = self._env.physics.model.tex_adr[SKY_TEXTURE_INDEX]
+
+ sky_texture = self._env.physics.model.tex_rgb[sky_address:sky_address +
+ sky_size].astype(np.float32)
+
+ if self._video_paths:
+
+ if self._shuffle_buffer_size:
+ # Shuffle images from all videos together to get background frames.
+ file_names = [
+ os.path.join(path, fn)
+ for path in self._video_paths
+ for fn in utils.listdir(path)
+ ]
+ self._random_state.shuffle(file_names)
+ # Load only the first n images for performance reasons.
+ file_names = file_names[:self._shuffle_buffer_size]
+ images = [imread(fn) for fn in file_names]
+ else:
+ # Randomly pick a video and load all images.
+ video_path = self._random_state.choice(self._video_paths)
+ file_names = utils.listdir(video_path)
+ if not self._dynamic:
+ # Randomly pick a single static frame.
+ file_names = [self._random_state.choice(file_names)]
+ images = [imread(os.path.join(video_path, fn)) for fn in file_names]
+
+ # Pick a random starting point and steping direction.
+ self._current_img_index = self._random_state.choice(len(images))
+ self._step_direction = self._random_state.choice([-1, 1])
+
+ # Prepare images in the texture format by resizing and flattening.
+
+ # Generate image textures.
+ texturized_images = []
+ for image in images:
+ image_flattened = size_and_flatten(image, sky_height, sky_width)
+ new_texture = blend_to_background(self._video_alpha, image_flattened,
+ sky_texture)
+ texturized_images.append(new_texture)
+
+ else:
+
+ self._current_img_index = 0
+ texturized_images = [sky_texture]
+
+ self._background = Texture(sky_size, sky_address, texturized_images)
+ self._apply()
+
+ def step(self, action):
+ self._frame_count += 1
+ time_step = self._env.step(action)
+
+ if time_step.first():
+ self._reset_background()
+ return time_step
+
+ if self._dynamic and self._video_paths and self._frame_count % 2 == 0:
+ # Move forward / backward in the image sequence by updating the index.
+ self._current_img_index += self._step_direction
+
+ # Start moving forward if we are past the start of the images.
+ if self._current_img_index <= 0:
+ self._current_img_index = 0
+ self._step_direction = abs(self._step_direction)
+ # Start moving backwards if we are past the end of the images.
+ if self._current_img_index >= len(self._background.textures):
+ self._current_img_index = len(self._background.textures) - 1
+ self._step_direction = -abs(self._step_direction)
+
+ self._apply()
+ return time_step
+
+ def _apply(self):
+ """Apply the background texture to the physics."""
+
+ if self._background:
+ start = self._background.address
+ end = self._background.address + self._background.size
+ texture = self._background.textures[self._current_img_index]
+
+ self._env.physics.model.tex_rgb[start:end] = texture
+ # Upload the new texture to the GPU. Note: we need to make sure that the
+ # OpenGL context belonging to this Physics instance is the current one.
+ with self._env.physics.contexts.gl.make_current() as ctx:
+ ctx.call(
+ mjbindings.mjlib.mjr_uploadTexture,
+ self._env.physics.model.ptr,
+ self._env.physics.contexts.mujoco.ptr,
+ SKY_TEXTURE_INDEX,
+ )
+
+ # Forward property and method calls to self._env.
+ def __getattr__(self, attr):
+ if hasattr(self._env, attr):
+ return getattr(self._env, attr)
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
diff --git a/DMC/src/env/distracting_control/camera.py b/DMC/src/env/distracting_control/camera.py
new file mode 100644
index 0000000..3f91aac
--- /dev/null
+++ b/DMC/src/env/distracting_control/camera.py
@@ -0,0 +1,358 @@
+# coding=utf-8
+# Copyright 2021 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""A wrapper for dm_control environments which applies camera distractions."""
+
+import copy
+from dm_control.rl import control
+import numpy as np
+
+CAMERA_MODES = ['fixed', 'track', 'trackcom', 'targetbody', 'targetbodycom']
+
+
+def eul2mat(theta):
+ """Converts euler angles (x, y, z) to a rotation matrix."""
+
+ return np.array([[
+ np.cos(theta[1]) * np.cos(theta[2]),
+ np.sin(theta[0]) * np.sin(theta[1]) * np.cos(theta[2]) -
+ np.sin(theta[2]) * np.cos(theta[0]),
+ np.sin(theta[1]) * np.cos(theta[0]) * np.cos(theta[2]) +
+ np.sin(theta[0]) * np.sin(theta[2])
+ ],
+ [
+ np.sin(theta[2]) * np.cos(theta[1]),
+ np.sin(theta[0]) * np.sin(theta[1]) * np.sin(theta[2]) +
+ np.cos(theta[0]) * np.cos(theta[2]),
+ np.sin(theta[1]) * np.sin(theta[2]) * np.cos(theta[0]) -
+ np.sin(theta[0]) * np.cos(theta[2])
+ ],
+ [
+ -np.sin(theta[1]),
+ np.sin(theta[0]) * np.cos(theta[1]),
+ np.cos(theta[0]) * np.cos(theta[1])
+ ]])
+
+
+def _mat_from_theta(cos_theta, sin_theta, a):
+ """Builds a rotation matrix from theta and an orientation vector."""
+
+ row1 = [
+ cos_theta + a[0]**2. * (1. - cos_theta),
+ a[0] * a[1] * (1 - cos_theta) - a[2] * sin_theta,
+ a[0] * a[2] * (1 - cos_theta) + a[1] * sin_theta
+ ]
+ row2 = [
+ a[1] * a[0] * (1 - cos_theta) + a[2] * sin_theta,
+ cos_theta + a[1]**2. * (1 - cos_theta),
+ a[1] * a[2] * (1. - cos_theta) - a[0] * sin_theta
+ ]
+ row3 = [
+ a[2] * a[0] * (1. - cos_theta) - a[1] * sin_theta,
+ a[2] * a[1] * (1. - cos_theta) + a[0] * sin_theta,
+ cos_theta + (a[2]**2.) * (1. - cos_theta)
+ ]
+ return np.stack([row1, row2, row3])
+
+
+def rotvec2mat(theta, vec):
+ """Converts a rotation around a vector to a rotation matrix."""
+
+ a = vec / np.sqrt(np.sum(vec**2.))
+ sin_theta = np.sin(theta)
+ cos_theta = np.cos(theta)
+
+ return _mat_from_theta(cos_theta, sin_theta, a)
+
+
+def get_lookat_xmat_no_roll(agent_pos, camera_pos):
+ """Solves for the cam rotation centering the agent with 0 roll."""
+
+ # NOTE(austinstone): This method leads to wild oscillations around the north
+ # and south polls.
+ # For example, if agent is at (0., 0., 0.) and the camera is at (.01, 0., 1.),
+ # this will produce a yaw of 90 degrees whereas if the camera is slightly
+ # adjacent at (-.01, 0., 1.) this will produce a yaw of -90 degrees. I'm
+ # not sure what the fix is, as this seems like the behavior we want in all
+ # environments except for reacher.
+ delta_vec = agent_pos - camera_pos
+ delta_vec /= np.sqrt(np.sum(delta_vec**2.))
+ yaw = np.arctan2(delta_vec[0], delta_vec[1])
+ pitch = np.arctan2(delta_vec[2], np.sqrt(np.sum(delta_vec[:2]**2.)))
+ pitch += np.pi / 2. # Camera starts out looking at [0, 0, -1.]
+ return eul2mat([pitch, 0., -yaw]).flatten()
+
+
+def get_lookat_xmat(agent_pos, camera_pos):
+ """Solves for the cam rotation centering the agent, allowing roll."""
+
+ # Solve for the rotation which centers the agent in the scene.
+ delta_vec = agent_pos - camera_pos
+ delta_vec /= np.sqrt(np.sum(delta_vec**2.))
+ y_vec = np.array([0., 0., -1.]) # This is where the cam starts from.
+ a = np.cross(y_vec, delta_vec)
+ sin_theta = np.sqrt(np.sum(a**2.))
+ cos_theta = np.dot(delta_vec, y_vec)
+ a /= (np.sqrt(np.sum(a**2.)) + .0001)
+ return _mat_from_theta(cos_theta, sin_theta, a)
+
+
+def cart2sphere(cart):
+ r = np.sqrt(np.sum(cart**2.))
+ h_angle = np.arctan2(cart[1], cart[0])
+ v_angle = np.arctan2(np.sqrt(np.sum(cart[:2]**2.)), cart[2])
+ return np.array([r, h_angle, v_angle])
+
+
+def sphere2cart(sphere):
+ r, h_angle, v_angle = sphere
+ x = r * np.sin(v_angle) * np.cos(h_angle)
+ y = r * np.sin(v_angle) * np.sin(h_angle)
+ z = r * np.cos(v_angle)
+ return np.array([x, y, z])
+
+
+def clip_cam_position(position, min_radius, max_radius, min_h_angle,
+ max_h_angle, min_v_angle, max_v_angle):
+ new_position = [-1., -1., -1.]
+ new_position[0] = np.clip(position[0], min_radius, max_radius)
+ new_position[1] = np.clip(position[1], min_h_angle, max_h_angle)
+ new_position[2] = np.clip(position[2], min_v_angle, max_v_angle)
+ return new_position
+
+
+def get_lookat_point(physics, camera_id):
+ """Get the point that the camera is looking at.
+
+ It is assumed that the "point" the camera looks at the agent distance
+ away and projected along the camera viewing matrix.
+
+ Args:
+ physics: mujoco physics objects
+ camera_id: int
+
+ Returns:
+ position: float32 np.array of length 3
+ """
+ dist_to_agent = physics.named.data.cam_xpos[
+ camera_id] - physics.named.data.subtree_com[1]
+ dist_to_agent = np.sqrt(np.sum(dist_to_agent**2.))
+ initial_viewing_mat = copy.deepcopy(physics.named.data.cam_xmat[camera_id])
+ initial_viewing_mat = np.reshape(initial_viewing_mat, (3, 3))
+ z_vec = np.array([0., 0., -dist_to_agent])
+ rotated_vec = np.dot(initial_viewing_mat, z_vec)
+ return rotated_vec + physics.named.data.cam_xpos[camera_id]
+
+
+class DistractingCameraEnv(control.Environment):
+ """Environment wrapper for camera pose visual distraction.
+
+ **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
+ the camera pose changes are applied before rendering occurs.
+ """
+
+ def __init__(self,
+ env,
+ camera_id,
+ horizontal_delta,
+ vertical_delta,
+ max_vel,
+ vel_std,
+ roll_delta,
+ max_roll_vel,
+ roll_std,
+ max_zoom_in_percent,
+ max_zoom_out_percent,
+ limit_to_upper_quadrant=False,
+ seed=None):
+ self._env = env
+ self._camera_id = camera_id
+ self._horizontal_delta = horizontal_delta
+ self._vertical_delta = vertical_delta
+
+ self._horizontal_delta = horizontal_delta
+ self._vertical_delta = vertical_delta
+ self._max_vel = max_vel
+ self._vel_std = vel_std
+ self._roll_delta = roll_delta
+ self._max_roll_vel = max_roll_vel
+ self._roll_vel_std = roll_std
+ self._max_zoom_in_percent = max_zoom_in_percent
+ self._max_zoom_out_percent = max_zoom_out_percent
+ self._limit_to_upper_quadrant = limit_to_upper_quadrant
+
+ self._random_state = np.random.RandomState(seed=seed)
+
+ # These camera state parameters will be set on the first reset call.
+ self._camera_type = None
+ self._camera_initial_lookat_point = None
+
+ self._camera_vel = None
+ self._max_h_angle = None
+ self._max_v_angle = None
+ self._min_h_angle = None
+ self._min_v_angle = None
+ self._radius = None
+ self._roll_vel = None
+ self._vel_scaling = None
+ self._frame_count = 0
+
+ def setup_camera(self):
+ """Set up camera motion ranges and state."""
+ # Define boundaries on the range of the camera motion.
+ mode = self._env._physics.model.cam_mode[0]
+
+ camera_type = CAMERA_MODES[mode]
+ assert camera_type in ['fixed', 'trackcom']
+
+ self._camera_type = camera_type
+ self._cam_initial_lookat_point = get_lookat_point(self._env.physics,
+ self._camera_id)
+
+ start_pos = copy.deepcopy(
+ self._env.physics.named.data.cam_xpos[self._camera_id])
+
+ if self._camera_type != 'fixed':
+ # Center the camera relative to the agent's center of mass.
+ start_pos -= self._env.physics.named.data.subtree_com[1]
+
+ start_r, start_h_angle, start_v_angle = cart2sphere(start_pos)
+ # Scale the velocity by the starting radius. Most environments have radius 4,
+ # but this downscales the velocity for the envs with radius < 4.
+ self._vel_scaling = start_r / 4.
+ self._max_h_angle = start_h_angle + self._horizontal_delta
+ self._min_h_angle = start_h_angle - self._horizontal_delta
+ self._max_v_angle = start_v_angle + self._vertical_delta
+ self._min_v_angle = start_v_angle - self._vertical_delta
+
+ if self._limit_to_upper_quadrant:
+ # A centered cam is at np.pi / 2.
+ self._max_v_angle = min(self._max_v_angle, np.pi / 2.)
+ self._min_v_angle = max(self._min_v_angle, 0.)
+ # A centered cam is at -np.pi / 2.
+ self._max_h_angle = min(self._max_h_angle, 0.)
+ self._min_h_angle = max(self._min_h_angle, -np.pi)
+
+ self._max_roll = self._roll_delta
+ self._min_roll = -self._roll_delta
+ self._min_radius = max(start_r - start_r * self._max_zoom_in_percent, 0.)
+ self._max_radius = start_r + start_r * self._max_zoom_out_percent
+
+ # Decide the starting position for the camera.
+ self._h_angle = self._random_state.uniform(self._min_h_angle,
+ self._max_h_angle)
+
+ self._v_angle = self._random_state.uniform(self._min_v_angle,
+ self._max_v_angle)
+
+ self._radius = self._random_state.uniform(self._min_radius,
+ self._max_radius)
+
+ self._roll = self._random_state.uniform(self._min_roll, self._max_roll)
+
+ # Decide the starting velocity for the camera.
+ vel = self._random_state.randn(3)
+ vel /= np.sqrt(np.sum(vel**2.))
+ vel *= self._random_state.uniform(0., self._max_vel)
+ self._camera_vel = vel
+ self._roll_vel = self._random_state.uniform(-self._max_roll_vel,
+ self._max_roll_vel)
+
+ def reset(self):
+ """Reset the camera state. """
+ self._frame_count = 0
+ time_step = self._env.reset()
+ self.setup_camera()
+ self._apply()
+ return time_step
+
+
+ def step(self, action):
+ self._frame_count += 1
+ time_step = self._env.step(action)
+
+ if time_step.first():
+ self.setup_camera()
+
+ if self._frame_count % 2 == 0:
+ self._apply()
+ return time_step
+
+ def _apply(self):
+ if not self._camera_type:
+ self.setup_camera()
+
+ # Random walk the velocity.
+ vel_delta = self._random_state.randn(3)
+ self._camera_vel += vel_delta * self._vel_std * self._vel_scaling
+ self._roll_vel += self._random_state.randn() * self._roll_vel_std
+
+ # Clip velocity if it gets too big.
+ vel_norm = np.sqrt(np.sum(self._camera_vel**2.))
+ if vel_norm > self._max_vel * self._vel_scaling:
+ self._camera_vel *= (self._max_vel * self._vel_scaling) / vel_norm
+
+ self._roll_vel = np.clip(self._roll_vel, -self._max_roll_vel,
+ self._max_roll_vel)
+
+ cart_cam_pos = sphere2cart([self._radius, self._h_angle, self._v_angle])
+ # Apply velocity vector to camera
+ sphere_cam_pos2 = cart2sphere(cart_cam_pos + self._camera_vel)
+ sphere_cam_pos2 = clip_cam_position(sphere_cam_pos2, self._min_radius,
+ self._max_radius, self._min_h_angle,
+ self._max_h_angle, self._min_v_angle,
+ self._max_v_angle)
+
+ self._camera_vel = sphere2cart(sphere_cam_pos2) - cart_cam_pos
+
+ self._radius, self._h_angle, self._v_angle = sphere_cam_pos2
+
+ roll2 = self._roll + self._roll_vel
+ roll2 = np.clip(roll2, self._min_roll, self._max_roll)
+
+ self._roll_vel = roll2 - self._roll
+ self._roll = roll2
+
+ cart_cam_pos = sphere2cart(sphere_cam_pos2)
+
+ if self._limit_to_upper_quadrant:
+ lookat_method = get_lookat_xmat_no_roll
+ else:
+ # This method avoids jitteriness at the pole but allows some roll
+ # in the camera matrix. This is important for reacher.
+ lookat_method = get_lookat_xmat
+
+ if self._camera_type == 'fixed':
+ lookat_mat = lookat_method(self._cam_initial_lookat_point,
+ cart_cam_pos)
+ else:
+ # Go from agent centric to world coords
+ cart_cam_pos += self._env.physics.named.data.subtree_com[1]
+ lookat_mat = lookat_method(
+ get_lookat_point(self._env.physics, self._camera_id), cart_cam_pos)
+
+ lookat_mat = np.reshape(lookat_mat, (3, 3))
+ roll_mat = rotvec2mat(self._roll, np.array([0., 0., 1.]))
+ xmat = np.dot(lookat_mat, roll_mat)
+ self._env.physics.named.data.cam_xpos[self._camera_id] = cart_cam_pos
+ self._env.physics.named.data.cam_xmat[self._camera_id] = xmat.flatten()
+
+ # Forward property and method calls to self._env.
+ def __getattr__(self, attr):
+ if hasattr(self._env, attr):
+ return getattr(self._env, attr)
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
diff --git a/DMC/src/env/distracting_control/color.py b/DMC/src/env/distracting_control/color.py
new file mode 100644
index 0000000..878c713
--- /dev/null
+++ b/DMC/src/env/distracting_control/color.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2021 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""A wrapper for dm_control environments which applies color distractions."""
+
+from dm_control.rl import control
+import numpy as np
+
+
+class DistractingColorEnv(control.Environment):
+ """Environment wrapper for color visual distraction.
+
+ **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure
+ the color changes are applied before rendering occurs.
+ """
+
+ def __init__(self, env, step_std, max_delta, seed=None):
+ """Initialize the environment wrapper.
+
+ Args:
+ env: instance of dm_control Environment to wrap with augmentations.
+ """
+ if step_std < 0:
+ raise ValueError("`step_std` must be greater than or equal to 0.")
+ if max_delta < 0:
+ raise ValueError("`max_delta` must be greater than or equal to 0.")
+
+ self._env = env
+ self._step_std = step_std
+ self._max_delta = max_delta
+ self._random_state = np.random.RandomState()
+
+ self._cam_type = None
+ self._current_rgb = None
+ self._max_rgb = None
+ self._min_rgb = None
+ self._original_rgb = None
+ self._frame_count = 0
+
+ def reset(self):
+ """Reset the distractions state."""
+ self._frame_count = 0
+ time_step = self._env.reset()
+ self._reset_color()
+ return time_step
+
+ def _reset_color(self):
+ # Save all original colors.
+ if self._original_rgb is None:
+ self._original_rgb = np.copy(self._env.physics.model.mat_rgba)[:, :3]
+ # Determine minimum and maximum rgb values.
+ self._max_rgb = np.clip(self._original_rgb + self._max_delta, 0.0, 1.0)
+ self._min_rgb = np.clip(self._original_rgb - self._max_delta, 0.0, 1.0)
+
+ # Pick random colors in the allowed ranges.
+ r = self._random_state.uniform(size=self._min_rgb.shape)
+ self._current_rgb = self._min_rgb + r * (self._max_rgb - self._min_rgb)
+
+ # Apply the color changes.
+ self._env.physics.model.mat_rgba[:, :3] = self._current_rgb
+
+ def step(self, action):
+ self._frame_count += 1
+ time_step = self._env.step(action)
+
+ if time_step.first():
+ self._reset_color()
+ return time_step
+
+ if self._frame_count % 2 == 0:
+ color_change = self._random_state.randn(*self._current_rgb.shape)
+ color_change = color_change * self._step_std
+ else:
+ color_change = 0
+
+ new_color = self._current_rgb + color_change
+
+ self._current_rgb = np.clip(
+ new_color,
+ a_min=self._min_rgb,
+ a_max=self._max_rgb,
+ )
+
+ # Apply the color changes.
+ self._env.physics.model.mat_rgba[:, :3] = self._current_rgb
+ return time_step
+
+ # Forward property and method calls to self._env.
+ def __getattr__(self, attr):
+ if hasattr(self._env, attr):
+ return getattr(self._env, attr)
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
diff --git a/DMC/src/env/distracting_control/suite.py b/DMC/src/env/distracting_control/suite.py
new file mode 100644
index 0000000..3ff0d45
--- /dev/null
+++ b/DMC/src/env/distracting_control/suite.py
@@ -0,0 +1,154 @@
+# coding=utf-8
+# Copyright 2021 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A collection of MuJoCo-based Reinforcement Learning environments.
+
+The suite provides a similar API to the original dm_control suite.
+Users can configure the distractions on top of the original tasks. The suite is
+targeted for loading environments directly with similar configurations as those
+used in the original paper. Each distraction wrapper can be used independently
+though.
+"""
+try:
+ from dm_control import suite # pylint: disable=g-import-not-at-top
+ from dm_control.suite.wrappers import pixels # pylint: disable=g-import-not-at-top
+except ImportError:
+ suite = None
+
+from env.distracting_control import background
+from env.distracting_control import camera
+from env.distracting_control import color
+from env.distracting_control import suite_utils
+
+
+def is_available():
+ return suite is not None
+
+
+def load(domain_name,
+ task_name,
+ difficulty=None,
+ dynamic=False,
+ background_dataset_paths=None,
+ background_dataset_videos="train",
+ background_kwargs=None,
+ camera_kwargs=None,
+ color_kwargs=None,
+ task_kwargs=None,
+ environment_kwargs=None,
+ visualize_reward=False,
+ render_kwargs=None,
+ pixels_only=True,
+ pixels_observation_key="pixels"):
+ """Returns an environment from a domain name, task name and optional settings.
+
+ ```python
+ env = suite.load('cartpole', 'balance')
+ ```
+
+ Note that this benchmark differs from the original implementation of Distracting Control Suite.
+ Notably, the `difficulty` argument has been changed, as well as changes to rate of change in the dynamic setting.
+ Tensorflow dependencies have also been dropped and replaced by PyTorch/NumPy equivalents.
+ The original codebase is available at https://github.com/google-research/google-research/tree/master/distracting_control.
+
+ Users can also toggle dynamic properties for distractions.
+
+ Args:
+ domain_name: A string containing the name of a domain.
+ task_name: A string containing the name of a task.
+ task_kwargs: Optional `dict` of keyword arguments for the task.
+ difficulty: Difficulty for the suite. Intensity scale, [0,1].
+ dynamic: Boolean controlling whether distractions are dynamic or static.
+ backgound_dataset_path: String to the davis directory that contains the
+ video directories.
+ background_dataset_videos: String ('train'/'val') or list of strings of the
+ DAVIS videos to be used for backgrounds.
+ background_kwargs: Dict, overwrites settings for background distractions.
+ camera_kwargs: Dict, overwrites settings for camera distractions.
+ color_kwargs: Dict, overwrites settings for color distractions.
+ task_kwargs: Dict, dm control task kwargs.
+ environment_kwargs: Optional `dict` specifying keyword arguments for the
+ environment.
+ visualize_reward: Optional `bool`. If `True`, object colours in rendered
+ frames are set to indicate the reward at each step. Default `False`.
+ render_kwargs: Dict, render kwargs for pixel wrapper.
+ pixels_only: Boolean controlling the exclusion of states in the observation.
+ pixels_observation_key: Key in the observation used for the rendered image.
+
+ Returns:
+ The requested environment.
+ """
+ if not is_available():
+ raise ImportError("dm_control module is not available. Make sure you "
+ "follow the installation instructions from the "
+ "dm_control package.")
+
+ assert isinstance(difficulty, float), f'intensity must be a float'
+ assert difficulty >= 0 and difficulty <= 1, f'intensity must be in the [0,1] interval'
+ assert str(difficulty) in suite_utils.DIFFICULTY_NUM_VIDEOS.keys(), \
+ f'intensity has only been implemented for the following values: {suite_utils.DIFFICULTY_NUM_VIDEOS.keys()}'
+ render_kwargs = render_kwargs or {}
+ if "camera_id" not in render_kwargs:
+ render_kwargs["camera_id"] = 2 if domain_name == "quadruped" else 0
+
+ env = suite.load(
+ domain_name,
+ task_name,
+ task_kwargs=task_kwargs,
+ environment_kwargs=environment_kwargs,
+ visualize_reward=visualize_reward)
+
+ # Apply background distractions.
+ if difficulty:
+ final_background_kwargs = dict()
+ num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS[str(difficulty)]
+ final_background_kwargs.update(
+ suite_utils.get_background_kwargs(domain_name, num_videos, dynamic,
+ background_dataset_paths,
+ background_dataset_videos))
+ if background_kwargs:
+ # Overwrite kwargs with those passed here.
+ final_background_kwargs.update(background_kwargs)
+ env = background.DistractingBackgroundEnv(env, **final_background_kwargs)
+
+ # Apply camera distractions.
+ if difficulty:
+ final_camera_kwargs = dict(camera_id=render_kwargs["camera_id"])
+ final_camera_kwargs.update(
+ suite_utils.get_camera_kwargs(domain_name, difficulty, dynamic))
+ if camera_kwargs:
+ # Overwrite kwargs with those passed here.
+ final_camera_kwargs.update(camera_kwargs)
+ env = camera.DistractingCameraEnv(env, **final_camera_kwargs)
+
+ # Apply color distractions.
+ if difficulty:
+ final_color_kwargs = dict()
+ final_color_kwargs.update(suite_utils.get_color_kwargs(difficulty, dynamic))
+ if color_kwargs:
+ # Overwrite kwargs with those passed here.
+ final_color_kwargs.update(color_kwargs)
+ env = color.DistractingColorEnv(env, **final_color_kwargs)
+
+ # Apply Pixel wrapper after distractions. This is needed to ensure the
+ # changes from the distraction wrapper are applied to the MuJoCo environment
+ # before the rendering occurs.
+ env = pixels.Wrapper(
+ env,
+ pixels_only=pixels_only,
+ render_kwargs=render_kwargs,
+ observation_key=pixels_observation_key)
+
+ return env
diff --git a/DMC/src/env/distracting_control/suite_utils.py b/DMC/src/env/distracting_control/suite_utils.py
new file mode 100644
index 0000000..9a262ed
--- /dev/null
+++ b/DMC/src/env/distracting_control/suite_utils.py
@@ -0,0 +1,80 @@
+# coding=utf-8
+# Copyright 2021 The Google Research Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A collection of MuJoCo-based Reinforcement Learning environments.
+
+The suite provides a similar API to the original dm_control suite.
+Users can configure the distractions on top of the original tasks. The suite is
+targeted for loading environments directly with similar configurations as those
+used in the original paper. Each distraction wrapper can be used independently
+though.
+"""
+import numpy as np
+
+DIFFICULTY_NUM_VIDEOS = {'0.025': 2, '0.05': 2, '0.1': 4, '0.15': 6, '0.2': 8, '0.3': None, '0.4': None, '0.5': None}
+DEFAULT_BACKGROUND_PATH = "$HOME/davis/"
+
+
+def get_color_kwargs(scale, dynamic):
+ max_delta = scale
+ step_std = 0.03 * scale if dynamic else 0.0
+ return dict(max_delta=max_delta, step_std=step_std)
+
+
+def get_camera_kwargs(domain_name, scale, dynamic):
+ assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah',
+ 'ball_in_cup', 'walker']
+ assert scale >= 0.0
+ assert scale <= 1.0
+ return dict(
+ vertical_delta=np.pi / 2 * scale,
+ horizontal_delta=np.pi / 2 * scale,
+ # Limit camera to -90 / 90 degree rolls.
+ roll_delta=np.pi / 2. * scale,
+ vel_std=.1 * scale if dynamic else 0.,
+ max_vel=.4 * scale if dynamic else 0.,
+ roll_std=np.pi / 300 * scale if dynamic else 0.,
+ max_roll_vel=np.pi / 50 * scale if dynamic else 0.,
+ max_zoom_in_percent=.5 * scale,
+ max_zoom_out_percent=1.5 * scale,
+ limit_to_upper_quadrant='reacher' not in domain_name,
+ )
+
+
+def get_background_kwargs(domain_name,
+ num_videos,
+ dynamic,
+ dataset_paths,
+ dataset_videos=None,
+ shuffle=False,
+ video_alpha=1.0):
+ assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah',
+ 'ball_in_cup', 'walker']
+ if domain_name == 'reacher':
+ ground_plane_alpha = 0.0
+ elif domain_name in ['walker', 'cheetah']:
+ ground_plane_alpha = 1.0
+ else:
+ ground_plane_alpha = 0.3
+
+ return dict(
+ num_videos=num_videos,
+ video_alpha=video_alpha,
+ ground_plane_alpha=ground_plane_alpha,
+ dynamic=dynamic,
+ dataset_paths=dataset_paths,
+ dataset_videos=dataset_videos,
+ shuffle_buffer_size=100 if shuffle else None,
+ )
diff --git a/DMC/src/env/dm_control/AUTHORS b/DMC/src/env/dm_control/AUTHORS
new file mode 100644
index 0000000..e72ca79
--- /dev/null
+++ b/DMC/src/env/dm_control/AUTHORS
@@ -0,0 +1,7 @@
+# This is the list of dm_control authors for copyright purposes.
+#
+# This does not necessarily list everyone who has contributed code, since in
+# some cases, their employer may be the copyright holder. To see the full list
+# of contributors, see the revision history in source control.
+
+DeepMind Technologies
diff --git a/DMC/src/env/dm_control/CONTRIBUTING.md b/DMC/src/env/dm_control/CONTRIBUTING.md
new file mode 100644
index 0000000..ae319c7
--- /dev/null
+++ b/DMC/src/env/dm_control/CONTRIBUTING.md
@@ -0,0 +1,23 @@
+# How to Contribute
+
+We'd love to accept your patches and contributions to this project. There are
+just a few small guidelines you need to follow.
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution,
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult
+[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
diff --git a/DMC/src/env/dm_control/LICENSE b/DMC/src/env/dm_control/LICENSE
new file mode 100644
index 0000000..7a4a3ea
--- /dev/null
+++ b/DMC/src/env/dm_control/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/DMC/src/env/dm_control/README.md b/DMC/src/env/dm_control/README.md
new file mode 100644
index 0000000..a352c4b
--- /dev/null
+++ b/DMC/src/env/dm_control/README.md
@@ -0,0 +1,146 @@
+# `dm_control`: The DeepMind Control Suite and Package
+
+# ![all domains](all_domains.png)
+
+This package consists of the following "core" components:
+
+- [`dm_control.mujoco`]: Libraries that provide Python bindings to the MuJoCo
+ physics engine.
+
+- [`dm_control.suite`]: A set of Python Reinforcement Learning environments
+ powered by the MuJoCo physics engine.
+
+- [`dm_control.viewer`]: An interactive environment viewer.
+
+Additionally, the following components are available for the creation of more
+complex control tasks:
+
+- [`dm_control.mjcf`]: A library for composing and modifying MuJoCo MJCF
+ models in Python.
+
+- `dm_control.composer`: A library for defining rich RL environments from
+ reusable, self-contained components.
+
+- [`dm_control.locomotion`]: Additional libraries for custom tasks.
+
+- [`dm_control.locomotion.soccer`]: Multi-agent soccer tasks.
+
+If you use this package, please cite our accompanying [tech report]:
+
+```
+@techreport{deepmindcontrolsuite2018,
+ title = {Deep{Mind} Control Suite},
+ author = {Yuval Tassa and Yotam Doron and Alistair Muldal and Tom Erez
+ and Yazhe Li and Diego de Las Casas and David Budden and Abbas
+ Abdolmaleki and Josh Merel and Andrew Lefrancq and Timothy Lillicrap
+ and Martin Riedmiller},
+ year = 2018,
+ month = jan,
+ howpublished = {https://arxiv.org/abs/1801.00690},
+ url = {https://arxiv.org/abs/1801.00690},
+ volume = {abs/1504.04804},
+ institution = {DeepMind},
+}
+```
+
+## Requirements and Installation
+
+`dm_control` is regularly tested using the following platforms and Python
+versions:
+
+| | Python 2.7 | Python 3.5 |
+| ------------ | ---------- | ---------- |
+| Ubuntu 14.04 | ✓ | ✓ |
+| Ubuntu 16.04 | | ✓ |
+
+Various people have been successful in getting `dm_control` to work on other
+Linux distros, OS X, and Windows. We do not provide active support for these,
+but will endeavour to answer questions on a best-effort basis.
+
+Follow these steps to install `dm_control`:
+
+1. Download MuJoCo Pro 2.00 from the download page on the [MuJoCo website].
+ MuJoCo Pro must be installed before `dm_control`, since `dm_control`'s
+ install script generates Python [`ctypes`] bindings based on MuJoCo's header
+ files. By default, `dm_control` assumes that the MuJoCo Zip archive is
+ extracted as `~/.mujoco/mujoco200_$PLATFORM` where `$PLATFORM` is either
+ `linux`, `win64`, or `macos`.
+
+2. Install the `dm_control` Python package by running `pip install dm_control`.
+ We recommend `pip install`ing into a `virtualenv`, or with the `--user` flag
+ to avoid interfering with system packages. At installation time,
+ `dm_control` looks for the MuJoCo headers from Step 1 in
+ `~/.mujoco/mujoco200_$PLATFORM/include`, however this path can be configured
+ with the `headers-dir` command line argument.
+
+3. Install a license key for MuJoCo, required by `dm_control` at runtime. See
+ the [MuJoCo license key page] for further details. By default, `dm_control`
+ looks for the MuJoCo license key file at `~/.mujoco/mjkey.txt`.
+
+4. If the license key (e.g. `mjkey.txt`) or the shared library provided by
+ MuJoCo Pro (e.g. `libmujoco200.so` or `libmujoco200.dylib`) are installed at
+ non-default paths, specify their locations using the `MJKEY_PATH` and
+ `MJLIB_PATH` environment variables respectively. These environment variables
+ should be set to the full path to the relevant file itself, e.g.
+ `export MJLIB_PATH=/path/to/libmujoco200.so`.
+
+## Versioning
+
+`dm_control` is released on a rolling basis: the latest commit on the `master`
+branch of our GitHub repository represents our latest release. Our Python
+package is versioned `0.0.N`, where `N` is the number that appears in the
+`PiperOrigin-RevId` field of the commit message. We always ensure that `N`
+strictly increases between a parent commit and its children. We do not upload
+all versions to PyPI, and occasionally the latest version on PyPI may lag behind
+the latest commit on GitHub. Should this happen, you can still install the
+newest version available by running `pip install
+git+git://github.com/deepmind/dm_control.git`.
+
+## Rendering
+
+The MuJoCo Python bindings support three different OpenGL rendering backends:
+EGL (headless, hardware-accelerated), GLFW (windowed, hardware-accelerated), and
+OSMesa (purely software-based). At least one of these three backends must be
+available in order render through `dm_control`.
+
+* Hardware rendering with a windowing system is supported via GLFW and GLEW.
+ On Linux these can be installed using your distribution's package manager.
+ For example, on Debian and Ubuntu, this can be done by running `sudo apt-get
+ install libglfw3 libglew2.0`. Please note that:
+
+ - [`dm_control.viewer`] can only be used with GLFW.
+ - GLFW will not work on headless machines.
+
+* "Headless" hardware rendering (i.e. without a windowing system such as X11)
+ requires [EXT_platform_device] support in the EGL driver. Recent Nvidia
+ drivers support this. You will also need GLEW. On Debian and Ubuntu, this
+ can be installed via `sudo apt-get install libglew2.0`.
+
+* Software rendering requires GLX and OSMesa. On Debian and Ubuntu these can
+ be installed using `sudo apt-get install libgl1-mesa-glx libosmesa6`.
+
+By default, `dm_control` will attempt to use GLFW first, then EGL, then OSMesa.
+You can also specify a particular backend to use by setting the `MUJOCO_GL=`
+environment variable to `"glfw"`, `"egl"`, or `"osmesa"`, respectively.
+
+## Additional instructions for Homebrew users on macOS
+
+1. The above instructions using `pip` should work, provided that you use a
+ Python interpreter that is installed by Homebrew (rather than the
+ system-default one).
+
+2. Before running, the `DYLD_LIBRARY_PATH` environment variable needs to be
+ updated with the path to the GLFW library. This can be done by running
+ `export DYLD_LIBRARY_PATH=$(brew --prefix)/lib:$DYLD_LIBRARY_PATH`.
+
+[EXT_platform_device]: https://www.khronos.org/registry/EGL/extensions/EXT/EGL_EXT_platform_device.txt
+[MuJoCo license key page]: https://www.roboti.us/license.html
+[MuJoCo website]: http://www.mujoco.org/
+[tech report]: https://arxiv.org/abs/1801.00690
+[`ctypes`]: https://docs.python.org/2/library/ctypes.html
+[`dm_control.mjcf`]: dm_control/mjcf/README.md
+[`dm_control.mujoco`]: dm_control/mujoco/README.md
+[`dm_control.suite`]: dm_control/suite/README.md
+[`dm_control.viewer`]: dm_control/viewer/README.md
+[`dm_control.locomotion`]: dm_control/locomotion/README.md
+[`dm_control.locomotion.soccer`]: dm_control/locomotion/soccer/README.md
diff --git a/DMC/src/env/dm_control/all_domains.png b/DMC/src/env/dm_control/all_domains.png
new file mode 100644
index 0000000..10c22fa
Binary files /dev/null and b/DMC/src/env/dm_control/all_domains.png differ
diff --git a/DMC/src/env/dm_control/dm_control/__init__.py b/DMC/src/env/dm_control/dm_control/__init__.py
new file mode 100644
index 0000000..1ebb270
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/_render/__init__.py b/DMC/src/env/dm_control/dm_control/_render/__init__.py
new file mode 100644
index 0000000..c3a678b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/__init__.py
@@ -0,0 +1,89 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""OpenGL context management for rendering MuJoCo scenes.
+
+By default, the `Renderer` class will try to load one of the following rendering
+APIs, in descending order of priority: GLFW > EGL > OSMesa.
+
+It is also possible to select a specific backend by setting the `MUJOCO_GL=`
+environment variable to 'glfw', 'egl', or 'osmesa'.
+"""
+
+import collections
+import os
+
+from absl import logging
+from dm_control._render import constants
+
+BACKEND = os.environ.get(constants.MUJOCO_GL)
+
+
+# pylint: disable=g-import-not-at-top
+def _import_egl():
+ from dm_control._render.pyopengl.egl_renderer import EGLContext
+ return EGLContext
+
+
+def _import_glfw():
+ from dm_control._render.glfw_renderer import GLFWContext
+ return GLFWContext
+
+
+def _import_osmesa():
+ from dm_control._render.pyopengl.osmesa_renderer import OSMesaContext
+ return OSMesaContext
+# pylint: enable=g-import-not-at-top
+
+_ALL_RENDERERS = collections.OrderedDict([
+ (constants.GLFW, _import_glfw),
+ (constants.EGL, _import_egl),
+ (constants.OSMESA, _import_osmesa),
+])
+
+
+if BACKEND is not None:
+ # If a backend was specified, try importing it and error if unsuccessful.
+ try:
+ import_func = _ALL_RENDERERS[BACKEND]
+ except KeyError:
+ raise RuntimeError(
+ 'Environment variable {} must be one of {!r}: got {!r}.'
+ .format(constants.MUJOCO_GL, _ALL_RENDERERS.keys(), BACKEND))
+ logging.info('MUJOCO_GL=%s, attempting to import specified OpenGL backend.',
+ BACKEND)
+ Renderer = import_func() # pylint: disable=invalid-name
+else:
+ logging.info('MUJOCO_GL is not set, so an OpenGL backend will be chosen '
+ 'automatically.')
+ # Otherwise try importing them in descending order of priority until
+ # successful.
+ for name, import_func in _ALL_RENDERERS.items():
+ try:
+ Renderer = import_func()
+ BACKEND = name
+ logging.info('Successfully imported OpenGL backend: %s', name)
+ break
+ except ImportError:
+ logging.info('Failed to import OpenGL backend: %s', name)
+ if BACKEND is None:
+ logging.info('No OpenGL backend could be imported. Attempting to create a '
+ 'rendering context will result in a RuntimeError.')
+
+ def Renderer(*args, **kwargs): # pylint: disable=function-redefined,invalid-name
+ del args, kwargs
+ raise RuntimeError('No OpenGL rendering backend is available.')
+
+USING_GPU = BACKEND in (constants.EGL, constants.GLFW)
diff --git a/DMC/src/env/dm_control/dm_control/_render/base.py b/DMC/src/env/dm_control/dm_control/_render/base.py
new file mode 100644
index 0000000..2c9f2ca
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/base.py
@@ -0,0 +1,141 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for OpenGL context handlers.
+
+`ContextBase` defines a common API that OpenGL rendering contexts should conform
+to. In addition, it provides a `make_current` context manager that:
+
+1. Makes this OpenGL context current within the appropriate rendering thread.
+2. Yields an object exposing a `call` method that can be used to execute OpenGL
+ calls within the rendering thread.
+
+See the docstring for `dm_control.utils.render_executor` for further details
+regarding rendering threads.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import atexit
+import collections
+import contextlib
+import weakref
+
+from absl import logging
+from dm_control._render import executor
+import six
+
+_CURRENT_CONTEXT_FOR_THREAD = collections.defaultdict(lambda: None)
+_CURRENT_THREAD_FOR_CONTEXT = collections.defaultdict(lambda: None)
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ContextBase(object):
+ """Base class for managing OpenGL contexts."""
+
+ def __init__(self,
+ max_width,
+ max_height,
+ render_executor_class=executor.RenderExecutor):
+ """Initializes this context."""
+ logging.debug('Using render executor class: %s',
+ render_executor_class.__name__)
+ self._render_executor = render_executor_class()
+ self._refcount = 0
+
+ self_weakref = weakref.ref(self)
+ def _free_at_exit():
+ if self_weakref():
+ self_weakref()._free_unconditionally() # pylint: disable=protected-access
+ atexit.register(_free_at_exit)
+
+ with self._render_executor.execution_context() as ctx:
+ ctx.call(self._platform_init, max_width, max_height)
+
+ def increment_refcount(self):
+ self._refcount += 1
+
+ def decrement_refcount(self):
+ self._refcount -= 1
+
+ @property
+ def terminated(self):
+ return self._render_executor.terminated
+
+ @property
+ def thread(self):
+ return self._render_executor.thread
+
+ def _free_on_executor_thread(self):
+ if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] == id(self):
+ del _CURRENT_THREAD_FOR_CONTEXT[id(self)]
+ del _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]
+ self._platform_free()
+
+ def free(self):
+ """Frees resources associated with this context if its refcount is zero."""
+ if self._refcount == 0:
+ self._free_unconditionally()
+
+ def _free_unconditionally(self):
+ self._render_executor.terminate(self._free_on_executor_thread)
+
+ def __del__(self):
+ self._free_unconditionally()
+
+ @contextlib.contextmanager
+ def make_current(self):
+ """Context manager that makes this Renderer's OpenGL context current.
+
+ Yields:
+ An object that exposes a `call` method that can be used to call a
+ function on the dedicated rendering thread.
+
+ Raises:
+ RuntimeError: If this context is already current on another thread.
+ """
+
+ with self._render_executor.execution_context() as ctx:
+ if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] != id(self):
+ if _CURRENT_THREAD_FOR_CONTEXT[id(self)]:
+ raise RuntimeError(
+ 'Cannot make context {!r} current on thread {!r}: '
+ 'this context is already current on another thread {!r}.'
+ .format(self, self._render_executor.thread,
+ _CURRENT_THREAD_FOR_CONTEXT[id(self)]))
+ else:
+ current_context = (
+ _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread])
+ if current_context:
+ del _CURRENT_THREAD_FOR_CONTEXT[current_context]
+ _CURRENT_THREAD_FOR_CONTEXT[id(self)] = self._render_executor.thread
+ _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] = id(self)
+ ctx.call(self._platform_make_current)
+ yield ctx
+
+ @abc.abstractmethod
+ def _platform_init(self, max_width, max_height):
+ """Performs an implementation-specific context initialization."""
+
+ @abc.abstractmethod
+ def _platform_make_current(self):
+ """Make the OpenGL context current on the executing thread."""
+
+ @abc.abstractmethod
+ def _platform_free(self):
+ """Performs an implementation-specific context cleanup."""
diff --git a/DMC/src/env/dm_control/dm_control/_render/base_test.py b/DMC/src/env/dm_control/dm_control/_render/base_test.py
new file mode 100644
index 0000000..409b7c3
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/base_test.py
@@ -0,0 +1,146 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the base rendering module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+# Internal dependencies.
+from absl.testing import absltest
+from dm_control._render import base
+from dm_control._render import executor
+
+WIDTH = 1024
+HEIGHT = 768
+
+
+class ContextBaseTests(absltest.TestCase):
+
+ class ContextMock(base.ContextBase):
+
+ def _platform_init(self, max_width, max_height):
+ self.init_thread = threading.current_thread()
+ self.make_current_count = 0
+ self.max_width = max_width
+ self.max_height = max_height
+ self.free_thread = None
+
+ def _platform_make_current(self):
+ self.make_current_count += 1
+ self.make_current_thread = threading.current_thread()
+
+ def _platform_free(self):
+ self.free_thread = threading.current_thread()
+
+ def setUp(self):
+ super(ContextBaseTests, self).setUp()
+ self.context = ContextBaseTests.ContextMock(WIDTH, HEIGHT)
+
+ def test_init(self):
+ self.assertIs(self.context.init_thread, self.context.thread)
+ self.assertEqual(self.context.max_width, WIDTH)
+ self.assertEqual(self.context.max_height, HEIGHT)
+
+ def test_make_current(self):
+ self.assertEqual(self.context.make_current_count, 0)
+
+ with self.context.make_current():
+ pass
+ self.assertEqual(self.context.make_current_count, 1)
+ self.assertIs(self.context.make_current_thread, self.context.thread)
+
+ # Already current, shouldn't trigger a call to `_platform_make_current`.
+ with self.context.make_current():
+ pass
+ self.assertEqual(self.context.make_current_count, 1)
+ self.assertIs(self.context.make_current_thread, self.context.thread)
+
+ def test_thread_sharing(self):
+ first_context = ContextBaseTests.ContextMock(
+ WIDTH, HEIGHT, executor.PassthroughRenderExecutor)
+ second_context = ContextBaseTests.ContextMock(
+ WIDTH, HEIGHT, executor.PassthroughRenderExecutor)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 1)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 1)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 1)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 1)
+
+ with first_context.make_current():
+ pass
+ self.assertEqual(first_context.make_current_count, 2)
+
+ with second_context.make_current():
+ pass
+ self.assertEqual(second_context.make_current_count, 2)
+
+ def test_free(self):
+ with self.context.make_current():
+ pass
+
+ thread = self.context.thread
+ self.assertIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT)
+ self.assertIn(thread, base._CURRENT_CONTEXT_FOR_THREAD)
+
+ self.context.free()
+ self.assertIs(self.context.free_thread, thread)
+ self.assertIsNone(self.context.thread)
+
+ self.assertNotIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT)
+ self.assertNotIn(thread, base._CURRENT_CONTEXT_FOR_THREAD)
+
+ def test_refcounting(self):
+ thread = self.context.thread
+
+ self.assertEqual(self.context._refcount, 0)
+ self.context.increment_refcount()
+ self.assertEqual(self.context._refcount, 1)
+
+ # Context should not be freed yet, since its refcount is still positive.
+ self.context.free()
+ self.assertIsNone(self.context.free_thread)
+ self.assertIs(self.context.thread, thread)
+
+ # Decrement the refcount to zero.
+ self.context.decrement_refcount()
+ self.assertEqual(self.context._refcount, 0)
+
+ # Now the context can be freed.
+ self.context.free()
+ self.assertIs(self.context.free_thread, thread)
+ self.assertIsNone(self.context.thread)
+
+ def test_del(self):
+ self.assertIsNone(self.context.free_thread)
+ self.context.__del__()
+ self.assertIsNotNone(self.context.free_thread)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/_render/constants.py b/DMC/src/env/dm_control/dm_control/_render/constants.py
new file mode 100644
index 0000000..2509dce
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/constants.py
@@ -0,0 +1,31 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""String constants for the rendering module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Name of the environment variable that selects a renderer platform.
+MUJOCO_GL = 'MUJOCO_GL'
+
+# Name of the environment variable that selects a platform for PyOpenGL.
+PYOPENGL_PLATFORM = 'PYOPENGL_PLATFORM'
+
+# Renderer platform specifiers.
+OSMESA = 'osmesa'
+GLFW = 'glfw'
+EGL = 'egl'
diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py b/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py
new file mode 100644
index 0000000..12a9c48
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RenderExecutor executes OpenGL rendering calls on an appropriate thread.
+
+OpenGL calls must be made on the same thread as where an OpenGL context is
+made current on. With GPU rendering, migrating OpenGL contexts between threads
+can become expensive. We provide a thread-safe executor that maintains a
+thread on which an OpenGL context can be kept permanently current, and any other
+threads that wish to render with that context will have their rendering calls
+offloaded to the dedicated thread.
+
+For single-threaded applications, set the `DISABLE_RENDER_THREAD_OFFLOADING`
+environment variable before launching the Python interpreter. This will
+eliminate the overhead of unnecessary thread-switching.
+"""
+
+# pylint: disable=g-import-not-at-top
+import os
+_OFFLOAD = not bool(os.environ.get('DISABLE_RENDER_THREAD_OFFLOADING', ''))
+del os
+
+from dm_control._render.executor.render_executor import BaseRenderExecutor
+from dm_control._render.executor.render_executor import OffloadingRenderExecutor
+from dm_control._render.executor.render_executor import PassthroughRenderExecutor
+
+_EXECUTORS = (PassthroughRenderExecutor, OffloadingRenderExecutor)
+
+try:
+ from dm_control._render.executor.native_mutex.render_executor import NativeMutexOffloadingRenderExecutor
+ _EXECUTORS += (NativeMutexOffloadingRenderExecutor,)
+except ImportError:
+ NativeMutexOffloadingRenderExecutor = None
+
+if _OFFLOAD:
+ RenderExecutor = ( # pylint: disable=invalid-name
+ NativeMutexOffloadingRenderExecutor or OffloadingRenderExecutor)
+else:
+ RenderExecutor = PassthroughRenderExecutor # pylint: disable=invalid-name
diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py
new file mode 100644
index 0000000..a7911e7
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py
@@ -0,0 +1,222 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RenderExecutors executes OpenGL rendering calls on an appropriate thread.
+
+The purpose of these classes is to ensure that OpenGL calls are made on the
+same thread as where an OpenGL context was made current.
+
+In a single-threaded setting, `PassthroughRenderExecutor` is essentially a no-op
+that executes rendering calls on the same thread. This is provided to minimize
+thread-switching overhead.
+
+In a multithreaded setting, `OffloadingRenderExecutor` maintains a separate
+dedicated thread on which the OpenGL context is created and made current. All
+subsequent rendering calls are then offloaded onto this dedicated thread.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import contextlib
+import threading
+
+from concurrent import futures
+import six
+
+_NOT_IN_CONTEXT = 'Cannot be called outside of an `execution_context`.'
+_ALREADY_TERMINATED = 'This executor has already been terminated.'
+
+
+class _FakeLock(object):
+ """An object with the same API as `threading.Lock` but that does nothing."""
+
+ def acquire(self, blocking=True):
+ pass
+
+ def release(self):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ del exc_type, exc_value, traceback
+
+
+_FAKE_LOCK = _FakeLock()
+
+
+@six.add_metaclass(abc.ABCMeta)
+class BaseRenderExecutor(object):
+ """An object that manages rendering calls for an OpenGL context.
+
+ This class helps ensure that OpenGL calls are made on the correct thread. The
+ usage pattern is as follows:
+
+ ```python
+ executor = SomeRenderExecutorClass()
+ with executor.execution_context() as ctx:
+ ctx.call(an_opengl_call, arg, kwarg=foo)
+ result = ctx.call(another_opengl_call)
+ ```
+ """
+
+ def __init__(self):
+ self._locked = 0
+ self._terminated = False
+
+ def _check_locked(self):
+ if not self._locked:
+ raise RuntimeError(_NOT_IN_CONTEXT)
+
+ def _check_not_terminated(self):
+ if self._terminated:
+ raise RuntimeError(_ALREADY_TERMINATED)
+
+ @contextlib.contextmanager
+ def execution_context(self):
+ """A context manager that allows calls to be offloaded to this executor."""
+ self._check_not_terminated()
+ with self._lock_if_necessary:
+ self._locked += 1
+ yield self
+ self._locked -= 1
+
+ @property
+ def terminated(self):
+ return self._terminated
+
+ @abc.abstractproperty
+ def thread(self):
+ pass
+
+ @abc.abstractproperty
+ def _lock_if_necessary(self):
+ pass
+
+ @abc.abstractmethod
+ def call(self, *args, **kwargs):
+ pass
+
+ @abc.abstractmethod
+ def terminate(self, cleanup_callable=None):
+ pass
+
+
+class PassthroughRenderExecutor(BaseRenderExecutor):
+ """A no-op render executor that executes on the calling thread."""
+
+ def __init__(self):
+ super(PassthroughRenderExecutor, self).__init__()
+ self._mutex = threading.RLock()
+
+ @property
+ def thread(self):
+ if not self._terminated:
+ return threading.current_thread()
+ else:
+ return None
+
+ @property
+ def _lock_if_necessary(self):
+ return self._mutex
+
+ def call(self, func, *args, **kwargs):
+ self._check_locked()
+ return func(*args, **kwargs)
+
+ def terminate(self, cleanup_callable=None):
+ with self._lock_if_necessary:
+ if not self._terminated:
+ if cleanup_callable:
+ cleanup_callable()
+ self._terminated = True
+
+
+class _ThreadPoolExecutorPool(object):
+ """A pool of reusable ThreadPoolExecutors."""
+
+ def __init__(self):
+ self._deque = collections.deque()
+ self._lock = threading.Lock()
+
+ def acquire(self):
+ with self._lock:
+ if self._deque:
+ return self._deque.popleft()
+ else:
+ return futures.ThreadPoolExecutor(max_workers=1)
+
+ def release(self, thread_pool_executor):
+ with self._lock:
+ self._deque.append(thread_pool_executor)
+
+
+_THREAD_POOL_EXECUTOR_POOL = _ThreadPoolExecutorPool()
+
+
+class OffloadingRenderExecutor(BaseRenderExecutor):
+ """A render executor that executes calls on a dedicated offload thread."""
+
+ def __init__(self):
+ super(OffloadingRenderExecutor, self).__init__()
+ self._mutex = threading.RLock()
+ self._executor = _THREAD_POOL_EXECUTOR_POOL.acquire()
+ self._thread = self._executor.submit(threading.current_thread).result()
+
+ @property
+ def thread(self):
+ return self._thread
+
+ @property
+ def _lock_if_necessary(self):
+ if threading.current_thread() is self.thread:
+ # If the offload thread needs to make a call to its own executor, for
+ # example when a weakref callback is triggered during an offloaded call,
+ # then we must not try to reacquire our own lock.
+ # Otherwise, a deadlock ensues.
+ return _FAKE_LOCK
+ else:
+ return self._mutex
+
+ def call(self, func, *args, **kwargs):
+ self._check_locked()
+ return self._call_locked(func, *args, **kwargs)
+
+ def _call_locked(self, func, *args, **kwargs):
+ if threading.current_thread() is self.thread:
+ # If the offload thread needs to make a call to its own executor, for
+ # example when a weakref callback is triggered during an offloaded call,
+ # we should just directly call the function.
+ # Otherwise, a deadlock ensues.
+ return func(*args, **kwargs)
+ else:
+ return self._executor.submit(func, *args, **kwargs).result()
+
+ def terminate(self, cleanup_callable=None):
+ if self._terminated:
+ return
+ with self._lock_if_necessary:
+ if not self._terminated:
+ if cleanup_callable:
+ self._call_locked(cleanup_callable)
+ _THREAD_POOL_EXECUTOR_POOL.release(self._executor)
+ self._executor = None
+ self._thread = None
+ self._terminated = True
diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py
new file mode 100644
index 0000000..56f68f8
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py
@@ -0,0 +1,210 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.utils.render_executor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+import time
+import unittest
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control._render import executor
+import mock
+from six.moves import range
+
+
+def enforce_timeout(timeout):
+ def wrap(test_func):
+ def wrapped_test(self, *args, **kwargs):
+ thread = threading.Thread(
+ target=test_func, args=((self,) + args), kwargs=kwargs)
+ thread.daemon = True
+ thread.start()
+ thread.join(timeout=timeout)
+ self.assertFalse(
+ thread.is_alive(),
+ msg='Test timed out after {} seconds.'.format(timeout))
+ return wrapped_test
+ return wrap
+
+
+class RenderExecutorTest(parameterized.TestCase):
+
+ def _make_executor(self, executor_type):
+ if (executor_type == executor.NativeMutexOffloadingRenderExecutor and
+ executor_type is None):
+ raise unittest.SkipTest(
+ 'NativeMutexOffloadingRenderExecutor is not available.')
+ else:
+ return executor_type()
+
+ def test_passthrough_executor_thread(self):
+ render_executor = self._make_executor(executor.PassthroughRenderExecutor)
+ self.assertIs(render_executor.thread, threading.current_thread())
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_offloading_executor_thread(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ self.assertIsNot(render_executor.thread, threading.current_thread())
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_on_correct_thread(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ with render_executor.execution_context() as ctx:
+ actual_executed_thread = ctx.call(threading.current_thread)
+ self.assertIs(actual_executed_thread, render_executor.thread)
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_multithreaded(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ list_length = 5
+ shared_list = [None] * list_length
+
+ def fill_list(thread_idx):
+ def assign_value(i):
+ shared_list[i] = thread_idx
+ for _ in range(1000):
+ with render_executor.execution_context() as ctx:
+ for i in range(list_length):
+ ctx.call(assign_value, i)
+ # Other threads should be prevented from calling `assign_value` while
+ # this thread is inside the `execution_context`.
+ self.assertEqual(shared_list, [thread_idx] * list_length)
+
+ threads = [threading.Thread(target=fill_list, args=(i,)) for i in range(9)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_exception(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ message = 'fake error'
+ def raise_value_error():
+ raise ValueError(message)
+ with render_executor.execution_context() as ctx:
+ with self.assertRaisesWithLiteralMatch(ValueError, message):
+ ctx.call(raise_value_error)
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_terminate(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ cleanup = mock.MagicMock()
+ render_executor.terminate(cleanup)
+ cleanup.assert_called_once_with()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_outside_of_context(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ func = mock.MagicMock()
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._NOT_IN_CONTEXT):
+ render_executor.call(func)
+ # Also test that the locked flag is properly cleared when leaving a context.
+ with render_executor.execution_context():
+ render_executor.call(lambda: None)
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._NOT_IN_CONTEXT):
+ render_executor.call(func)
+ func.assert_not_called()
+ render_executor.terminate()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_call_after_terminate(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ render_executor.terminate()
+ func = mock.MagicMock()
+ with self.assertRaisesWithLiteralMatch(
+ RuntimeError, executor.render_executor._ALREADY_TERMINATED):
+ with render_executor.execution_context() as ctx:
+ ctx.call(func)
+ func.assert_not_called()
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ def test_locking(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ other_thread_context_entered = threading.Condition()
+ other_thread_context_done = [False]
+ def other_thread_func():
+ with render_executor.execution_context():
+ with other_thread_context_entered:
+ other_thread_context_entered.notify()
+ time.sleep(1)
+ other_thread_context_done[0] = True
+ other_thread = threading.Thread(target=other_thread_func)
+ with other_thread_context_entered:
+ other_thread.start()
+ other_thread_context_entered.wait()
+ with render_executor.execution_context():
+ self.assertTrue(
+ other_thread_context_done[0],
+ msg=('Main thread should not be able to enter the execution context '
+ 'until the other thread is done.'))
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ @enforce_timeout(timeout=5.)
+ def test_reentrant_locking(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ def triple_lock(render_executor):
+ with render_executor.execution_context():
+ with render_executor.execution_context():
+ with render_executor.execution_context():
+ pass
+ triple_lock(render_executor)
+
+ @parameterized.parameters(executor.PassthroughRenderExecutor,
+ executor.OffloadingRenderExecutor,
+ executor.NativeMutexOffloadingRenderExecutor)
+ @enforce_timeout(timeout=5.)
+ def test_no_deadlock_in_callbacks(self, executor_type):
+ render_executor = self._make_executor(executor_type)
+ # This test times out in the event of a deadlock.
+ def callback():
+ with render_executor.execution_context() as ctx:
+ ctx.call(lambda: None)
+ with render_executor.execution_context() as ctx:
+ ctx.call(callback)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py
new file mode 100644
index 0000000..847b1ad
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by GLFW."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+from dm_control._render import base
+from dm_control._render import executor
+import six
+
+# Re-raise any exceptions that occur during module import as `ImportError`s.
+# This simplifies the conditional imports in `render/__init__.py`.
+try:
+ import glfw # pylint: disable=g-import-not-at-top
+except (ImportError, IOError, OSError) as exc:
+ _, exc, tb = sys.exc_info()
+ six.reraise(ImportError, ImportError(str(exc)), tb)
+try:
+ glfw.init()
+except glfw.GLFWError as exc:
+ _, exc, tb = sys.exc_info()
+ six.reraise(ImportError, ImportError(str(exc)), tb)
+
+
+class GLFWContext(base.ContextBase):
+ """An OpenGL context backed by GLFW."""
+
+ def __init__(self, max_width, max_height):
+ # GLFWContext always uses `PassthroughRenderExecutor` rather than offloading
+ # rendering calls to a separate thread because GLFW can only be safely used
+ # from the main thread.
+ super(GLFWContext, self).__init__(max_width, max_height,
+ executor.PassthroughRenderExecutor)
+
+ def _platform_init(self, max_width, max_height):
+ """Initializes this context.
+
+ Args:
+ max_width: Integer specifying the maximum framebuffer width in pixels.
+ max_height: Integer specifying the maximum framebuffer height in pixels.
+ """
+ glfw.window_hint(glfw.VISIBLE, 0)
+ glfw.window_hint(glfw.DOUBLEBUFFER, 0)
+ self._context = glfw.create_window(width=max_width, height=max_height,
+ title='Invisible window', monitor=None,
+ share=None)
+ # This reference prevents `glfw.destroy_window` from being garbage-collected
+ # before the last window is destroyed, otherwise we may get
+ # `AttributeError`s when the `__del__` method is later called.
+ self._destroy_window = glfw.destroy_window
+
+ def _platform_make_current(self):
+ glfw.make_context_current(self._context)
+
+ def _platform_free(self):
+ """Frees resources associated with this context."""
+ if self._context:
+ if glfw.get_current_context() == self._context:
+ glfw.make_context_current(None)
+ self._destroy_window(self._context)
+ self._context = None
diff --git a/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py
new file mode 100644
index 0000000..2e3be3e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py
@@ -0,0 +1,75 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for GLFWContext."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+# Internal dependencies.
+from absl.testing import absltest
+from dm_control import _render
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.testing import decorators
+
+
+import mock # pylint: disable=g-import-not-at-top
+
+MAX_WIDTH = 1024
+MAX_HEIGHT = 1024
+
+CONTEXT_PATH = _render.__name__ + '.glfw_renderer.glfw'
+
+
+@unittest.skipUnless(
+ _render.BACKEND == _render.constants.GLFW,
+ reason='GLFW beckend not selected.')
+class GLFWContextTest(absltest.TestCase):
+
+ def test_init(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ self.assertIs(renderer._context, mock_context)
+
+ def test_make_current(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ with renderer.make_current():
+ pass
+ mock_glfw.make_context_current.assert_called_once_with(mock_context)
+
+ def test_freeing(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_glfw:
+ mock_glfw.create_window.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+ mock_glfw.destroy_window.assert_called_once_with(mock_context)
+ self.assertIsNone(renderer._context)
+
+ @decorators.run_threaded(num_threads=1, calls_per_thread=20)
+ def test_repeatedly_create_and_destroy_context(self):
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ wrapper.MjrContext(wrapper.MjModel.from_xml_string(''), renderer)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py
new file mode 100644
index 0000000..a514c4b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py
new file mode 100644
index 0000000..2e373ed
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py
@@ -0,0 +1,79 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Extends OpenGL.EGL with definitions necessary for headless rendering."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes
+from OpenGL.platform import ctypesloader # pylint: disable=g-bad-import-order
+try:
+ # Nvidia driver seems to need libOpenGL.so (as opposed to libGL.so)
+ # for multithreading to work properly. We load this in before everything else.
+ ctypesloader.loadLibrary(ctypes.cdll, 'OpenGL', mode=ctypes.RTLD_GLOBAL)
+except OSError:
+ pass
+
+# pylint: disable=g-import-not-at-top
+
+from OpenGL import EGL
+from OpenGL import error
+from six.moves import range
+
+
+# From the EGL_EXT_device_enumeration extension.
+PFNEGLQUERYDEVICESEXTPROC = ctypes.CFUNCTYPE(
+ EGL.EGLBoolean,
+ EGL.EGLint,
+ ctypes.POINTER(EGL.EGLDeviceEXT),
+ ctypes.POINTER(EGL.EGLint),
+)
+try:
+ _eglQueryDevicesEXT = PFNEGLQUERYDEVICESEXTPROC( # pylint: disable=invalid-name
+ EGL.eglGetProcAddress('eglQueryDevicesEXT'))
+except TypeError:
+ raise ImportError('eglQueryDevicesEXT is not available.')
+
+
+# From the EGL_EXT_platform_device extension.
+EGL_PLATFORM_DEVICE_EXT = 0x313F
+PFNEGLGETPLATFORMDISPLAYEXTPROC = ctypes.CFUNCTYPE(
+ EGL.EGLDisplay, EGL.EGLenum, ctypes.c_void_p, ctypes.POINTER(EGL.EGLint))
+try:
+ eglGetPlatformDisplayEXT = PFNEGLGETPLATFORMDISPLAYEXTPROC( # pylint: disable=invalid-name
+ EGL.eglGetProcAddress('eglGetPlatformDisplayEXT'))
+except TypeError:
+ raise ImportError('eglGetPlatformDisplayEXT is not available.')
+
+
+# Wrap raw _eglQueryDevicesEXT function into something more Pythonic.
+def eglQueryDevicesEXT(max_devices=10): # pylint: disable=invalid-name
+ devices = (EGL.EGLDeviceEXT * max_devices)()
+ num_devices = EGL.EGLint()
+ success = _eglQueryDevicesEXT(max_devices, devices, num_devices)
+ if success == EGL.EGL_TRUE:
+ return [devices[i] for i in range(num_devices.value)]
+ else:
+ raise error.GLError(err=EGL.eglGetError(),
+ baseOperation=eglQueryDevicesEXT,
+ result=success)
+
+
+# Expose everything from upstream so that
+# we can use this as a drop-in replacement for OpenGL.EGL.
+# pylint: disable=wildcard-import,g-bad-import-order
+from OpenGL.EGL import *
diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py
new file mode 100644
index 0000000..249b490
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py
@@ -0,0 +1,135 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by EGL, provided through PyOpenGL."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import atexit
+import ctypes
+import os
+
+from dm_control._render import base
+from dm_control._render import constants
+from dm_control._render import executor
+
+PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM)
+
+if not PYOPENGL_PLATFORM:
+ os.environ[constants.PYOPENGL_PLATFORM] = constants.EGL
+elif PYOPENGL_PLATFORM != constants.EGL:
+ raise ImportError(
+ 'Cannot use EGL rendering platform. '
+ 'The PYOPENGL_PLATFORM environment variable is set to {!r} '
+ '(should be either unset or {!r}).'
+ .format(PYOPENGL_PLATFORM, constants.EGL))
+
+
+# pylint: disable=g-import-not-at-top
+from dm_control._render.pyopengl import egl_ext as EGL
+from OpenGL import error
+
+
+def create_initialized_headless_egl_display():
+ """Creates an initialized EGL display directly on a device."""
+ devices = EGL.eglQueryDevicesEXT()
+ if os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None:
+ devices = [devices[int(os.environ["CUDA_VISIBLE_DEVICES"])]]
+ for device in devices:
+ display = EGL.eglGetPlatformDisplayEXT(
+ EGL.EGL_PLATFORM_DEVICE_EXT, device, None)
+ if display != EGL.EGL_NO_DISPLAY and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ # `eglInitialize` may or may not raise an exception on failure depending
+ # on how PyOpenGL is configured. We therefore catch a `GLError` and also
+ # manually check the output of `eglGetError()` here.
+ try:
+ initialized = EGL.eglInitialize(display, None, None)
+ except error.GLError:
+ pass
+ else:
+ if initialized == EGL.EGL_TRUE and EGL.eglGetError() == EGL.EGL_SUCCESS:
+ return display
+ return EGL.EGL_NO_DISPLAY
+
+
+EGL_DISPLAY = create_initialized_headless_egl_display()
+if EGL_DISPLAY == EGL.EGL_NO_DISPLAY:
+ raise ImportError('Cannot initialize a headless EGL display.')
+atexit.register(EGL.eglTerminate, EGL_DISPLAY)
+
+
+EGL_ATTRIBUTES = (
+ EGL.EGL_RED_SIZE, 8,
+ EGL.EGL_GREEN_SIZE, 8,
+ EGL.EGL_BLUE_SIZE, 8,
+ EGL.EGL_ALPHA_SIZE, 8,
+ EGL.EGL_DEPTH_SIZE, 24,
+ EGL.EGL_STENCIL_SIZE, 8,
+ EGL.EGL_COLOR_BUFFER_TYPE, EGL.EGL_RGB_BUFFER,
+ EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT,
+ EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_BIT,
+ EGL.EGL_NONE
+)
+
+
+class EGLContext(base.ContextBase):
+ """An OpenGL context backed by EGL."""
+
+ def __init__(self, max_width, max_height):
+ # EGLContext currently only works with `PassthroughRenderExecutor`.
+ # TODO(b/110927854) Make this work with the offloading executor.
+ super(EGLContext, self).__init__(max_width, max_height,
+ executor.PassthroughRenderExecutor)
+
+ def _platform_init(self, unused_max_width, unused_max_height):
+ """Initialization this EGL context."""
+ num_configs = ctypes.c_long()
+ config_size = 1
+ config = EGL.EGLConfig()
+ EGL.eglReleaseThread()
+ EGL.eglChooseConfig(
+ EGL_DISPLAY,
+ EGL_ATTRIBUTES,
+ ctypes.byref(config),
+ config_size,
+ num_configs)
+ if num_configs.value < 1:
+ raise RuntimeError(
+ 'EGL failed to find a framebuffer configuration that matches the '
+ 'desired attributes: {}'.format(EGL_ATTRIBUTES))
+ EGL.eglBindAPI(EGL.EGL_OPENGL_API)
+ self._context = EGL.eglCreateContext(
+ EGL_DISPLAY, config, EGL.EGL_NO_CONTEXT, None)
+ if not self._context:
+ raise RuntimeError('Cannot create an EGL context.')
+
+ def _platform_make_current(self):
+ if self._context:
+ success = EGL.eglMakeCurrent(
+ EGL_DISPLAY, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, self._context)
+ if not success:
+ raise RuntimeError('Failed to make the EGL context current.')
+
+ def _platform_free(self):
+ """Frees resources associated with this context."""
+ if self._context:
+ current_context = EGL.eglGetCurrentContext()
+ if current_context and self._context.address == current_context.address:
+ EGL.eglMakeCurrent(EGL_DISPLAY, EGL.EGL_NO_SURFACE,
+ EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT)
+ EGL.eglDestroyContext(EGL_DISPLAY, self._context)
+ self._context = None
diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py
new file mode 100644
index 0000000..4c2c54a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An OpenGL renderer backed by OSMesa."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from dm_control._render import base
+from dm_control._render import constants
+
+PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM)
+
+if not PYOPENGL_PLATFORM:
+ os.environ[constants.PYOPENGL_PLATFORM] = constants.OSMESA
+elif PYOPENGL_PLATFORM != constants.OSMESA:
+ raise ImportError(
+ 'Cannot use OSMesa rendering platform. '
+ 'The PYOPENGL_PLATFORM environment variable is set to {!r} '
+ '(should be either unset or {!r}).'
+ .format(PYOPENGL_PLATFORM, constants.OSMESA))
+
+# pylint: disable=g-import-not-at-top
+from OpenGL import GL
+from OpenGL import osmesa
+from OpenGL.GL import arrays
+
+_DEPTH_BITS = 24
+_STENCIL_BITS = 8
+_ACCUM_BITS = 0
+
+
+class OSMesaContext(base.ContextBase):
+ """An OpenGL context backed by OSMesa."""
+
+ def _platform_init(self, max_width, max_height):
+ """Initializes this OSMesa context."""
+ self._context = osmesa.OSMesaCreateContextExt(
+ osmesa.OSMESA_RGBA,
+ _DEPTH_BITS,
+ _STENCIL_BITS,
+ _ACCUM_BITS,
+ None, # sharelist
+ )
+ if not self._context:
+ raise RuntimeError('Failed to create OSMesa GL context.')
+
+ self._height = max_height
+ self._width = max_width
+
+ # Allocate a buffer to render into.
+ self._buffer = arrays.GLfloatArray.zeros((max_height, max_width, 4))
+
+ def _platform_make_current(self):
+ if self._context:
+ success = osmesa.OSMesaMakeCurrent(
+ self._context,
+ self._buffer,
+ GL.GL_FLOAT,
+ self._width,
+ self._height)
+ if not success:
+ raise RuntimeError('Failed to make OSMesa context current.')
+
+ def _platform_free(self):
+ """Frees resources associated with this context."""
+ if self._context and self._context == osmesa.OSMesaGetCurrentContext():
+ osmesa.OSMesaMakeCurrent(None, None, GL.GL_FLOAT, 0, 0)
+ osmesa.OSMesaDestroyContext(self._context)
+ self._buffer = None
+ self._context = None
diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py
new file mode 100644
index 0000000..774d7b8
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py
@@ -0,0 +1,75 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for OSMesaContext."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+# Internal dependencies.
+from absl.testing import absltest
+from dm_control import _render
+import mock
+from OpenGL import GL
+
+MAX_WIDTH = 640
+MAX_HEIGHT = 480
+
+CONTEXT_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.osmesa'
+GL_ARRAYS_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.arrays'
+
+
+@unittest.skipUnless(
+ _render.BACKEND == _render.constants.OSMESA,
+ reason='OSMesa backend not selected.')
+class OSMesaContextTest(absltest.TestCase):
+
+ def test_init(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ self.assertIs(renderer._context, mock_context)
+ renderer.free()
+
+ def test_make_current(self):
+ mock_context = mock.MagicMock()
+ mock_buffer = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ with mock.patch(GL_ARRAYS_PATH) as mock_glarrays:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ mock_glarrays.GLfloatArray.zeros.return_value = mock_buffer
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ with renderer.make_current():
+ pass
+ mock_osmesa.OSMesaMakeCurrent.assert_called_once_with(
+ mock_context, mock_buffer, GL.GL_FLOAT, MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+
+ def test_freeing(self):
+ mock_context = mock.MagicMock()
+ with mock.patch(CONTEXT_PATH) as mock_osmesa:
+ mock_osmesa.OSMesaCreateContextExt.return_value = mock_context
+ renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT)
+ renderer.free()
+ mock_osmesa.OSMesaDestroyContext.assert_called_once_with(mock_context)
+ self.assertIsNone(renderer._context)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/__init__.py b/DMC/src/env/dm_control/dm_control/autowrap/__init__.py
new file mode 100644
index 0000000..f9817d4
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py b/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py
new file mode 100644
index 0000000..25fcdcf
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py
@@ -0,0 +1,147 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+r"""Automatically generates ctypes Python bindings for MuJoCo.
+
+Parses mjdata.h, mjmodel.h, mjrender.h, mjvisualize.h, mjxmacro.h and mujoco.h;
+generates the following Python source files:
+
+ constants.py: constants
+ enums.py: enums
+ sizes.py: size information for dynamically-shaped arrays
+ types.py: ctypes declarations for structs
+ wrappers.py: low-level Python wrapper classes for structs (these implement
+ getter/setter methods for struct members where applicable)
+ functions.py: ctypes function declarations for MuJoCo API functions
+
+Example usage:
+
+ autowrap --header_paths='/path/to/mjmodel.h /path/to/mjdata.h ...' \
+ --output_dir=/path/to/mjbindings
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import io
+import os
+
+# Internal dependencies.
+from absl import app
+from absl import flags
+from absl import logging
+from dm_control.autowrap import binding_generator
+from dm_control.autowrap import codegen_util
+
+import six
+
+_MJMODEL_H = "mjmodel.h"
+_MJXMACRO_H = "mjxmacro.h"
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_spaceseplist(
+ "header_paths", None,
+ "Space-separated list of paths to MuJoCo header files.")
+
+flags.DEFINE_string("output_dir", None,
+ "Path to output directory for wrapper source files.")
+
+
+def main(unused_argv):
+ special_header_paths = {}
+
+ # Get the path to the mjmodel and mjxmacro header files.
+ # These header files need special handling.
+ for header in (_MJMODEL_H, _MJXMACRO_H):
+ for path in FLAGS.header_paths:
+ if path.endswith(header):
+ special_header_paths[header] = path
+ break
+ if header not in special_header_paths:
+ logging.fatal("List of inputs must contain a path to %s", header)
+
+ # Make sure mjmodel.h is parsed first, since it is included by other headers.
+ srcs = codegen_util.UniqueOrderedDict()
+ sorted_header_paths = sorted(FLAGS.header_paths)
+ sorted_header_paths.remove(special_header_paths[_MJMODEL_H])
+ sorted_header_paths.insert(0, special_header_paths[_MJMODEL_H])
+ for p in sorted_header_paths:
+ with io.open(p, "r", errors="ignore") as f:
+ srcs[p] = f.read()
+
+ # consts_dict should be a codegen_util.UniqueOrderedDict.
+ # This is a temporary workaround due to the fact that the parser does not yet
+ # handle nested `#if define(predicate)` blocks, which results in some
+ # constants being parsed twice. We therefore can't enforce the uniqueness of
+ # the keys in `consts_dict`. As of MuJoCo v1.30 there is only a single problem
+ # block beginning on line 10 in mujoco.h, and a single constant is affected
+ # (MJAPI).
+ consts_dict = collections.OrderedDict()
+
+ # These are commented in `mjdata.h` but have no macros in `mjxmacro.h`.
+ hints_dict = codegen_util.UniqueOrderedDict({"buffer": ("nbuffer",),
+ "stack": ("nstack",)})
+
+ parser = binding_generator.BindingGenerator(
+ consts_dict=consts_dict, hints_dict=hints_dict)
+
+ # Parse enums.
+ for pth, src in six.iteritems(srcs):
+ if pth is not special_header_paths[_MJXMACRO_H]:
+ parser.parse_enums(src)
+
+ # Parse constants and type declarations.
+ for pth, src in six.iteritems(srcs):
+ if pth is not special_header_paths[_MJXMACRO_H]:
+ parser.parse_consts_typedefs(src)
+
+ # Get shape hints from mjxmacro.h.
+ parser.parse_hints(srcs[special_header_paths[_MJXMACRO_H]])
+
+ # Parse structs and function pointer type declarations.
+ for pth, src in six.iteritems(srcs):
+ if pth is not special_header_paths[_MJXMACRO_H]:
+ parser.parse_structs_and_function_pointer_typedefs(src)
+
+ # Parse functions.
+ for pth, src in six.iteritems(srcs):
+ if pth is not special_header_paths[_MJXMACRO_H]:
+ parser.parse_functions(src)
+
+ # Parse global strings and function pointers.
+ for pth, src in six.iteritems(srcs):
+ if pth is not special_header_paths[_MJXMACRO_H]:
+ parser.parse_global_strings(src)
+ parser.parse_function_pointers(src)
+
+ # Create the output directory if it doesn't already exist.
+ if not os.path.exists(FLAGS.output_dir):
+ os.makedirs(FLAGS.output_dir)
+
+ # Generate Python source files and write them to the output directory.
+ parser.write_consts(os.path.join(FLAGS.output_dir, "constants.py"))
+ parser.write_enums(os.path.join(FLAGS.output_dir, "enums.py"))
+ parser.write_types(os.path.join(FLAGS.output_dir, "types.py"))
+ parser.write_wrappers(os.path.join(FLAGS.output_dir, "wrappers.py"))
+ parser.write_funcs_and_globals(os.path.join(FLAGS.output_dir, "functions.py"))
+ parser.write_index_dict(os.path.join(FLAGS.output_dir, "sizes.py"))
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("header_paths")
+ flags.mark_flag_as_required("output_dir")
+ app.run(main)
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py b/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py
new file mode 100644
index 0000000..74bb30b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py
@@ -0,0 +1,597 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Parses MuJoCo header files and generates Python bindings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import pprint
+import textwrap
+
+from absl import logging
+from dm_control.autowrap import c_declarations
+from dm_control.autowrap import codegen_util
+from dm_control.autowrap import header_parsing
+import pyparsing
+import six
+
+# Absolute path to the top-level module.
+_MODULE = "dm_control.mujoco.wrapper"
+
+# Imports used in all generated source files.
+_BOILERPLATE_IMPORTS = [
+ "from __future__ import absolute_import",
+ "from __future__ import division",
+ "from __future__ import print_function\n",
+]
+
+
+class Error(Exception):
+ pass
+
+
+class BindingGenerator(object):
+ """Parses declarations from MuJoCo headers and generates Python bindings."""
+
+ def __init__(self,
+ enums_dict=None,
+ consts_dict=None,
+ typedefs_dict=None,
+ hints_dict=None,
+ types_dict=None,
+ funcs_dict=None,
+ strings_dict=None,
+ func_ptrs_dict=None,
+ index_dict=None):
+ """Constructs a new HeaderParser instance.
+
+ The optional arguments listed below can be used to passing in dict-like
+ objects specifying pre-defined declarations. By default empty
+ UniqueOrderedDicts will be instantiated and then populated according to the
+ contents of the headers.
+
+ Args:
+ enums_dict: Nested mappings from {enum_name: {member_name: value}}.
+ consts_dict: Mapping from {const_name: value}.
+ typedefs_dict: Mapping from {type_name: ctypes_typename}.
+ hints_dict: Mapping from {var_name: shape_tuple}.
+ types_dict: Mapping from {type_name: type_instance}.
+ funcs_dict: Mapping from {func_name: Function_instance}.
+ strings_dict: Mapping from {var_name: StaticStringArray_instance}.
+ func_ptrs_dict: Mapping from {var_name: FunctionPtr_instance}.
+ index_dict: Mapping from {lowercase_struct_name: {var_name: shape_tuple}}.
+ """
+ self.enums_dict = (enums_dict if enums_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.consts_dict = (consts_dict if consts_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.typedefs_dict = (typedefs_dict if typedefs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.hints_dict = (hints_dict if hints_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.types_dict = (types_dict if types_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.funcs_dict = (funcs_dict if funcs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.strings_dict = (strings_dict if strings_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.func_ptrs_dict = (func_ptrs_dict if func_ptrs_dict is not None
+ else codegen_util.UniqueOrderedDict())
+ self.index_dict = (index_dict if index_dict is not None
+ else codegen_util.UniqueOrderedDict())
+
+ def get_consts_and_enums(self):
+ consts_and_enums = self.consts_dict.copy()
+ for enum in six.itervalues(self.enums_dict):
+ consts_and_enums.update(enum)
+ return consts_and_enums
+
+ def resolve_size(self, old_size):
+ """Resolves an array size identifier.
+
+ The following conversions will be attempted:
+
+ * If `old_size` is an integer it will be returned as-is.
+ * If `old_size` is a string of the form `"3"` it will be cast to an int.
+ * If `old_size` is a string in `self.consts_dict` then the value of the
+ constant will be returned.
+ * If `old_size` is a string of the form `"3*constant_name"` then the
+ result of `3*constant_value` will be returned.
+ * If `old_size` is a string that does not specify an int constant and
+ cannot be cast to an int (e.g. an identifier for a dynamic dimension,
+ such as `"ncontact"`) then it will be returned as-is.
+
+ Args:
+ old_size: An int or string.
+
+ Returns:
+ An int or string.
+ """
+ if isinstance(old_size, int):
+ return old_size # If it's already an int then there's nothing left to do
+ elif "*" in old_size:
+ # If it's a string specifying a product (such as "2*mjMAXLINEPNT"),
+ # recursively resolve the components to ints and calculate the result.
+ size = 1
+ for part in old_size.split("*"):
+ dim = self.resolve_size(part)
+ assert isinstance(dim, int)
+ size *= dim
+ return size
+ else:
+ # Recursively dereference any sizes declared in header macros
+ size = codegen_util.recursive_dict_lookup(old_size,
+ self.get_consts_and_enums())
+ # Try to coerce the result to an int, return a string if this fails
+ return codegen_util.try_coerce_to_num(size, try_types=(int,))
+
+ def get_shape_tuple(self, old_size, squeeze=False):
+ """Generates a shape tuple from parser results.
+
+ Args:
+ old_size: Either a `pyparsing.ParseResults`, or a valid int or string
+ input to `self.resolve_size` (see method docstring for further details).
+ squeeze: If True, any dimensions that are statically defined as 1 will be
+ removed from the shape tuple.
+
+ Returns:
+ A shape tuple containing ints for dimensions that are statically defined,
+ and string size identifiers for dimensions that can only be determined at
+ runtime.
+ """
+ if isinstance(old_size, pyparsing.ParseResults):
+ # For multi-dimensional arrays, convert each dimension separately
+ shape = tuple(self.resolve_size(dim) for dim in old_size)
+ else:
+ shape = (self.resolve_size(old_size),)
+ if squeeze:
+ shape = tuple(d for d in shape if d != 1) # Remove singleton dimensions
+ return shape
+
+ def resolve_typename(self, old_ctypes_typename):
+ """Gets a qualified ctypes typename from typedefs_dict and C_TO_CTYPES."""
+
+ # Recursively dereference any typenames declared in self.typedefs_dict
+ new_ctypes_typename = codegen_util.recursive_dict_lookup(
+ old_ctypes_typename, self.typedefs_dict)
+
+ # Try to convert to a ctypes native typename
+ new_ctypes_typename = header_parsing.C_TO_CTYPES.get(
+ new_ctypes_typename, new_ctypes_typename)
+
+ if new_ctypes_typename == old_ctypes_typename:
+ logging.warning("Could not resolve typename '%s'", old_ctypes_typename)
+
+ return new_ctypes_typename
+
+ def get_type_from_token(self, token, parent=None):
+ """Accepts a token returned by a parser, returns a subclass of CDeclBase."""
+
+ comment = codegen_util.mangle_comment(token.comment)
+ is_const = token.is_const == "const"
+
+ # An anonymous union declaration
+ if token.anonymous_union:
+ if not parent and parent.name:
+ raise Error(
+ "Anonymous unions must be members of a named struct or union.")
+
+ # Generate a name based on the name of the parent.
+ name = codegen_util.mangle_varname(parent.name + "_anon_union")
+
+ members = codegen_util.UniqueOrderedDict()
+ sub_structs = codegen_util.UniqueOrderedDict()
+ out = c_declarations.AnonymousUnion(
+ name, members, sub_structs, comment, parent)
+
+ # Add members
+ for sub_token in token.members:
+
+ # Recurse into nested structs
+ member = self.get_type_from_token(sub_token, parent=out)
+ out.members[member.name] = member
+
+ # Nested sub-structures need special treatment
+ if isinstance(member, c_declarations.Struct):
+ out.sub_structs[member.name] = member
+
+ # Add to dict of unions
+ self.types_dict[out.ctypes_typename] = out
+
+ # A struct declaration
+ elif token.members:
+
+ name = token.name
+
+ # If the name is empty, see if there is a type declaration that matches
+ # this struct's typename
+ if not name:
+ for k, v in six.iteritems(self.typedefs_dict):
+ if v == token.typename:
+ name = k
+
+ # Anonymous structs need a dummy typename
+ typename = token.typename
+ if not typename:
+ if parent:
+ typename = token.name
+ else:
+ raise Error(
+ "Anonymous structs that aren't members of a named struct are not "
+ "supported (name = '{token.name}').".format(token=token))
+
+ # Mangle the name if it contains any protected keywords
+ name = codegen_util.mangle_varname(name)
+
+ members = codegen_util.UniqueOrderedDict()
+ sub_structs = codegen_util.UniqueOrderedDict()
+ out = c_declarations.Struct(name, typename, members, sub_structs, comment,
+ parent, is_const)
+
+ # Map the old typename to the mangled typename in typedefs_dict
+ self.typedefs_dict[typename] = out.ctypes_typename
+
+ # Add members
+ for sub_token in token.members:
+
+ # Recurse into nested structs
+ member = self.get_type_from_token(sub_token, parent=out)
+ out.members[member.name] = member
+
+ # Nested sub-structures need special treatment
+ if isinstance(member, c_declarations.Struct):
+ out.sub_structs[member.name] = member
+
+ # Add to dict of structs
+ self.types_dict[out.ctypes_typename] = out
+
+ else:
+
+ name = codegen_util.mangle_varname(token.name)
+ typename = self.resolve_typename(token.typename)
+
+ # 1D array with size defined at compile time
+ if token.size:
+ shape = self.get_shape_tuple(token.size)
+ if typename in {header_parsing.NONE, header_parsing.CTYPES_CHAR}:
+ out = c_declarations.StaticPtrArray(
+ name, typename, shape, comment, parent, is_const)
+ else:
+ out = c_declarations.StaticNDArray(
+ name, typename, shape, comment, parent, is_const)
+
+ elif token.ptr:
+
+ # Pointer to a numpy-compatible type, could be an array or a scalar
+ if typename in header_parsing.CTYPES_TO_NUMPY:
+
+ # Multidimensional array (one or more dimensions might be undefined)
+ if name in self.hints_dict:
+
+ # Dynamically-sized dimensions have string identifiers
+ shape = self.hints_dict[name]
+ if any(isinstance(d, six.string_types) for d in shape):
+ out = c_declarations.DynamicNDArray(name, typename, shape,
+ comment, parent, is_const)
+ else:
+ out = c_declarations.StaticNDArray(name, typename, shape, comment,
+ parent, is_const)
+
+ # This must be a pointer to a scalar primitive
+ else:
+ out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
+ parent, is_const)
+
+ # Pointer to struct or other arbitrary type
+ else:
+ out = c_declarations.ScalarPrimitivePtr(name, typename, comment,
+ parent, is_const)
+
+ # A struct we've already encountered
+ elif typename in self.types_dict:
+ s = self.types_dict[typename]
+ if isinstance(s, c_declarations.FunctionPtrTypedef):
+ out = c_declarations.FunctionPtr(
+ name, token.name, s.typename, comment)
+ else:
+ out = c_declarations.Struct(name, s.typename, s.members,
+ s.sub_structs, comment, parent)
+
+ # Presumably this is a scalar primitive
+ else:
+ out = c_declarations.ScalarPrimitive(name, typename, comment, parent,
+ is_const)
+
+ return out
+
+ # Parsing functions.
+ # ----------------------------------------------------------------------------
+
+ def parse_hints(self, xmacro_src):
+ """Parses mjxmacro.h, update self.hints_dict."""
+ parser = header_parsing.XMACRO
+ for tokens, _, _ in parser.scanString(xmacro_src):
+ for xmacro in tokens:
+ for member in xmacro.members:
+ # "Squeeze out" singleton dimensions.
+ shape = self.get_shape_tuple(member.dims, squeeze=True)
+ self.hints_dict.update({member.name: shape})
+
+ if codegen_util.is_macro_pointer(xmacro.name):
+ struct_name = codegen_util.macro_struct_name(xmacro.name)
+ if struct_name not in self.index_dict:
+ self.index_dict[struct_name] = {}
+
+ self.index_dict[struct_name].update({member.name: shape})
+
+ def parse_enums(self, src):
+ """Parses mj*.h, update self.enums_dict."""
+ parser = header_parsing.ENUM_DECL
+ for tokens, _, _ in parser.scanString(src):
+ for enum in tokens:
+ members = codegen_util.UniqueOrderedDict()
+ value = 0
+ for member in enum.members:
+ # Leftward bitshift
+ if member.bit_lshift_a:
+ value = int(member.bit_lshift_a) << int(member.bit_lshift_b)
+ # Assignment
+ elif member.value:
+ value = int(member.value)
+ # Implicit count
+ else:
+ value += 1
+ members.update({member.name: value})
+ self.enums_dict.update({enum.name: members})
+
+ def parse_consts_typedefs(self, src):
+ """Updates self.consts_dict, self.typedefs_dict."""
+ parser = (header_parsing.COND_DECL |
+ header_parsing.UNCOND_DECL)
+ for tokens, _, _ in parser.scanString(src):
+ self.recurse_into_conditionals(tokens)
+
+ def recurse_into_conditionals(self, tokens):
+ """Called recursively within nested #if(n)def... #else... #endif blocks."""
+ for token in tokens:
+ # Another nested conditional block
+ if token.predicate:
+ if (token.predicate in self.get_consts_and_enums()
+ and self.get_consts_and_enums()[token.predicate]):
+ self.recurse_into_conditionals(token.if_true)
+ else:
+ self.recurse_into_conditionals(token.if_false)
+ # One or more declarations
+ else:
+ if token.typename:
+ self.typedefs_dict.update({token.name: token.typename})
+ elif token.value:
+ value = codegen_util.try_coerce_to_num(token.value)
+ # Avoid adding function aliases.
+ if isinstance(value, six.string_types):
+ continue
+ else:
+ self.consts_dict.update({token.name: value})
+ else:
+ self.consts_dict.update({token.name: True})
+
+ def parse_structs_and_function_pointer_typedefs(self, src):
+ """Updates self.types_dict."""
+ parser = (header_parsing.NESTED_STRUCTS |
+ header_parsing.FUNCTION_PTR_TYPE_DECL)
+ for tokens, _, _ in parser.scanString(src):
+ for token in tokens:
+ if token.return_type:
+ # This is a function type declaration.
+ self.types_dict[token.typename] = c_declarations.FunctionPtrTypedef(
+ token.typename,
+ self.get_type_from_token(token.return_type),
+ tuple(self.get_type_from_token(arg) for arg in token.arguments))
+ else:
+ # This is a struct or a union.
+ self.get_type_from_token(token)
+
+ def parse_functions(self, src):
+ """Updates self.funcs_dict."""
+ parser = header_parsing.MJAPI_FUNCTION_DECL
+ for tokens, _, _ in parser.scanString(src):
+ for token in tokens:
+ name = codegen_util.mangle_varname(token.name)
+ comment = codegen_util.mangle_comment(token.comment)
+ if token.arguments:
+ args = codegen_util.UniqueOrderedDict()
+ for arg in token.arguments:
+ a = self.get_type_from_token(arg)
+ args[a.name] = a
+ else:
+ args = None
+ if token.return_value:
+ ret_val = self.get_type_from_token(token.return_value)
+ else:
+ ret_val = None
+ func = c_declarations.Function(name, args, ret_val, comment)
+ self.funcs_dict[func.name] = func
+
+ def parse_global_strings(self, src):
+ """Updates self.strings_dict."""
+ parser = header_parsing.MJAPI_STRING_ARRAY
+ for token, _, _ in parser.scanString(src):
+ name = codegen_util.mangle_varname(token.name)
+ shape = self.get_shape_tuple(token.dims)
+ self.strings_dict[name] = c_declarations.StaticStringArray(
+ name, shape, symbol_name=token.name)
+
+ def parse_function_pointers(self, src):
+ """Updates self.func_ptrs_dict."""
+ parser = header_parsing.MJAPI_FUNCTION_PTR
+ for token, _, _ in parser.scanString(src):
+ name = codegen_util.mangle_varname(token.name)
+ self.func_ptrs_dict[name] = c_declarations.FunctionPtr(
+ name, symbol_name=token.name,
+ type_name=token.typename, comment=token.comment)
+
+ # Code generation methods
+ # ----------------------------------------------------------------------------
+
+ def make_header(self, imports=()):
+ """Returns a header string for an auto-generated Python source file."""
+ docstring = textwrap.dedent("""
+ \"\"\"Automatically generated by {scriptname:}.
+
+ MuJoCo header version: {mujoco_version:}
+ \"\"\"
+ """.format(scriptname=os.path.split(__file__)[-1],
+ mujoco_version=self.consts_dict["mjVERSION_HEADER"]))
+ docstring = docstring[1:] # Strip the leading line break.
+ return "\n".join(
+ [docstring] + _BOILERPLATE_IMPORTS + list(imports) + ["\n"])
+
+ def write_consts(self, fname):
+ """Write constants."""
+ imports = [
+ "# pylint: disable=invalid-name",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Constants") + "\n")
+ for name, value in six.iteritems(self.consts_dict):
+ f.write("{0} = {1}\n".format(name, value))
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_enums(self, fname):
+ """Write enum definitions."""
+ with open(fname, "w") as f:
+ imports = [
+ "import collections",
+ "# pylint: disable=invalid-name",
+ "# pylint: disable=line-too-long",
+ ]
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Enums"))
+ for enum_name, members in six.iteritems(self.enums_dict):
+ fields = ["\"{}\"".format(name) for name in six.iterkeys(members)]
+ values = [str(value) for value in six.itervalues(members)]
+ s = textwrap.dedent("""
+ {0} = collections.namedtuple(
+ "{0}",
+ [{1}]
+ )({2})
+ """).format(enum_name, ",\n ".join(fields), ", ".join(values))
+ f.write(s)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_types(self, fname):
+ """Write ctypes struct and function type declarations."""
+ imports = [
+ "import ctypes",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line(
+ "ctypes struct, union, and function type declarations"))
+ for type_decl in six.itervalues(self.types_dict):
+ f.write("\n" + type_decl.ctypes_decl)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_wrappers(self, fname):
+ """Write wrapper classes for ctypes structs."""
+ with open(fname, "w") as f:
+ imports = [
+ "import ctypes",
+ "# pylint: disable=undefined-variable",
+ "# pylint: disable=wildcard-import",
+ "from {} import util".format(_MODULE),
+ "from {}.mjbindings.types import *".format(_MODULE),
+ ]
+ f.write(self.make_header(imports))
+ f.write(codegen_util.comment_line("Low-level wrapper classes"))
+ for type_decl in six.itervalues(self.types_dict):
+ if isinstance(type_decl, c_declarations.Struct):
+ f.write("\n" + type_decl.wrapper_class)
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_funcs_and_globals(self, fname):
+ """Write ctypes declarations for functions and global data."""
+ imports = [
+ "import collections",
+ "import ctypes",
+ "# pylint: disable=undefined-variable",
+ "# pylint: disable=wildcard-import",
+ "from {} import util".format(_MODULE),
+ "from {}.mjbindings.types import *".format(_MODULE),
+ "import numpy as np",
+ "# pylint: disable=line-too-long",
+ "# pylint: disable=invalid-name",
+ "# common_typos_disable",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write("mjlib = util.get_mjlib()\n")
+
+ f.write("\n" + codegen_util.comment_line("ctypes function declarations"))
+ for function in six.itervalues(self.funcs_dict):
+ f.write("\n" + function.ctypes_func_decl(cdll_name="mjlib"))
+
+ # Only require strings for UI purposes.
+ f.write("\n" + codegen_util.comment_line("String arrays") + "\n")
+ for string_arr in six.itervalues(self.strings_dict):
+ f.write(string_arr.ctypes_var_decl(cdll_name="mjlib"))
+
+ f.write("\n" + codegen_util.comment_line("Callback function pointers"))
+
+ fields = ["'_{0}'".format(func_ptr.name)
+ for func_ptr in self.func_ptrs_dict.values()]
+ values = [func_ptr.ctypes_var_decl(cdll_name="mjlib")
+ for func_ptr in self.func_ptrs_dict.values()]
+ f.write(textwrap.dedent("""
+ class _Callbacks(object):
+
+ __slots__ = [
+ {0}
+ ]
+
+ def __init__(self):
+ {1}
+ """).format(",\n ".join(fields), "\n ".join(values)))
+
+ indent = codegen_util.Indenter()
+ with indent:
+ for func_ptr in self.func_ptrs_dict.values():
+ f.write(indent(func_ptr.getters_setters_with_custom_prefix("self._")))
+
+ f.write("\n\ncallbacks = _Callbacks() # pylint: disable=invalid-name")
+ f.write("\ndel _Callbacks\n")
+
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
+
+ def write_index_dict(self, fname):
+ """Write file containing array shape information for indexing."""
+ pp = pprint.PrettyPrinter()
+ output_string = pp.pformat(dict(self.index_dict))
+ indent = codegen_util.Indenter()
+ imports = [
+ "# pylint: disable=bad-continuation",
+ "# pylint: disable=line-too-long",
+ ]
+ with open(fname, "w") as f:
+ f.write(self.make_header(imports))
+ f.write("array_sizes = (\n")
+ with indent:
+ f.write(output_string)
+ f.write("\n)")
+ f.write("\n" + codegen_util.comment_line("End of generated code"))
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py b/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py
new file mode 100644
index 0000000..fa532ac
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py
@@ -0,0 +1,520 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Python representations of C declarations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+from dm_control.autowrap import codegen_util
+from dm_control.autowrap import header_parsing
+import six
+
+
+class CDeclBase(object):
+ """Base class for Python representations of C declarations."""
+
+ def __init__(self, **attrs):
+ self._attrs = attrs
+ for k, v in six.iteritems(attrs):
+ setattr(self, k, v)
+
+ def __repr__(self):
+ """Pretty string representation."""
+ attr_str = ", ".join("{0}={1!r}".format(k, v)
+ for k, v in six.iteritems(self._attrs))
+ return "{0}({1})".format(type(self).__name__, attr_str)
+
+ @property
+ def docstring(self):
+ """Auto-generate a docstring for self."""
+ return "\n".join(textwrap.wrap(self.comment, 74))
+
+ @property
+ def ctypes_typename(self):
+ """ctypes typename."""
+ return self.typename
+
+ @property
+ def ctypes_ptr(self):
+ """String representation of self as a ctypes pointer."""
+ return header_parsing.CTYPES_PTRS.get(
+ self.ctypes_typename, "ctypes.POINTER({})".format(self.ctypes_typename))
+
+ @property
+ def np_dtype(self):
+ """Get a numpy dtype name for self, fall back on self.ctypes_typename."""
+ return header_parsing.CTYPES_TO_NUMPY.get(self.ctypes_typename,
+ self.ctypes_typename)
+
+ @property
+ def np_flags(self):
+ """Tuple of strings specifying numpy.ndarray flags."""
+ return ("C", "W")
+
+
+class Struct(CDeclBase):
+ """C struct declaration."""
+
+ def __init__(self, name, typename, members, sub_structs, comment="",
+ parent=None, is_const=None):
+ super(Struct, self).__init__(name=name,
+ typename=typename,
+ members=members,
+ sub_structs=sub_structs,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_decl(self):
+ """Generates a ctypes.Structure declaration for self."""
+ indent = codegen_util.Indenter()
+ lines = []
+ lines.append(textwrap.dedent("""
+ class {0.ctypes_typename}(ctypes.Structure):
+ \"\"\"{0.docstring}\"\"\"""".format(self)))
+ anonymous_fields = [member.name for member in six.itervalues(self.members)
+ if isinstance(member, AnonymousUnion)]
+ with indent:
+ if anonymous_fields:
+ lines.append(indent("_anonymous_ = ["))
+ with indent:
+ with indent:
+ for name in anonymous_fields:
+ lines.append(indent("'" + name + "',"))
+ lines.append(indent("]"))
+
+ if self.members:
+ lines.append(indent("_fields_ = ["))
+ with indent:
+ with indent:
+ for member in six.itervalues(self.members):
+ lines.append(indent(member.ctypes_field_decl + ","))
+ lines.append(indent("]\n"))
+ return "\n".join(lines)
+
+ @property
+ def ctypes_typename(self):
+ """Mangles ctypes.Structure typenames to distinguish them from wrappers."""
+ return codegen_util.mangle_struct_typename(self.typename)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def wrapper_name(self):
+ return codegen_util.camel_case(self.typename) + "Wrapper"
+
+ @property
+ def wrapper_class(self):
+ """Generates a Python class containing getter/setter methods for members."""
+ indent = codegen_util.Indenter()
+ lines = [textwrap.dedent("""
+ class {0.wrapper_name}(util.WrapperBase):
+ \"\"\"{0.docstring}\"\"\"""".format(self))]
+ with indent:
+ for member in six.itervalues(self.members):
+ if isinstance(member, AnonymousUnion):
+ for submember in six.itervalues(member.members):
+ lines.append(indent(submember.getters_setters))
+ else:
+ lines.append(indent(member.getters_setters))
+ lines.append("") # Add an extra newline at the end of the class definition.
+ return "\n".join(lines)
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return {0.wrapper_name}(ctypes.pointer(self._ptr.contents.{0.name}))""" # pylint: disable=missing-format-attribute
+ .format(self))
+
+ @property
+ def arg(self):
+ """String representation of self as a ctypes function argument."""
+ return self.ctypes_typename
+
+
+class AnonymousUnion(CDeclBase):
+ """Anonymous union declaration."""
+
+ def __init__(self, name, members, sub_structs, comment="", parent=None):
+ super(AnonymousUnion, self).__init__(name=name,
+ members=members,
+ sub_structs=sub_structs,
+ comment=comment,
+ parent=parent)
+
+ @property
+ def ctypes_decl(self):
+ """Generates a ctypes.Union declaration for self."""
+ indent = codegen_util.Indenter()
+ lines = []
+ lines.append(textwrap.dedent("""
+ class {0.ctypes_typename}(ctypes.Union):
+ \"\"\"{0.docstring}\"\"\"""".format(self)))
+ with indent:
+ if self.members:
+ lines.append(indent("_fields_ = ["))
+ with indent:
+ with indent:
+ for member in six.itervalues(self.members):
+ lines.append(indent(member.ctypes_field_decl + ","))
+ lines.append(indent("]\n"))
+ return "\n".join(lines)
+
+ @property
+ def ctypes_typename(self):
+ """Mangles ctypes.Union typenames to distinguish them from wrappers."""
+ return codegen_util.mangle_struct_typename(self.name)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute
+
+
+class ScalarPrimitive(CDeclBase):
+ """A scalar value corresponding to a C primitive type."""
+
+ def __init__(self, name, typename, comment="", parent=None, is_const=None):
+ super(ScalarPrimitive, self).__init__(name=name,
+ typename=typename,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return self._ptr.contents.{0.name}
+
+ @{0.name}.setter
+ def {0.name}(self, value):
+ self._ptr.contents.{0.name} = value""".format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """String representation of self as a ctypes function argument."""
+ return self.ctypes_typename
+
+
+class ScalarPrimitivePtr(CDeclBase):
+ """Pointer to a ScalarPrimitive."""
+
+ def __init__(self, name, typename, comment="", parent=None, is_const=None):
+ super(ScalarPrimitivePtr, self).__init__(name=name,
+ typename=typename,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return self._ptr.contents.{0.name}
+
+ @{0.name}.setter
+ def {0.name}(self, value):
+ self._ptr.contents.{0.name} = value""".format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ # we assume that every pointer that maps to a numpy dtype corresponds to an
+ # array argument/return value
+ if self.ctypes_typename in header_parsing.CTYPES_TO_NUMPY:
+ return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s})"
+ .format(self)) # pylint: disable=missing-format-attribute
+ else:
+ return self.ctypes_ptr
+
+
+class StaticPtrArray(CDeclBase):
+ """Array of arbitrary pointers whose size can be inferred from the headers."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(StaticPtrArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ if self.typename in header_parsing.CTYPES_PTRS:
+ return "('{0.name}', {0.ctypes_ptr} * {1})".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+ else:
+ return "('{0.name}', {0.ctypes_typename} * {1})".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return textwrap.dedent("""
+ @property
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return self._ptr.contents.{0.name}""".format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return "{0.ctypes_typename}".format(self)
+
+
+class StaticNDArray(CDeclBase):
+ """Numeric array whose dimensions can all be inferred from the headers."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(StaticNDArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_typename} * ({1}))".format( # pylint: disable=missing-format-attribute
+ self, " * ".join(str(d) for d in self.shape))
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with a getter method for self (no setter)."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return util.buf_to_npy(self._ptr.contents.{0.name}, {0.shape!s})""" # pylint: disable=missing-format-attribute
+ .format(self))
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return ("util.ndptr(shape={0.shape}, dtype={0.np_dtype}, " # pylint: disable=missing-format-attribute
+ "flags={0.np_flags!s})".format(self))
+
+
+class DynamicNDArray(CDeclBase):
+ """Numeric array where one or more dimensions are determined at runtime."""
+
+ def __init__(self, name, typename, shape, comment="", parent=None,
+ is_const=None):
+ super(DynamicNDArray, self).__init__(name=name,
+ typename=typename,
+ shape=shape,
+ comment=comment,
+ parent=parent,
+ is_const=is_const)
+
+ @property
+ def runtime_shape_str(self):
+ """String representation of shape tuple at runtime."""
+ rs = []
+ for d in self.shape:
+ # dynamically-sized dimension
+ if isinstance(d, six.string_types):
+ if self.parent and d in self.parent.members:
+ rs.append("self.{}".format(d))
+ else:
+ rs.append("self._model.{}".format(d))
+ # static dimension
+ else:
+ rs.append(str(d))
+ return str(tuple(rs)).replace("'", "") # strip quotes from string rep
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with a getter method for self (no setter)."""
+ return textwrap.dedent("""
+ @util.CachedProperty
+ def {0.name}(self):
+ \"\"\"{0.docstring}\"\"\"
+ return util.buf_to_npy(self._ptr.contents.{0.name},
+ {0.runtime_shape_str})""".format(self)) # pylint: disable=missing-format-attribute
+
+ @property
+ def arg(self):
+ """Generates string representation of self as a ctypes function argument."""
+ return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s})"
+ .format(self)) # pylint: disable=missing-format-attribute
+
+
+class Function(CDeclBase):
+ """A function declaration including input type(s) and return type."""
+
+ def __init__(self, name, arguments, return_value, comment=""):
+ super(Function, self).__init__(name=name,
+ arguments=arguments,
+ return_value=return_value,
+ comment=comment)
+
+ def ctypes_func_decl(self, cdll_name):
+ """Generates a ctypes function declaration."""
+ indent = codegen_util.Indenter()
+ lines = []
+ lines.append("{0}.{1}.__doc__ = \"\"\"\n{2}\"\"\"".format(
+ cdll_name, self.name, self.docstring))
+ if self.arguments:
+ lines.append("{0}.{1}.argtypes = [".format(cdll_name, self.name))
+ with indent:
+ with indent:
+ lines.extend(indent(a.arg + ",")
+ for a in six.itervalues(self.arguments))
+ lines.append("]")
+ else:
+ lines.append("{0}.{1}.argtypes = None".format(cdll_name, self.name))
+ if self.return_value:
+ lines.append("{0}.{1}.restype = {2}".format(
+ cdll_name, self.name, self.return_value.arg))
+ else:
+ lines.append("{0}.{1}.restype = None".format(cdll_name, self.name))
+ lines.append("") # Force a newline after the declaration.
+ return "\n".join(lines)
+
+ @property
+ def docstring(self):
+ """Generates a docstring."""
+ indent = codegen_util.Indenter()
+ lines = textwrap.wrap(self.comment, width=80)
+ if self.arguments:
+ lines.append("\nArgs:")
+ with indent:
+ for a in six.itervalues(self.arguments):
+ s = "{a.name}: {a.arg}{const}".format(
+ a=a, const=(" " if a.is_const else ""))
+ lines.append(indent(s))
+ if self.return_value:
+ lines.append("\nReturns:")
+ with indent:
+ lines.append(indent(self.return_value.arg))
+ lines.append("") # Force a newline at the end of the docstring.
+ return "\n".join(lines)
+
+
+class StaticStringArray(CDeclBase):
+ """A string array of fixed dimensions exported by MuJoCo."""
+
+ def __init__(self, name, shape, symbol_name):
+ super(StaticStringArray, self).__init__(name=name,
+ shape=shape,
+ symbol_name=symbol_name)
+
+ def ctypes_var_decl(self, cdll_name=""):
+ """Generates a ctypes export statement."""
+
+ ptr_str = "ctypes.c_char_p"
+ for dim in self.shape[::-1]:
+ ptr_str = "({0} * {1!s})".format(ptr_str, dim)
+
+ return "{0} = {1}.in_dll({2}, {3!r})\n".format(
+ self.name, ptr_str, cdll_name, self.symbol_name)
+
+
+class FunctionPtrTypedef(CDeclBase):
+ """A type declaration for a C function pointer."""
+
+ def __init__(self, typename, return_type, argument_types):
+ super(FunctionPtrTypedef, self).__init__(
+ typename=typename, return_type=return_type,
+ argument_types=argument_types)
+
+ @property
+ def ctypes_decl(self):
+ """Generates a ctypes.CFUNCTYPE declaration for self."""
+ types = (self.return_type,) + self.argument_types
+ types_decl = ", ".join(t.arg for t in types)
+ return "{0} = ctypes.CFUNCTYPE({1})".format(self.typename, types_decl)
+
+
+class FunctionPtr(CDeclBase):
+ """A pointer to an externally defined C function."""
+
+ def __init__(self, name, symbol_name, type_name, comment=""):
+ super(FunctionPtr, self).__init__(
+ name=name, symbol_name=symbol_name,
+ type_name=type_name, comment=comment)
+
+ @property
+ def ctypes_field_decl(self):
+ """Generates a declaration for self as a field of a ctypes.Structure."""
+ return "('{0.name}', {0.type_name})".format(self) # pylint: disable=missing-format-attribute
+
+ def ctypes_var_decl(self, cdll_name=""):
+ """Generates a ctypes export statement."""
+
+ return "self._{0} = ctypes.c_void_p.in_dll({1}, {2!r})".format(
+ self.name, cdll_name, self.symbol_name)
+
+ def getters_setters_with_custom_prefix(self, prefix):
+ return textwrap.dedent("""
+ @property
+ def {0.name}(self):
+ if {1}{0.name}.value:
+ return {0.type_name}({1}{0.name}.value)
+ else:
+ return None
+
+ @{0.name}.setter
+ def {0.name}(self, value):
+ new_func_ptr, wrapped_pyfunc = util.cast_func_to_c_void_p(
+ value, {0.type_name})
+ # Prevents wrapped_pyfunc from being inadvertently garbage collected.
+ {1}{0.name}._wrapped_pyfunc = wrapped_pyfunc
+ {1}{0.name}.value = new_func_ptr.value
+ """.format(self, prefix)) # pylint: disable=missing-format-attribute
+
+ @property
+ def getters_setters(self):
+ """Populates a Python class with getter & setter methods for self."""
+ return self.getters_setters_with_custom_prefix(prefix="self._ptr.contents.")
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py b/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py
new file mode 100644
index 0000000..1565a57
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py
@@ -0,0 +1,154 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Misc helper functions needed by autowrap.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import keyword
+import re
+
+import six
+from six.moves import builtins
+
+_MJXMACRO_SUFFIX = "_POINTERS"
+_PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist + dir(builtins))
+if not six.PY2:
+ _PYTHON_RESERVED_KEYWORDS.add("buffer")
+
+
+class Indenter(object):
+ r"""Callable context manager for tracking string indentation levels.
+
+ Args:
+ level: The initial indentation level.
+ indent_str: The string used to indent each line.
+
+ Example usage:
+
+ ```python
+ idt = Indenter()
+ s = idt("level 0\n")
+ with idt:
+ s += idt("level 1\n")
+ with idt:
+ s += idt("level 2\n")
+ s += idt("level 1 again\n")
+ s += idt("back to level 0\n")
+ print(s)
+ ```
+ """
+
+ def __init__(self, level=0, indent_str=" "):
+ self.indent_str = indent_str
+ self.level = level
+
+ def __enter__(self):
+ self.level += 1
+ return self
+
+ def __exit__(self, type_, value, traceback):
+ self.level -= 1
+
+ def __call__(self, string):
+ return indent(string, self.level, self.indent_str)
+
+
+def indent(s, n=1, indent_str=" "):
+ """Inserts `n * indent_str` at the start of each non-empty line in `s`."""
+ p = n * indent_str
+ return "".join((p + l) if l.lstrip() else l for l in s.splitlines(True))
+
+
+class UniqueOrderedDict(collections.OrderedDict):
+ """Subclass of `OrderedDict` that enforces the uniqueness of keys."""
+
+ def __setitem__(self, k, v):
+ if k in self:
+ raise ValueError("Key '{}' already exists.".format(k))
+ super(UniqueOrderedDict, self).__setitem__(k, v)
+
+
+def macro_struct_name(name, suffix=None):
+ """Converts mjxmacro struct names, e.g. "MJDATA_POINTERS" to "mjdata"."""
+ if suffix is None:
+ suffix = _MJXMACRO_SUFFIX
+ return name[:-len(suffix)].lower()
+
+
+def is_macro_pointer(name):
+ """Returns True if the mjxmacro struct name contains pointer sizes."""
+ return name.endswith(_MJXMACRO_SUFFIX)
+
+
+def mangle_varname(s):
+ """Append underscores to ensure that `s` is not a reserved Python keyword."""
+ while s in _PYTHON_RESERVED_KEYWORDS:
+ s += "_"
+ return s
+
+
+def mangle_struct_typename(s):
+ """Strip leading underscores and make uppercase."""
+ return s.lstrip("_").upper()
+
+
+def mangle_comment(s):
+ """Strip extraneous whitespace, add full-stops at end of each line."""
+ if not isinstance(s, six.string_types):
+ return "\n".join(mangle_comment(line) for line in s)
+ elif not s:
+ return "."
+ else:
+ out = "\n".join(" ".join(line.split()) for line in s.splitlines())
+ if not out.endswith("."):
+ out += "."
+ return out
+
+
+def camel_case(s):
+ """Convert a snake_case string (maybe with lowerCaseFirst) to CamelCase."""
+ tokens = re.sub(r"([A-Z])", r" \1", s.replace("_", " ")).split()
+ return "".join(w.title() for w in tokens)
+
+
+def try_coerce_to_num(s, try_types=(int, float)):
+ """Try to coerce string to Python numeric type, return None if empty."""
+ if not s:
+ return None
+ for try_type in try_types:
+ try:
+ return try_type(s.rstrip("UuFf"))
+ except (ValueError, AttributeError):
+ continue
+ return s
+
+
+def recursive_dict_lookup(key, try_dict, max_depth=10):
+ """Recursively map dictionary keys to values."""
+ if max_depth < 0:
+ raise KeyError("Maximum recursion depth exceeded")
+ while key in try_dict:
+ key = try_dict[key]
+ return recursive_dict_lookup(key, try_dict, max_depth - 1)
+ return key
+
+
+def comment_line(string, width=79, fill_char="-"):
+ """Wraps `string` in a padded comment line."""
+ return "# {0:{2}^{1}}\n".format(string, width - 2, fill_char)
diff --git a/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py b/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py
new file mode 100644
index 0000000..002f109
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py
@@ -0,0 +1,318 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""pyparsing definitions and helper functions for parsing MuJoCo headers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import pyparsing as pp
+import six
+from six.moves import map
+
+# NB: Don't enable parser memoization (`pp.ParserElement.enablePackrat()`),
+# since this results in a ~6x slowdown.
+
+
+NONE = "None"
+CTYPES_CHAR = "ctypes.c_char"
+
+C_TO_CTYPES = {
+ # integers
+ "int": "ctypes.c_int",
+ "unsigned int": "ctypes.c_uint",
+ "char": CTYPES_CHAR,
+ "unsigned char": "ctypes.c_ubyte",
+ "size_t": "ctypes.c_size_t",
+ # floats
+ "float": "ctypes.c_float",
+ "double": "ctypes.c_double",
+ # pointers
+ "void": NONE,
+}
+
+CTYPES_PTRS = {NONE: "ctypes.c_void_p"}
+
+CTYPES_TO_NUMPY = {
+ # integers
+ "ctypes.c_int": "np.intc",
+ "ctypes.c_uint": "np.uintc",
+ "ctypes.c_ubyte": "np.ubyte",
+ # floats
+ "ctypes.c_float": "np.float32",
+ "ctypes.c_double": "np.float64",
+}
+
+# Helper functions for constructing recursive parsers.
+# ------------------------------------------------------------------------------
+
+
+def _nested_scopes(opening, closing, body):
+ """Constructs a parser for (possibly nested) scopes."""
+ scope = pp.Forward()
+ scope << pp.Group( # pylint: disable=expression-not-assigned
+ opening +
+ pp.ZeroOrMore(body | scope)("members") +
+ closing)
+ return scope
+
+
+def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false):
+ """Constructs a parser for (possibly nested) if...(else)...endif blocks."""
+ ifelse = pp.Forward()
+ ifelse << pp.Group( # pylint: disable=expression-not-assigned
+ if_ +
+ pred("predicate") +
+ pp.ZeroOrMore(match_if_true | ifelse)("if_true") +
+ pp.Optional(else_ +
+ pp.ZeroOrMore(match_if_false | ifelse)("if_false")) +
+ endif)
+ return ifelse
+
+
+# Some common string patterns to suppress.
+# ------------------------------------------------------------------------------
+(X, LPAREN, RPAREN, LBRACK, RBRACK, LBRACE, RBRACE, SEMI, COMMA, EQUAL, FSLASH,
+ BSLASH) = list(map(pp.Suppress, "X()[]{};,=/\\"))
+EOL = pp.LineEnd().suppress()
+
+# Comments, continuation.
+# ------------------------------------------------------------------------------
+COMMENT = pp.Combine(
+ pp.Suppress("//") +
+ pp.Optional(pp.White()).suppress() +
+ pp.SkipTo(pp.LineEnd()))
+
+MULTILINE_COMMENT = pp.delimitedList(
+ COMMENT.copy().setWhitespaceChars(" \t"), delim=EOL)
+
+CONTINUATION = (BSLASH + pp.LineEnd()).suppress()
+
+# Preprocessor directives.
+# ------------------------------------------------------------------------------
+DEFINE = pp.Keyword("#define").suppress()
+IFDEF = pp.Keyword("#ifdef").suppress()
+IFNDEF = pp.Keyword("#ifndef").suppress()
+ELSE = pp.Keyword("#else").suppress()
+ENDIF = pp.Keyword("#endif").suppress()
+
+# Variable names, types, literals etc.
+# ------------------------------------------------------------------------------
+NAME = pp.Word(pp.alphanums + "_")
+INT = pp.Word(pp.nums + "UuLl")
+FLOAT = pp.Word(pp.nums + ".+-EeFf")
+NUMBER = FLOAT | INT
+
+# Dimensions can be of the form `[3]`, `[constant_name]` or `[2*constant_name]`
+ARRAY_DIM = pp.Combine(
+ LBRACK +
+ (INT | NAME) +
+ pp.Optional(pp.Literal("*")) +
+ pp.Optional(INT | NAME) +
+ RBRACK)
+
+PTR = pp.Literal("*")
+EXTERN = pp.Keyword("extern")
+NATIVE_TYPENAME = pp.MatchFirst(
+ [pp.Keyword(n) for n in six.iterkeys(C_TO_CTYPES)])
+
+# Macros.
+# ------------------------------------------------------------------------------
+
+HDR_GUARD = DEFINE + "THIRD_PARTY_MUJOCO_HDRS_"
+
+# e.g. "#define mjUSEDOUBLE"
+DEF_FLAG = pp.Group(
+ DEFINE +
+ NAME("name") +
+ (COMMENT("comment") | EOL)).ignore(HDR_GUARD)
+
+# e.g. "#define mjMINVAL 1E-14 // minimum value in any denominator"
+DEF_CONST = pp.Group(
+ DEFINE +
+ NAME("name") +
+ (NUMBER | NAME)("value") +
+ (COMMENT("comment") | EOL))
+
+# e.g. "X( mjtNum*, name_textadr, ntext, 1 )"
+XMEMBER = pp.Group(
+ X +
+ LPAREN +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ COMMA +
+ NAME("name") +
+ COMMA +
+ pp.delimitedList((INT | NAME), delim=COMMA)("dims") +
+ RPAREN)
+
+XMACRO = pp.Group(
+ pp.Optional(COMMENT("comment")) +
+ DEFINE +
+ NAME("name") +
+ CONTINUATION +
+ pp.delimitedList(XMEMBER, delim=CONTINUATION)("members"))
+
+
+# Type/variable declarations.
+# ------------------------------------------------------------------------------
+TYPEDEF = pp.Keyword("typedef").suppress()
+STRUCT = pp.Keyword("struct")
+UNION = pp.Keyword("union")
+ENUM = pp.Keyword("enum").suppress()
+
+# e.g. "typedef unsigned char mjtByte; // used for true/false"
+TYPE_DECL = pp.Group(
+ TYPEDEF +
+ pp.Optional(STRUCT) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ SEMI +
+ pp.Optional(COMMENT("comment")))
+
+# Declarations of flags/constants/types.
+UNCOND_DECL = DEF_FLAG | DEF_CONST | TYPE_DECL
+
+# Declarations inside (possibly nested) #if(n)def... #else... #endif... blocks.
+COND_DECL = _nested_if_else(IFDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL)
+# Note: this doesn't work for '#if defined(FLAG)' blocks
+
+# e.g. "mjtNum gravity[3]; // gravitational acceleration"
+STRUCT_MEMBER = pp.Group(
+ pp.Optional(STRUCT("struct")) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ pp.ZeroOrMore(ARRAY_DIM)("size") +
+ SEMI +
+ pp.Optional(COMMENT("comment")))
+
+# Struct declaration within a union (non-nested).
+UNION_STRUCT_DECL = pp.Group(
+ STRUCT("struct") +
+ pp.Optional(NAME("typename")) +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE +
+ pp.OneOrMore(STRUCT_MEMBER)("members") +
+ RBRACE +
+ pp.Optional(NAME("name")) +
+ SEMI)
+
+ANONYMOUS_UNION_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ UNION("anonymous_union") +
+ LBRACE +
+ pp.OneOrMore(
+ UNION_STRUCT_DECL |
+ STRUCT_MEMBER |
+ COMMENT.suppress())("members") +
+ RBRACE +
+ SEMI)
+
+# Multiple (possibly nested) struct declarations.
+NESTED_STRUCTS = _nested_scopes(
+ opening=(STRUCT +
+ pp.Optional(NAME("typename")) +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE),
+ closing=(RBRACE + pp.Optional(NAME("name")) + SEMI),
+ body=pp.OneOrMore(
+ STRUCT_MEMBER |
+ ANONYMOUS_UNION_DECL |
+ COMMENT.suppress())("members"))
+
+BIT_LSHIFT = INT("bit_lshift_a") + pp.Suppress("<<") + INT("bit_lshift_b")
+
+ENUM_LINE = pp.Group(
+ NAME("name") +
+ pp.Optional(EQUAL + (INT("value") ^ BIT_LSHIFT)) +
+ pp.Optional(COMMA) +
+ pp.Optional(COMMENT("comment")))
+
+ENUM_DECL = pp.Group(
+ TYPEDEF +
+ ENUM +
+ NAME("typename") +
+ pp.Optional(COMMENT("comment")) +
+ LBRACE +
+ pp.OneOrMore(ENUM_LINE | COMMENT.suppress())("members") +
+ RBRACE +
+ pp.Optional(NAME("name")) +
+ SEMI)
+
+# Function declarations.
+# ------------------------------------------------------------------------------
+MJAPI = pp.Keyword("MJAPI").suppress()
+CONST = pp.Keyword("const")
+VOID = pp.Group(pp.Keyword("void") + ~PTR).suppress()
+
+ARG = pp.Group(
+ pp.Optional(CONST("is_const")) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")) +
+ NAME("name") +
+ pp.Optional(ARRAY_DIM("size")))
+
+RET = pp.Group(
+ pp.Optional(CONST("is_const")) +
+ (NATIVE_TYPENAME | NAME)("typename") +
+ pp.Optional(PTR("ptr")))
+
+FUNCTION_DECL = (
+ (VOID | RET("return_value")) +
+ NAME("name") +
+ LPAREN +
+ (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) +
+ RPAREN +
+ SEMI)
+
+MJAPI_FUNCTION_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ pp.LineStart() +
+ MJAPI +
+ FUNCTION_DECL)
+
+# e.g.
+# // predicate function: set enable/disable based on item category
+# typedef int (*mjfItemEnable)(int category, void* data);
+FUNCTION_PTR_TYPE_DECL = pp.Group(
+ pp.Optional(MULTILINE_COMMENT("comment")) +
+ TYPEDEF +
+ RET("return_type") +
+ LPAREN +
+ PTR +
+ NAME("typename") +
+ RPAREN +
+ LPAREN +
+ (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) +
+ RPAREN +
+ SEMI)
+
+# Global variables.
+# ------------------------------------------------------------------------------
+
+MJAPI_STRING_ARRAY = (
+ MJAPI +
+ EXTERN +
+ CONST +
+ pp.Keyword("char") +
+ PTR +
+ NAME("name") +
+ pp.OneOrMore(ARRAY_DIM)("dims") +
+ SEMI)
+
+MJAPI_FUNCTION_PTR = MJAPI + EXTERN + NAME("typename") + NAME("name") + SEMI
diff --git a/DMC/src/env/dm_control/dm_control/composer/__init__.py b/DMC/src/env/dm_control/dm_control/composer/__init__.py
new file mode 100644
index 0000000..18d5037
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module containing abstract base classes for Composer environments."""
+
+from dm_control.composer.arena import Arena
+from dm_control.composer.constants import * # pylint: disable=wildcard-import
+from dm_control.composer.define import cached_property
+from dm_control.composer.define import observable
+from dm_control.composer.entity import Entity
+from dm_control.composer.entity import FreePropObservableMixin
+from dm_control.composer.entity import ModelWrapperEntity
+from dm_control.composer.entity import Observables
+from dm_control.composer.environment import Environment
+from dm_control.composer.environment import EpisodeInitializationError
+from dm_control.composer.environment import HOOK_NAMES
+from dm_control.composer.initializer import Initializer
+from dm_control.composer.robot import Robot
+from dm_control.composer.task import NullTask
+from dm_control.composer.task import Task
diff --git a/DMC/src/env/dm_control/dm_control/composer/arena.py b/DMC/src/env/dm_control/dm_control/composer/arena.py
new file mode 100644
index 0000000..1085faa
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/arena.py
@@ -0,0 +1,52 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""The base empty arena that defines global settings for Composer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from dm_control import mjcf
+from dm_control.composer import entity as entity_module
+
+_ARENA_XML_PATH = os.path.join(os.path.dirname(__file__), 'arena.xml')
+
+
+class Arena(entity_module.Entity):
+ """The base empty arena that defines global settings for Composer."""
+
+ def _build(self, name=None):
+ """Initializes this arena.
+
+ Args:
+ name: (optional) A string, the name of this arena. If `None`, use the
+ model name defined in the MJCF file.
+ """
+ self._mjcf_root = mjcf.from_path(_ARENA_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+
+ def add_free_entity(self, entity):
+ """Includes an entity in the arena as a free-moving body."""
+ frame = self.attach(entity)
+ frame.add('freejoint')
+ return frame
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
diff --git a/DMC/src/env/dm_control/dm_control/composer/arena.xml b/DMC/src/env/dm_control/dm_control/composer/arena.xml
new file mode 100644
index 0000000..e8a1774
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/arena.xml
@@ -0,0 +1,11 @@
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/composer/constants.py b/DMC/src/env/dm_control/dm_control/composer/constants.py
new file mode 100644
index 0000000..bf8616b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/constants.py
@@ -0,0 +1,22 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining constant values for Composer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+SENSOR_SITES_GROUP = 4
diff --git a/DMC/src/env/dm_control/dm_control/composer/define.py b/DMC/src/env/dm_control/dm_control/composer/define.py
new file mode 100644
index 0000000..f03f220
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/define.py
@@ -0,0 +1,65 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Decorators for Entity methods returning elements and observables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import threading
+
+
+class cached_property(property): # pylint: disable=invalid-name
+ """A property that is evaluated only once per object instance."""
+
+ def __init__(self, func, doc=None):
+ super(cached_property, self).__init__(fget=func, doc=doc)
+ self.lock = threading.RLock()
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ name = self.fget.__name__
+ obj_dict = obj.__dict__
+ try:
+ # Try returning a precomputed value without locking first.
+ # Profiling shows that the lock takes up a non-trivial amount of time.
+ return obj_dict[name]
+ except KeyError:
+ # The value hasn't been computed, now we have to lock.
+ with self.lock:
+ try:
+ # Check again whether another thread has already computed the value.
+ return obj_dict[name]
+ except KeyError:
+ # Otherwise call the function, cache the result, and return it
+ return obj_dict.setdefault(name, self.fget(obj))
+
+
+# A decorator for base.Observables methods returning an observable. This
+# decorator should be used by abstract base classes to indicate sub-classes need
+# to implement a corresponding @observavble annotated method.
+abstract_observable = abc.abstractproperty # pylint: disable=invalid-name
+
+
+class observable(cached_property): # pylint: disable=invalid-name
+ """A decorator for base.Observables methods returning an observable.
+
+ The body of the decorated function is evaluated at Entity construction time
+ and the observable is cached.
+ """
+ pass
diff --git a/DMC/src/env/dm_control/dm_control/composer/entity.py b/DMC/src/env/dm_control/dm_control/composer/entity.py
new file mode 100644
index 0000000..c729101
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/entity.py
@@ -0,0 +1,583 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract entity class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import os
+import weakref
+
+from absl import logging
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+import six
+
+_OPTION_KEYS = set(['update_interval', 'buffer_size', 'delay', 'aggregator',
+ 'corruptor', 'enabled'])
+
+_NO_ATTACHMENT_FRAME = 'No attachment frame found.'
+
+
+# The component order differs from that used by the open-source `tf` package.
+def _multiply_quaternions(quat1, quat2):
+ result = np.empty_like(quat1)
+ mjbindings.mjlib.mju_mulQuat(result, quat1, quat2)
+ return result
+
+
+def _rotate_vector(vec, quat):
+ """Rotates a vector by the given quaternion."""
+ result = np.empty_like(vec)
+ mjbindings.mjlib.mju_rotVecQuat(result, vec, quat)
+ return result
+
+
+class _ObservableKeys(object):
+ """Helper object that implements the `observables.dict_keys` functionality."""
+
+ def __init__(self, entity, observables):
+ self._entity = entity
+ self._observables = observables
+
+ def __getattr__(self, name):
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError:
+ raise ValueError('cannot retrieve the full identifier of mjcf_model')
+ return os.path.join(model_identifier, name)
+
+ def __dir__(self):
+ out = set(self._observables.keys())
+ out.update(dir(super(_ObservableKeys, self)))
+ return list(out)
+
+
+class Observables(object):
+ """Base-class for Entity observables.
+
+ Subclasses should declare getter methods annotated with @define.observable
+ decorator and returning an observable object.
+ """
+
+ def __init__(self, entity):
+ self._entity = weakref.proxy(entity)
+
+ self._observables = collections.OrderedDict()
+ self._keys_helper = _ObservableKeys(self._entity, self._observables)
+
+ # Ensure consistent ordering.
+ for attr_name in sorted(dir(type(self))):
+ type_attr = getattr(type(self), attr_name)
+ if isinstance(type_attr, define.observable):
+ self._observables[attr_name] = getattr(self, attr_name)
+
+ @property
+ def dict_keys(self):
+ return self._keys_helper
+
+ def as_dict(self, fully_qualified=True):
+ """Returns an OrderedDict of observables belonging to this Entity.
+
+ The returned observables will include any added using the _add_observable
+ method, as well as any generated by a method decorated with the
+ @define.observable annotation.
+
+ Args:
+ fully_qualified: (bool) Whether the dict keys should be prefixed with the
+ parent entity's full model identifier.
+ """
+
+ if fully_qualified:
+ # We need to make sure that this property doesn't raise an AttributeError,
+ # otherwise __getattr__ is executed and we get a very funky error.
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError:
+ raise ValueError('cannot retrieve the full identifier of mjcf_model')
+
+ return collections.OrderedDict(
+ [(os.path.join(model_identifier, name), observable)
+ for name, observable in six.iteritems(self._observables)])
+ else:
+ # Return a copy to prevent dict being edited.
+ return self._observables.copy()
+
+ def get_observable(self, name, name_fully_qualified=False):
+ """Returns the observable with the given name.
+
+ Args:
+ name: (str) The identifier of the observable.
+ name_fully_qualified: (bool) Whether the provided name is prefixed by the
+ model's full identifier.
+ """
+
+ if name_fully_qualified:
+ try:
+ model_identifier = self._entity.mjcf_model.full_identifier
+ except AttributeError:
+ raise ValueError('cannot retrieve the full identifier of mjcf_model')
+ return self._observables[name.replace(model_identifier, '')]
+ else:
+ return self._observables[name]
+
+ def set_options(self, options):
+ """Configure Observables with an options dict.
+
+ Args:
+ options: A dict of dicts of configuration options keyed on
+ observable names, or a dict of configuration options, which will
+ propagate those options to all observables.
+ """
+ if options is None:
+ options = {}
+ elif options.keys() and set(options.keys()).issubset(_OPTION_KEYS):
+ options = dict([(key, options) for key in self._observables.keys()])
+
+ for obs_key, obs_options in six.iteritems(options):
+ try:
+ obs = self._observables[obs_key]
+ except KeyError:
+ raise KeyError('No observable with name {!r}'.format(obs_key))
+ obs.configure(**obs_options)
+
+ def enable_all(self):
+ """Enable all observables of this entity."""
+ for obs in self._observables.values():
+ obs.enabled = True
+
+ def disable_all(self):
+ """Disable all observables of this entity."""
+ for obs in self._observables.values():
+ obs.enabled = False
+
+ def add_observable(self, name, observable, enabled=True):
+ self._observables[name] = observable
+ self._observables[name].enabled = enabled
+
+
+@six.add_metaclass(abc.ABCMeta)
+class FreePropObservableMixin(object):
+ """Enforce observables of a free-moving object."""
+
+ @abc.abstractproperty
+ def position(self):
+ pass
+
+ @abc.abstractproperty
+ def orientation(self):
+ pass
+
+ @abc.abstractproperty
+ def linear_velocity(self):
+ pass
+
+ @abc.abstractproperty
+ def angular_velocity(self):
+ pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Entity(object):
+ """The abstract base class for an entity in a Composer environment."""
+
+ def __init__(self, *args, **kwargs):
+ """Entity constructor.
+
+ Subclasses should not override this method, instead implement a _build
+ method.
+
+ Args:
+ *args: Arguments passed through to the _build method.
+ **kwargs: Keyword arguments. Passed through to the _build method, apart
+ from the following.
+ `observable_options`: A dictionary of Observable
+ configuration options.
+ """
+ self._post_init_hooks = []
+
+ self._parent = None
+ self._attached = []
+
+ try:
+ observable_options = kwargs.pop('observable_options')
+ except KeyError:
+ observable_options = None
+
+ self._build(*args, **kwargs)
+ self._observables = self._build_observables()
+ self._observables.set_options(observable_options)
+
+ @abc.abstractmethod
+ def _build(self, *args, **kwargs):
+ """Entity initialization method to be overridden by subclasses."""
+ raise NotImplementedError
+
+ def _build_observables(self):
+ """Entity observables initialization method.
+
+ Returns:
+ An object subclassing the Observables class.
+ """
+ return Observables(self)
+
+ def iter_entities(self, exclude_self=False):
+ """An iterator that recursively iterates through all attached entities.
+
+ Args:
+ exclude_self: (optional) Whether to exclude this `Entity` itself from the
+ iterator.
+
+ Yields:
+ If `exclude_self` is `False`, the first value yielded is this Entity
+ itself. The following Entities are then yielded recursively in a
+ depth-first fashion, following the order in which the Entities are
+ attached.
+ """
+ if not exclude_self:
+ yield self
+ for attached_entity in self._attached:
+ for attached_entity_of_attached_entity in attached_entity.iter_entities():
+ yield attached_entity_of_attached_entity
+
+ @property
+ def observables(self):
+ """The observables defined by this entity."""
+ return self._observables
+
+ def initialize_episode_mjcf(self, random_state):
+ """Callback executed when the MJCF model is modified between episodes."""
+ pass
+
+ def after_compile(self, physics, random_state):
+ """Callback executed after the Mujoco Physics is recompiled."""
+ pass
+
+ def initialize_episode(self, physics, random_state):
+ """Callback executed during episode initialization."""
+ pass
+
+ def before_step(self, physics, random_state):
+ """Callback executed before an agent control step."""
+ pass
+
+ def before_substep(self, physics, random_state):
+ """Callback executed before a simulation step."""
+ pass
+
+ def after_substep(self, physics, random_state):
+ """A callback which is executed after a simulation step."""
+ pass
+
+ def after_step(self, physics, random_state):
+ """Callback executed after an agent control step."""
+ pass
+
+ @abc.abstractproperty
+ def mjcf_model(self):
+ raise NotImplementedError
+
+ def attach(self, entity, attach_site=None):
+ """Attaches an `Entity` without any additional degrees of freedom.
+
+ Args:
+ entity: The `Entity` to attach.
+ attach_site: (optional) The site to which to attach the entity's model. If
+ not set, defaults to self.attachment_site.
+
+ Returns:
+ The frame of the attached model.
+ """
+
+ if attach_site is None:
+ attach_site = self.attachment_site
+
+ frame = attach_site.attach(entity.mjcf_model)
+ self._attached.append(entity)
+ entity._parent = weakref.ref(self) # pylint: disable=protected-access
+ return frame
+
+ def detach(self):
+ """Detaches this entity if it has previously been attached."""
+ if self._parent is not None:
+ parent = self._parent()
+ if parent: # Weakref might dereference to None during garbage collection.
+ self.mjcf_model.detach()
+ parent._attached.remove(self) # pylint: disable=protected-access
+ self._parent = None
+ else:
+ raise RuntimeError('Cannot detach an entity that is not attached.')
+
+ @property
+ def parent(self):
+ """Returns the `Entity` to which this entity is attached, or `None`."""
+ return self._parent() if self._parent else None
+
+ @property
+ def attachment_site(self):
+ return self.mjcf_model
+
+ @property
+ def root_body(self):
+ if self.parent:
+ return mjcf.get_attachment_frame(self.mjcf_model)
+ else:
+ return self.mjcf_model.worldbody
+
+ def global_vector_to_local_frame(self, physics, vec_in_world_frame):
+ """Linearly transforms a world-frame vector into entity's local frame.
+
+ Note that this function does not perform an affine transformation of the
+ vector. In other words, the input vector is assumed to be specified with
+ respect to the same origin as this entity's local frame. This function
+ can also be applied to matrices whose innermost dimensions are either 2 or
+ 3. In this case, a matrix with the same leading dimensions is returned
+ where the innermost vectors are replaced by their values computed in the
+ local frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ vec_in_world_frame: A NumPy array with last dimension of shape (2,) or
+ (3,) that represents a vector quantity in the world frame.
+
+ Returns:
+ The same quantity as `vec_in_world_frame` but reexpressed in this
+ entity's local frame. The returned np.array has the same shape as
+ np.asarray(vec_in_world_frame).
+
+ Raises:
+ ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
+ or (3,).
+ """
+ vec_in_world_frame = np.asarray(vec_in_world_frame)
+
+ xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
+ # The ordering of the np.dot is such that the transformation holds for any
+ # matrix whose final dimensions are (2,) or (3,).
+ if vec_in_world_frame.shape[-1] == 2:
+ return np.dot(vec_in_world_frame, xmat[:2, :2])
+ elif vec_in_world_frame.shape[-1] == 3:
+ return np.dot(vec_in_world_frame, xmat)
+ else:
+ raise ValueError('`vec_in_world_frame` should have shape with final '
+ 'dimension 2 or 3: got {}'.format(
+ vec_in_world_frame.shape))
+
+ def global_xmat_to_local_frame(self, physics, xmat):
+ """Transforms another entity's `xmat` into this entity's local frame.
+
+ This function takes another entity's (E) xmat, which is an SO(3) matrix
+ from E's frame to the world frame, and turns it to a matrix that transforms
+ from E's frame into this entity's local frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ xmat: A NumPy array of shape (3, 3) or (9,) that represents another
+ entity's xmat.
+
+ Returns:
+ The `xmat` reexpressed in this entity's local frame. The returned
+ np.array has the same shape as np.asarray(xmat).
+
+ Raises:
+ ValueError: if `xmat` does not have shape (3, 3) or (9,).
+ """
+ xmat = np.asarray(xmat)
+
+ input_shape = xmat.shape
+ if xmat.shape == (9,):
+ xmat = np.reshape(xmat, (3, 3))
+
+ self_xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3))
+ if xmat.shape == (3, 3):
+ return np.reshape(np.dot(self_xmat.T, xmat), input_shape)
+ else:
+ raise ValueError('`xmat` should have shape (3, 3) or (9,): got {}'.format(
+ xmat.shape))
+
+ def get_pose(self, physics):
+ """Get the position and orientation of this entity relative to its parent.
+
+ Note that the semantics differ slightly depending on whether or not the
+ entity has a free joint:
+
+ * If it has a free joint the position and orientation are always given in
+ global coordinates.
+ * If the entity is fixed or attached with a different joint type then the
+ position and orientation are given relative to the parent frame.
+
+ For entities that are either attached directly to the worldbody, or to other
+ entities that are positioned at the global origin (e.g. the arena) the
+ global and relative poses are equivalent.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+
+ Returns:
+ A 2-tuple where the first entry is a (3,) numpy array representing the
+ position and the second is a (4,) numpy array representing orientation as
+ a quaternion.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ position = physics.bind(root_joint).qpos[:3]
+ quaternion = physics.bind(root_joint).qpos[3:]
+ else:
+ attachment_frame = mjcf.get_attachment_frame(self.mjcf_model)
+ if attachment_frame is None:
+ raise RuntimeError(_NO_ATTACHMENT_FRAME)
+ position = physics.bind(attachment_frame).pos
+ quaternion = physics.bind(attachment_frame).quat
+ return position, quaternion
+
+ def set_pose(self, physics, position=None, quaternion=None):
+ """Sets position and/or orientation of this entity relative to its parent.
+
+ If the entity is attached with a free joint, this method will set the
+ respective DoFs of the joint. If the entity is either fixed or attached with
+ a different joint type, this method will update the position and/or
+ quaternion of the attachment frame.
+
+ Note that the semantics differ slightly between the two cases: the DoFs of a
+ free body are specified in global coordinates, whereas the position of a
+ non-free body is specified in relative coordinates with respect to the
+ parent frame. However, for entities that are either attached directly to the
+ worldbody, or to other entities that are positioned at the global origin
+ (e.g. the arena), there is no difference between the two cases.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: (optional) A NumPy array of size 3.
+ quaternion: (optional) A NumPy array of size 4.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ if position is not None:
+ physics.bind(root_joint).qpos[:3] = position
+ if quaternion is not None:
+ physics.bind(root_joint).qpos[3:] = quaternion
+ else:
+ attachment_frame = mjcf.get_attachment_frame(self.mjcf_model)
+ if attachment_frame is None:
+ raise RuntimeError(_NO_ATTACHMENT_FRAME)
+ if position is not None:
+ physics.bind(attachment_frame).pos = position
+ if quaternion is not None:
+ normalised_quaternion = quaternion / np.linalg.norm(quaternion)
+ physics.bind(attachment_frame).quat = normalised_quaternion
+
+ def shift_pose(self,
+ physics,
+ position=None,
+ quaternion=None,
+ rotate_velocity=False):
+ """Shifts the position and/or orientation from its current configuration.
+
+ This is a convenience function that performs the same operation as
+ `set_pose`, but where the specified `position` is added to the current
+ position, and the specified `quaternion` is premultiplied to the current
+ quaternion.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: (optional) A NumPy array of size 3.
+ quaternion: (optional) A NumPy array of size 4.
+ rotate_velocity: (optional) A bool, whether to shift the current linear
+ velocity along with the pose. This will rotate the current linear
+ velocity, which is expressed relative to the world frame. The angular
+ velocity, which is expressed relative to the local frame is left
+ unchanged.
+
+ Raises:
+ RuntimeError: If the entity is not attached.
+ """
+ current_position, current_quaternion = self.get_pose(physics)
+ new_position, new_quaternion = None, None
+ if position is not None:
+ new_position = current_position + position
+ if quaternion is not None:
+ quaternion = np.array(quaternion, dtype=np.float64, copy=False)
+ new_quaternion = _multiply_quaternions(quaternion, current_quaternion)
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint and rotate_velocity:
+ # Rotate the linear velocity. The angular velocity (qvel[3:])
+ # is left unchanged, as it is expressed in the local frame.
+ # When rotatating the body frame the angular velocity already
+ # tracks the rotation but the linear velocity does not.
+ velocity = physics.bind(root_joint).qvel[:3]
+ rotated_velocity = _rotate_vector(velocity, quaternion)
+ self.set_velocity(physics, rotated_velocity)
+ self.set_pose(physics, new_position, new_quaternion)
+
+ def set_velocity(self, physics, velocity=None, angular_velocity=None):
+ """Sets the linear velocity and/or angular velocity of this free entity.
+
+ If the entity is attached with a free joint, this method will set the
+ respective DoFs of the joint. Otherwise a warning is logged.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ velocity: (optional) A NumPy array of size 3 specifying the
+ linear velocity.
+ angular_velocity: (optional) A NumPy array of size 3 specifying the
+ angular velocity
+ """
+ root_joint = mjcf.get_frame_freejoint(self.mjcf_model)
+ if root_joint:
+ if velocity is not None:
+ physics.bind(root_joint).qvel[:3] = velocity
+ if angular_velocity is not None:
+ physics.bind(root_joint).qvel[3:] = angular_velocity
+ else:
+ logging.warning('Cannot set velocity on Entity with no free joint.')
+
+ def configure_joints(self, physics, position):
+ """Configures this entity's internal joints.
+
+ The default implementation of this method simply sets the `qpos` of all
+ joints in this entity to the values specified in the `position` argument.
+ Entity subclasses with actuated joints may override this method to achieve a
+ stable reconfiguration of joint positions, for example the control signal
+ of position actuators may be changed to match the new joint positions.
+
+ Args:
+ physics: An instance of `mjcf.Physics`.
+ position: The desired position of this entity's joints.
+ """
+ joints = self.mjcf_model.find_all('joint', exclude_attachments=True)
+ physics.bind(joints).qpos = position
+
+
+class ModelWrapperEntity(Entity):
+ """An entity class that wraps an MJCF model without any additional logic."""
+
+ def _build(self, mjcf_model):
+ self._mjcf_model = mjcf_model
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_model
diff --git a/DMC/src/env/dm_control/dm_control/composer/entity_test.py b/DMC/src/env/dm_control/dm_control/composer/entity_test.py
new file mode 100644
index 0000000..b27f398
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/entity_test.py
@@ -0,0 +1,435 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for composer.Entity."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer import arena
+from dm_control.composer import define
+from dm_control.composer import entity
+from dm_control.composer.observation.observable import base as observable
+import numpy as np
+import six
+from six.moves import range
+
+_NO_ROTATION = (1, 0, 0, 0) # Tests support for non-arrays and non-floats.
+_NINETY_DEGREES_ABOUT_X = np.array(
+ [np.cos(np.pi / 4), np.sin(np.pi / 4), 0., 0.])
+_NINETY_DEGREES_ABOUT_Y = np.array(
+ [np.cos(np.pi / 4), 0., np.sin(np.pi / 4), 0.])
+_NINETY_DEGREES_ABOUT_Z = np.array(
+ [np.cos(np.pi / 4), 0., 0., np.sin(np.pi / 4)])
+_FORTYFIVE_DEGREES_ABOUT_X = np.array(
+ [np.cos(np.pi / 8), np.sin(np.pi / 8), 0., 0.])
+
+_TEST_ROTATIONS = [
+ # Triplets of original rotation, new rotation and final rotation.
+ (None, _NO_ROTATION, _NO_ROTATION),
+ (_NO_ROTATION, _NINETY_DEGREES_ABOUT_Z, _NINETY_DEGREES_ABOUT_Z),
+ (_FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Y,
+ np.array([0.65328, 0.2706, 0.65328, -0.2706])),
+]
+
+
+def _param_product(**param_lists):
+ keys, values = zip(*param_lists.items())
+ for combination in itertools.product(*values):
+ yield dict(zip(keys, combination))
+
+
+class TestEntity(entity.Entity):
+ """Simple test entity that does nothing but declare some observables."""
+
+ def _build(self, name='test_entity'):
+ self._mjcf_root = mjcf.element.RootElement(model=name)
+ self._mjcf_root.worldbody.add('geom', type='sphere', size=(0.1,))
+
+ def _build_observables(self):
+ return TestEntityObservables(self)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class TestEntityObservables(entity.Observables):
+ """Trivial observables for the test entity."""
+
+ @define.observable
+ def observable0(self):
+ return observable.Generic(lambda phys: 0.0)
+
+ @define.observable
+ def observable1(self):
+ return observable.Generic(lambda phys: 1.0)
+
+
+class EntityTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(EntityTest, self).setUp()
+ self.entity = TestEntity()
+
+ def testNumObservables(self):
+ """Tests that the observables dict has the right number of entries."""
+ self.assertLen(self.entity.observables.as_dict(), 2)
+
+ def testObservableNames(self):
+ """Tests that the observables dict keys correspond to the observable names.
+ """
+ obs = self.entity.observables.as_dict()
+ self.assertIn('observable0', obs)
+ self.assertIn('observable1', obs)
+
+ subentity = TestEntity(name='subentity')
+ self.entity.attach(subentity)
+ self.assertIn('subentity/observable0', subentity.observables.as_dict())
+ self.assertEqual(subentity.observables.dict_keys.observable0,
+ 'subentity/observable0')
+ self.assertIn('observable0', dir(subentity.observables.dict_keys))
+ self.assertIn('subentity/observable1', subentity.observables.as_dict())
+ self.assertEqual(subentity.observables.dict_keys.observable1,
+ 'subentity/observable1')
+ self.assertIn('observable1', dir(subentity.observables.dict_keys))
+
+ def testEnableDisableObservables(self):
+ """Test the enabling and disable functionality for observables."""
+ all_obs = self.entity.observables.as_dict()
+
+ self.entity.observables.enable_all()
+ for obs in all_obs.values():
+ self.assertTrue(obs.enabled)
+
+ self.entity.observables.disable_all()
+ for obs in all_obs.values():
+ self.assertFalse(obs.enabled)
+
+ self.entity.observables.observable0.enabled = True
+ self.assertTrue(all_obs['observable0'].enabled)
+
+ def testObservableDefaultOptions(self):
+ corruptor = lambda x: x
+ options = {
+ 'update_interval': 2,
+ 'buffer_size': 10,
+ 'delay': 1,
+ 'aggregator': 'max',
+ 'corruptor': corruptor,
+ 'enabled': True
+ }
+ self.entity.observables.set_options(options)
+
+ for obs in self.entity.observables.as_dict().values():
+ self.assertEqual(obs.update_interval, 2)
+ self.assertEqual(obs.delay, 1)
+ self.assertEqual(obs.buffer_size, 10)
+ self.assertEqual(obs.aggregator, observable.AGGREGATORS['max'])
+ self.assertEqual(obs.corruptor, corruptor)
+ self.assertTrue(obs.enabled)
+
+ def testObservablePartialDefaultOptions(self):
+ options = {'update_interval': 2, 'delay': 1}
+ self.entity.observables.set_options(options)
+
+ for obs in self.entity.observables.as_dict().values():
+ self.assertEqual(obs.update_interval, 2)
+ self.assertEqual(obs.delay, 1)
+ self.assertEqual(obs.buffer_size, None)
+ self.assertEqual(obs.aggregator, None)
+ self.assertEqual(obs.corruptor, None)
+
+ def testObservableOptionsInvalidName(self):
+ options = {'asdf': None}
+ with six.assertRaisesRegex(
+ self, KeyError, 'No observable with name \'asdf\''):
+ self.entity.observables.set_options(options)
+
+ def testObservableInvalidOptions(self):
+ options = {'observable0': {'asdf': 2}}
+ with six.assertRaisesRegex(self, AttributeError,
+ 'Cannot add attribute asdf in configure.'):
+ self.entity.observables.set_options(options)
+
+ def testObservableOptions(self):
+ options = {
+ 'observable0': {
+ 'update_interval': 2,
+ 'delay': 3
+ },
+ 'observable1': {
+ 'update_interval': 4,
+ 'delay': 5
+ }
+ }
+ self.entity.observables.set_options(options)
+ observables = self.entity.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertEqual(observables['observable0'].buffer_size, None)
+ self.assertEqual(observables['observable0'].aggregator, None)
+ self.assertEqual(observables['observable0'].corruptor, None)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 4)
+ self.assertEqual(observables['observable1'].delay, 5)
+ self.assertEqual(observables['observable1'].buffer_size, None)
+ self.assertEqual(observables['observable1'].aggregator, None)
+ self.assertEqual(observables['observable1'].corruptor, None)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testObservableOptionsEntityConstructor(self):
+ options = {
+ 'observable0': {
+ 'update_interval': 2,
+ 'delay': 3
+ },
+ 'observable1': {
+ 'update_interval': 4,
+ 'delay': 5
+ }
+ }
+ ent = TestEntity(observable_options=options)
+ observables = ent.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertEqual(observables['observable0'].buffer_size, None)
+ self.assertEqual(observables['observable0'].aggregator, None)
+ self.assertEqual(observables['observable0'].corruptor, None)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 4)
+ self.assertEqual(observables['observable1'].delay, 5)
+ self.assertEqual(observables['observable1'].buffer_size, None)
+ self.assertEqual(observables['observable1'].aggregator, None)
+ self.assertEqual(observables['observable1'].corruptor, None)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testObservablePartialOptions(self):
+ options = {'observable0': {'update_interval': 2, 'delay': 3}}
+ self.entity.observables.set_options(options)
+ observables = self.entity.observables.as_dict()
+ self.assertEqual(observables['observable0'].update_interval, 2)
+ self.assertEqual(observables['observable0'].delay, 3)
+ self.assertEqual(observables['observable0'].buffer_size, None)
+ self.assertEqual(observables['observable0'].aggregator, None)
+ self.assertEqual(observables['observable0'].corruptor, None)
+ self.assertFalse(observables['observable0'].enabled)
+
+ self.assertEqual(observables['observable1'].update_interval, 1)
+ self.assertEqual(observables['observable1'].delay, None)
+ self.assertEqual(observables['observable1'].buffer_size, None)
+ self.assertEqual(observables['observable1'].aggregator, None)
+ self.assertEqual(observables['observable1'].corruptor, None)
+ self.assertFalse(observables['observable1'].enabled)
+
+ def testAttach(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+
+ self.assertIsNone(entities[0].parent)
+ self.assertIs(entities[1].parent, entities[0])
+ self.assertIs(entities[2].parent, entities[1])
+ self.assertIs(entities[3].parent, entities[0])
+
+ self.assertIsNone(entities[0].mjcf_model.parent_model)
+ self.assertIs(entities[1].mjcf_model.parent_model, entities[0].mjcf_model)
+ self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model)
+ self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model)
+
+ self.assertEqual(list(entities[0].iter_entities()), entities)
+
+ def testDetach(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+
+ entities[1].detach()
+ with six.assertRaisesRegex(self, RuntimeError, 'not attached'):
+ entities[1].detach()
+
+ self.assertIsNone(entities[0].parent)
+ self.assertIsNone(entities[1].parent)
+ self.assertIs(entities[2].parent, entities[1])
+ self.assertIs(entities[3].parent, entities[0])
+
+ self.assertIsNone(entities[0].mjcf_model.parent_model)
+ self.assertIsNone(entities[1].mjcf_model.parent_model)
+ self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model)
+ self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model)
+
+ self.assertEqual(list(entities[0].iter_entities()),
+ [entities[0], entities[3]])
+
+ def testIterEntitiesExcludeSelf(self):
+ entities = [TestEntity() for _ in range(4)]
+ entities[0].attach(entities[1])
+ entities[1].attach(entities[2])
+ entities[0].attach(entities[3])
+ self.assertEqual(
+ list(entities[0].iter_entities(exclude_self=True)), entities[1:])
+
+ def testGlobalVectorToLocalFrame(self):
+ parent = TestEntity()
+ parent.mjcf_model.worldbody.add(
+ 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
+ physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
+
+ # 3D vectors
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 1, 0]),
+ [1, 0, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [-1, 0, 0]),
+ [0, 1, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 0, 1]),
+ [0, 0, 1], atol=1e-10)
+
+ # 2D vectors; z-component is ignored
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [0, 1]),
+ [1, 0], atol=1e-10)
+ np.testing.assert_allclose(
+ self.entity.global_vector_to_local_frame(physics, [-1, 0]),
+ [0, 1], atol=1e-10)
+
+ def testGlobalMatrixToLocalFrame(self):
+ parent = TestEntity()
+ parent.mjcf_model.worldbody.add(
+ 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model)
+ physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model)
+
+ rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])
+ ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]])
+
+ np.testing.assert_allclose(
+ self.entity.global_xmat_to_local_frame(physics, rotation_atob),
+ ego_rotation_atob, atol=1e-10)
+
+ flat_rotation_atob = np.reshape(rotation_atob, -1)
+ flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1)
+ np.testing.assert_allclose(
+ self.entity.global_xmat_to_local_frame(
+ physics, flat_rotation_atob),
+ flat_rotation_ego_atob, atol=1e-10)
+
+ @parameterized.parameters(*_param_product(
+ position=[None, [1., 0., -1.]],
+ quaternion=[None, _FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Z],
+ freejoint=[False, True],
+ ))
+ def testSetPose(self, position, quaternion, freejoint):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ if freejoint:
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ if quaternion is None:
+ ground_truth_quat = _NO_ROTATION
+ else:
+ ground_truth_quat = quaternion
+
+ if position is None:
+ ground_truth_pos = np.zeros(shape=(3,))
+ else:
+ ground_truth_pos = position
+
+ subentity.set_pose(physics, position=position, quaternion=quaternion)
+
+ np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos)
+ np.testing.assert_array_equal(physics.bind(frame).xquat, ground_truth_quat)
+
+ @parameterized.parameters(*_param_product(
+ original_position=[[-2, -1, -1.], [1., 0., -1.]],
+ position=[None, [1., 0., -1.]],
+ original_quaternion=_TEST_ROTATIONS[0],
+ quaternion=_TEST_ROTATIONS[1],
+ expected_quaternion=_TEST_ROTATIONS[2],
+ freejoint=[False, True],
+ ))
+ def testShiftPose(self, original_position, position, original_quaternion,
+ quaternion, expected_quaternion, freejoint):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ if freejoint:
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ # Set the original position
+ subentity.set_pose(
+ physics, position=original_position, quaternion=original_quaternion)
+
+ if position is None:
+ ground_truth_pos = original_position
+ else:
+ ground_truth_pos = original_position + np.array(position)
+ subentity.shift_pose(physics, position=position, quaternion=quaternion)
+ np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos)
+
+ updated_quat = physics.bind(frame).xquat
+ np.testing.assert_array_almost_equal(updated_quat, expected_quaternion,
+ 1e-4)
+
+ @parameterized.parameters(False, True)
+ def testShiftPoseWithVelocity(self, rotate_velocity):
+ # Setup entity.
+ test_arena = arena.Arena()
+ subentity = TestEntity(name='subentity')
+ frame = test_arena.attach(subentity)
+ frame.add('freejoint')
+
+ physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model)
+
+ # Set the original position
+ subentity.set_pose(physics, position=[0., 0., 0.])
+
+ # Set velocity in y dim.
+ subentity.set_velocity(physics, [0., 1., 0.])
+
+ # Rotate the entity around the z axis.
+ subentity.shift_pose(
+ physics, quaternion=[0., 0., 0., 1.], rotate_velocity=rotate_velocity)
+
+ physics.forward()
+ updated_position, _ = subentity.get_pose(physics)
+ if rotate_velocity:
+ # Should not have moved in the y dim.
+ np.testing.assert_array_almost_equal(updated_position[1], 0.)
+ else:
+ # Should not have moved in the x dim.
+ np.testing.assert_array_almost_equal(updated_position[0], 0.)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/environment.py b/DMC/src/env/dm_control/dm_control/composer/environment.py
new file mode 100644
index 0000000..68e904e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/environment.py
@@ -0,0 +1,450 @@
+# Copyright 2018-2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""RL environment classes for Composer tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+import weakref
+
+from absl import logging
+from dm_control import mjcf
+from dm_control.composer import observation
+from dm_control.rl import control
+import dm_env
+import numpy as np
+from six.moves import range
+
+warnings.simplefilter('always', DeprecationWarning)
+
+_STEPS_LOGGING_INTERVAL = 10000
+
+HOOK_NAMES = ('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+
+_empty_function = lambda: None
+
+
+def _empty_function_with_docstring():
+ """Some docstring."""
+
+_EMPTY_CODE = _empty_function.__code__.co_code
+_EMPTY_WITH_DOCSTRING_CODE = _empty_function_with_docstring.__code__.co_code
+
+
+def _callable_is_trivial(f):
+ return (f.__code__.co_code == _EMPTY_CODE or
+ f.__code__.co_code == _EMPTY_WITH_DOCSTRING_CODE)
+
+
+class EpisodeInitializationError(RuntimeError):
+ """Raised by a `composer.Task` when it fails to initialize an episode."""
+
+
+class _Hook(object):
+
+ __slots__ = ('entity_hooks', 'extra_hooks')
+
+ def __init__(self):
+ self.entity_hooks = []
+ self.extra_hooks = []
+
+
+class _EnvironmentHooks(object):
+ """Helper object that scans and memoizes various hooks in a task.
+
+ This object exist to ensure that we do not incur a substantial overhead in
+ calling empty entity hooks in more complicated tasks.
+ """
+
+ __slots__ = (('_task', '_episode_step_count') +
+ tuple('_' + hook_name for hook_name in HOOK_NAMES))
+
+ def __init__(self, task):
+ self._task = task
+ self._episode_step_count = 0
+ for hook_name in HOOK_NAMES:
+ slot_name = '_' + hook_name
+ setattr(self, slot_name, _Hook())
+ self.refresh_entity_hooks()
+
+ def refresh_entity_hooks(self):
+ """Scans and memoizes all non-trivial entity hooks."""
+ for hook_name in HOOK_NAMES:
+ hooks = []
+ for entity in self._task.root_entity.iter_entities():
+ entity_hook = getattr(entity, hook_name)
+ # Ignore any hook that is a no-op to avoid function call overhead.
+ if not _callable_is_trivial(entity_hook):
+ hooks.append(entity_hook)
+ getattr(self, '_' + hook_name).entity_hooks = hooks
+
+ def add_extra_hook(self, hook_name, hook_callable):
+ if hook_name not in HOOK_NAMES:
+ raise ValueError('{!r} is not a valid hook name'.format(hook_name))
+ if not callable(hook_callable):
+ raise ValueError('{!r} is not a callable'.format(hook_callable))
+ getattr(self, '_' + hook_name).extra_hooks.append(hook_callable)
+
+ def initialize_episode_mjcf(self, random_state):
+ self._task.initialize_episode_mjcf(random_state)
+ for entity_hook in self._initialize_episode_mjcf.entity_hooks:
+ entity_hook(random_state)
+ for extra_hook in self._initialize_episode_mjcf.extra_hooks:
+ extra_hook(random_state)
+
+ def after_compile(self, physics, random_state):
+ self._task.after_compile(physics, random_state)
+ for entity_hook in self._after_compile.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_compile.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def initialize_episode(self, physics, random_state):
+ self._episode_step_count = 0
+ self._task.initialize_episode(physics, random_state)
+ for entity_hook in self._initialize_episode.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._initialize_episode.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def before_step(self, physics, action, random_state):
+ self._episode_step_count += 1
+ if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0:
+ logging.info('The current episode has been running for %d steps.',
+ self._episode_step_count)
+ self._task.before_step(physics, action, random_state)
+ for entity_hook in self._before_step.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._before_step.extra_hooks:
+ extra_hook(physics, action, random_state)
+
+ def before_substep(self, physics, action, random_state):
+ self._task.before_substep(physics, action, random_state)
+ for entity_hook in self._before_substep.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hooks in self._before_substep.extra_hooks:
+ extra_hooks(physics, action, random_state)
+
+ def after_substep(self, physics, random_state):
+ self._task.after_substep(physics, random_state)
+ for entity_hook in self._after_substep.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_substep.extra_hooks:
+ extra_hook(physics, random_state)
+
+ def after_step(self, physics, random_state):
+ self._task.after_step(physics, random_state)
+ for entity_hook in self._after_step.entity_hooks:
+ entity_hook(physics, random_state)
+ for extra_hook in self._after_step.extra_hooks:
+ extra_hook(physics, random_state)
+
+
+class _CommonEnvironment(object):
+ """Common components for RL environments."""
+
+ def __init__(self, task, time_limit=float('inf'), random_state=None,
+ n_sub_steps=None,
+ raise_exception_on_physics_error=True,
+ strip_singleton_obs_buffer_dim=False):
+ """Initializes an instance of `_CommonEnvironment`.
+
+ Args:
+ task: Instance of `composer.base.Task`.
+ time_limit: (optional) A float, the time limit in seconds beyond which an
+ episode is forced to terminate.
+ random_state: Optional, either an int seed or an `np.random.RandomState`
+ object. If None (default), the random number generator will self-seed
+ from a platform-dependent source of entropy.
+ n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
+ agent control step. New code should instead override the
+ `control_substep` property of the task.
+ raise_exception_on_physics_error: (optional) A boolean, indicating whether
+ `PhysicsError` should be raised as an exception. If `False`, physics
+ errors will result in the current episode being terminated with a
+ warning logged, and a new episode started.
+ strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
+ the array shape of observations with `buffer_size == 1` will not have a
+ leading buffer dimension.
+ """
+ self._task = task
+ if not isinstance(random_state, np.random.RandomState):
+ self._random_state = np.random.RandomState(random_state)
+ else:
+ self._random_state = random_state
+ self._hooks = _EnvironmentHooks(self._task)
+ self._time_limit = time_limit
+ self._raise_exception_on_physics_error = raise_exception_on_physics_error
+ self._strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim
+
+ if n_sub_steps is not None:
+ warnings.simplefilter('once', DeprecationWarning)
+ warnings.warn('The `n_sub_steps` argument is deprecated. Please override '
+ 'the `control_timestep` property of the task instead.',
+ DeprecationWarning)
+ self._overridden_n_sub_steps = n_sub_steps
+
+ self._recompile_physics_and_update_observables()
+
+ def add_extra_hook(self, hook_name, hook_callable):
+ self._hooks.add_extra_hook(hook_name, hook_callable)
+
+ def _recompile_physics_and_update_observables(self):
+ """Sets up the environment for latest MJCF model from the task."""
+ self._physics_proxy = None
+ self._recompile_physics()
+ if isinstance(self._physics, weakref.ProxyType):
+ self._physics_proxy = self._physics
+ else:
+ self._physics_proxy = weakref.proxy(self._physics)
+
+ if self._overridden_n_sub_steps is not None:
+ self._n_sub_steps = self._overridden_n_sub_steps
+ else:
+ self._n_sub_steps = self._task.physics_steps_per_control_step
+
+ self._hooks.refresh_entity_hooks()
+ self._hooks.after_compile(self._physics_proxy, self._random_state)
+ self._observation_updater = self._make_observation_updater()
+ self._observation_updater.reset(self._physics_proxy, self._random_state)
+
+ def _recompile_physics(self):
+ """Creates a new Physics using the latest MJCF model from the task."""
+ if getattr(self, '_physics', None):
+ self._physics.free()
+ self._physics = mjcf.Physics.from_mjcf_model(
+ self._task.root_entity.mjcf_model)
+
+ def _make_observation_updater(self):
+ return observation.Updater(
+ self._task.observables, self._task.physics_steps_per_control_step,
+ self._strip_singleton_obs_buffer_dim)
+
+ @property
+ def physics(self):
+ """Returns a `weakref.ProxyType` pointing to the current `mjcf.Physics`.
+
+ Note that the underlying `mjcf.Physics` will be destroyed whenever the MJCF
+ model is recompiled. It is therefore unsafe for external objects to hold a
+ reference to `environment.physics`. Attempting to access attributes of a
+ dead `Physics` instance will result in a `ReferenceError`.
+ """
+ return self._physics_proxy
+
+ @property
+ def task(self):
+ return self._task
+
+ @property
+ def random_state(self):
+ return self._random_state
+
+ def control_timestep(self):
+ """Returns the interval between agent actions in seconds."""
+ if self._overridden_n_sub_steps is not None:
+ return self.physics.timestep() * self._overridden_n_sub_steps
+ else:
+ return self.task.control_timestep
+
+
+class Environment(_CommonEnvironment, dm_env.Environment):
+ """Reinforcement learning environment for Composer tasks."""
+
+ def __init__(self, task, time_limit=float('inf'), random_state=None,
+ n_sub_steps=None,
+ raise_exception_on_physics_error=True,
+ strip_singleton_obs_buffer_dim=False,
+ max_reset_attempts=1):
+ """Initializes an instance of `Environment`.
+
+ Args:
+ task: Instance of `composer.base.Task`.
+ time_limit: (optional) A float, the time limit in seconds beyond which
+ an episode is forced to terminate.
+ random_state: (optional) an int seed or `np.random.RandomState` instance.
+ n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per
+ agent control step. New code should instead override the
+ `control_substep` property of the task.
+ raise_exception_on_physics_error: (optional) A boolean, indicating whether
+ `PhysicsError` should be raised as an exception. If `False`, physics
+ errors will result in the current episode being terminated with a
+ warning logged, and a new episode started.
+ strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`,
+ the array shape of observations with `buffer_size == 1` will not have a
+ leading buffer dimension.
+ max_reset_attempts: (optional) Maximum number of times to try resetting
+ the environment. If an `EpisodeInitializationError` is raised
+ during this process, an environment reset is reattempted up to this
+ number of times. If this count is exceeded then the most recent
+ exception will be allowed to propagate. Defaults to 1, i.e. no failure
+ is allowed.
+ """
+ super(Environment, self).__init__(
+ task=task,
+ time_limit=time_limit,
+ random_state=random_state,
+ n_sub_steps=n_sub_steps,
+ raise_exception_on_physics_error=raise_exception_on_physics_error,
+ strip_singleton_obs_buffer_dim=strip_singleton_obs_buffer_dim)
+ self._max_reset_attempts = max_reset_attempts
+ self._reset_next_step = True
+
+ def reset(self):
+ failed_attempts = 0
+ while True:
+ try:
+ return self._reset_attempt()
+ except EpisodeInitializationError as e:
+ failed_attempts += 1
+ if failed_attempts < self._max_reset_attempts:
+ logging.error('Error during episode reset: %s', repr(e))
+ else:
+ raise
+
+ def _reset_attempt(self):
+ self._hooks.initialize_episode_mjcf(self._random_state)
+ self._recompile_physics_and_update_observables()
+ with self._physics.reset_context():
+ self._hooks.initialize_episode(self._physics_proxy, self._random_state)
+ self._observation_updater.reset(self._physics_proxy, self._random_state)
+ self._reset_next_step = False
+ return dm_env.TimeStep(
+ step_type=dm_env.StepType.FIRST,
+ reward=None,
+ discount=None,
+ observation=self._observation_updater.get_observation())
+
+ # TODO(b/129061424): Remove this method.
+ def step_spec(self):
+ """DEPRECATED: please use `reward_spec` and `discount_spec` instead."""
+ warnings.warn('`step_spec` is deprecated, please use `reward_spec` and '
+ '`discount_spec` instead.', DeprecationWarning)
+ if (self._task.get_reward_spec() is None or
+ self._task.get_discount_spec() is None):
+ raise NotImplementedError
+ return dm_env.TimeStep(
+ step_type=None,
+ reward=self._task.get_reward_spec(),
+ discount=self._task.get_discount_spec(),
+ observation=self._observation_updater.observation_spec(),
+ )
+
+ def step(self, action):
+ """Updates the environment using the action and returns a `TimeStep`."""
+ if self._reset_next_step:
+ self._reset_next_step = False
+ return self.reset()
+
+ self._hooks.before_step(self._physics_proxy, action, self._random_state)
+ self._observation_updater.prepare_for_next_control_step()
+
+ try:
+ for i in range(self._n_sub_steps):
+ self._hooks.before_substep(self._physics_proxy, action,
+ self._random_state)
+ self._physics.step()
+ self._hooks.after_substep(self._physics_proxy, self._random_state)
+ # The final observation update must happen after all the hooks in
+ # `self._hooks.after_step` is called. Otherwise, if any of these hooks
+ # modify the physics state then we might capture an observation that is
+ # inconsistent with the final physics state.
+ if i < self._n_sub_steps - 1:
+ self._observation_updater.update()
+ physics_is_divergent = False
+ except control.PhysicsError as e:
+ if not self._raise_exception_on_physics_error:
+ logging.warning(e)
+ physics_is_divergent = True
+ else:
+ raise
+
+ self._hooks.after_step(self._physics_proxy, self._random_state)
+ self._observation_updater.update()
+
+ if not physics_is_divergent:
+ reward = self._task.get_reward(self._physics_proxy)
+ discount = self._task.get_discount(self._physics_proxy)
+ terminating = (
+ self._task.should_terminate_episode(self._physics_proxy)
+ or self._physics.time() >= self._time_limit
+ )
+ else:
+ reward = 0.0
+ discount = 0.0
+ terminating = True
+
+ obs = self._observation_updater.get_observation()
+
+ if not terminating:
+ return dm_env.TimeStep(dm_env.StepType.MID, reward, discount, obs)
+ else:
+ self._reset_next_step = True
+ return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, obs)
+
+ def action_spec(self):
+ """Returns the action specification for this environment."""
+ return self._task.action_spec(self._physics_proxy)
+
+ def reward_spec(self):
+ """Describes the reward returned by this environment.
+
+ This will be the output of `self.task.reward_spec()` if it is not None,
+ otherwise it will be the default spec returned by
+ `dm_env.Environment.reward_spec()`.
+
+ Returns:
+ A `specs.Array` instance, or a nested dict, list or tuple of
+ `specs.Array`s.
+ """
+ task_reward_spec = self._task.get_reward_spec()
+ if task_reward_spec is not None:
+ return task_reward_spec
+ else:
+ return super(Environment, self).reward_spec()
+
+ def discount_spec(self):
+ """Describes the discount returned by this environment.
+
+ This will be the output of `self.task.discount_spec()` if it is not None,
+ otherwise it will be the default spec returned by
+ `dm_env.Environment.discount_spec()`.
+
+ Returns:
+ A `specs.Array` instance, or a nested dict, list or tuple of
+ `specs.Array`s.
+ """
+ task_discount_spec = self._task.get_discount_spec()
+ if task_discount_spec is not None:
+ return task_discount_spec
+ else:
+ return super(Environment, self).discount_spec()
+
+ def observation_spec(self):
+ """Returns the observation specification for this environment.
+
+ Returns:
+ An `OrderedDict` mapping observation name to `specs.Array` containing
+ observation shape and dtype.
+ """
+ return self._observation_updater.observation_spec()
diff --git a/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py b/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py
new file mode 100644
index 0000000..5fcebbf
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for Entity and Task hooks in an Environment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from dm_control import composer
+from dm_control.composer import hooks_test_utils
+import numpy as np
+from six.moves import range
+
+
+class EnvironmentHooksTest(hooks_test_utils.HooksTestMixin, absltest.TestCase):
+
+ def testEnvironmentHooksScheduling(self):
+ env = composer.Environment(self.task)
+ for hook_name in composer.HOOK_NAMES:
+ env.add_extra_hook(hook_name, getattr(self.extra_hooks, hook_name))
+ for _ in range(self.num_episodes):
+ with self.track_episode():
+ env.reset()
+ for _ in range(self.steps_per_episode):
+ env.step([0.1, 0.2, 0.3, 0.4])
+ np.testing.assert_array_equal(env.physics.data.ctrl,
+ [0.1, 0.2, 0.3, 0.4])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/environment_test.py b/DMC/src/env/dm_control/dm_control/composer/environment_test.py
new file mode 100644
index 0000000..835450f
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/environment_test.py
@@ -0,0 +1,106 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.environment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+import dm_env
+import mock
+import numpy as np
+from six.moves import range
+
+
+class DummyTask(composer.NullTask):
+
+ def __init__(self):
+ null_entity = composer.ModelWrapperEntity(mjcf.RootElement())
+ super(DummyTask, self).__init__(null_entity)
+
+ @property
+ def task_observables(self):
+ time = observable.Generic(lambda physics: physics.time())
+ time.enabled = True
+ return {'time': time}
+
+
+class DummyTaskWithResetFailures(DummyTask):
+
+ def __init__(self, num_reset_failures):
+ super(DummyTaskWithResetFailures, self).__init__()
+ self.num_reset_failures = num_reset_failures
+ self.reset_counter = 0
+
+ def initialize_episode_mjcf(self, random_state):
+ self.reset_counter += 1
+
+ def initialize_episode(self, physics, random_state):
+ if self.reset_counter <= self.num_reset_failures:
+ raise composer.EpisodeInitializationError()
+
+
+class EnvironmentTest(parameterized.TestCase):
+
+ def test_failed_resets(self):
+ total_reset_failures = 5
+ env_reset_attempts = 2
+ task = DummyTaskWithResetFailures(num_reset_failures=total_reset_failures)
+ env = composer.Environment(task, max_reset_attempts=env_reset_attempts)
+ for _ in range(total_reset_failures // env_reset_attempts):
+ with self.assertRaises(composer.EpisodeInitializationError):
+ env.reset()
+ env.reset() # should not raise an exception
+ self.assertEqual(task.reset_counter, total_reset_failures + 1)
+
+ @parameterized.parameters(
+ dict(name='reward_spec', defined_in_task=True),
+ dict(name='reward_spec', defined_in_task=False),
+ dict(name='discount_spec', defined_in_task=True),
+ dict(name='discount_spec', defined_in_task=False))
+ def test_get_spec(self, name, defined_in_task):
+ task = DummyTask()
+ env = composer.Environment(task)
+ with mock.patch.object(task, 'get_' + name) as mock_task_get_spec:
+ if defined_in_task:
+ expected_spec = mock.Mock()
+ mock_task_get_spec.return_value = expected_spec
+ else:
+ expected_spec = getattr(dm_env.Environment, name)(env)
+ mock_task_get_spec.return_value = None
+ spec = getattr(env, name)()
+ mock_task_get_spec.assert_called_once_with()
+ self.assertSameStructure(spec, expected_spec)
+
+ def test_can_provide_observation(self):
+ task = DummyTask()
+ env = composer.Environment(task)
+ obs = env.reset().observation
+ self.assertLen(obs, 1)
+ np.testing.assert_array_equal(obs['time'], env.physics.time())
+ for _ in range(20):
+ obs = env.step([]).observation
+ self.assertLen(obs, 1)
+ np.testing.assert_array_equal(obs['time'], env.physics.time())
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py b/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py
new file mode 100644
index 0000000..f1a449f
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py
@@ -0,0 +1,326 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities for testing environment hooks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import contextlib
+import inspect
+
+from dm_control import composer
+from dm_control import mjcf
+from six.moves import range
+
+
+def add_bodies_and_actuators(mjcf_model, num_actuators):
+ if num_actuators % 2:
+ raise ValueError('num_actuators is not a multiple of 2')
+ for _ in range(num_actuators // 2):
+ body = mjcf_model.worldbody.add('body')
+ body.add('inertial', pos=[0, 0, 0], mass=1, diaginertia=[1, 1, 1])
+ joint_x = body.add('joint', axis=[1, 0, 0])
+ mjcf_model.actuator.add('position', joint=joint_x)
+ joint_y = body.add('joint', axis=[0, 1, 0])
+ mjcf_model.actuator.add('position', joint=joint_y)
+
+
+class HooksTracker(object):
+ """Helper class for tracking call order of callbacks."""
+
+ def __init__(self, test_case, physics_timestep, control_timestep,
+ *args, **kwargs):
+ super(HooksTracker, self).__init__(*args, **kwargs)
+ self.tracked = False
+ self._test_case = test_case
+ self._call_count = collections.defaultdict(lambda: 0)
+ self._physics_timestep = physics_timestep
+ self._physics_steps_per_control_step = (
+ round(int(control_timestep / physics_timestep)))
+
+ mro = inspect.getmro(type(self))
+ self._has_super = mro[mro.index(HooksTracker) + 1] != object
+
+ def assertEqual(self, actual, expected, msg=''):
+ msg = '{}: {}: {!r} != {!r}'.format(type(self), msg, actual, expected)
+ self._test_case.assertEqual(actual, expected, msg)
+
+ def assertHooksNotCalled(self, *hook_names):
+ for hook_name in hook_names:
+ self.assertEqual(
+ self._call_count[hook_name], 0,
+ 'assertHooksNotCalled: hook_name = {!r}'.format(hook_name))
+
+ def assertHooksCalledOnce(self, *hook_names):
+ for hook_name in hook_names:
+ self.assertEqual(
+ self._call_count[hook_name], 1,
+ 'assertHooksCalledOnce: hook_name = {!r}'.format(hook_name))
+
+ def assertCompleteEpisode(self, control_steps):
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+ physics_steps = control_steps * self._physics_steps_per_control_step
+ self.assertEqual(self._call_count['before_step'], control_steps)
+ self.assertEqual(self._call_count['before_substep'], physics_steps)
+ self.assertEqual(self._call_count['after_substep'], physics_steps)
+ self.assertEqual(self._call_count['after_step'], control_steps)
+
+ def assertPhysicsStepCountEqual(self, physics, expected_count):
+ actual_count = int(round(physics.time() / self._physics_timestep))
+ self.assertEqual(actual_count, expected_count)
+
+ def reset_call_counts(self):
+ self._call_count = collections.defaultdict(lambda: 0)
+
+ def initialize_episode_mjcf(self, random_state):
+ """Implements `initialize_episode_mjcf` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).initialize_episode_mjcf(random_state)
+ if not self.tracked:
+ return
+ self.assertHooksNotCalled('after_compile',
+ 'initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ self._call_count['initialize_episode_mjcf'] += 1
+
+ def after_compile(self, physics, random_state):
+ """Implements `after_compile` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).after_compile(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf')
+ self.assertHooksNotCalled('initialize_episode',
+ 'before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+ self._call_count['after_compile'] += 1
+
+ def initialize_episode(self, physics, random_state):
+ """Implements `initialize_episode` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).initialize_episode(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile')
+ self.assertHooksNotCalled('before_step',
+ 'before_substep',
+ 'after_substep',
+ 'after_step')
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+ self._call_count['initialize_episode'] += 1
+
+ def before_step(self, physics, *args):
+ """Implements `before_step` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).before_step(physics, *args)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # `before_step` is only called in between complete control steps.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'])
+
+ # Complete control steps imply complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(physics,
+ self._call_count['before_substep'])
+
+ self._call_count['before_step'] += 1
+
+ def before_substep(self, physics, *args):
+ """Implements `before_substep` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).before_substep(physics, *args)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # `before_substep` is only called in between complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ self._call_count['before_substep'] += 1
+
+ def after_substep(self, physics, random_state):
+ """Implements `after_substep` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).after_substep(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # We are inside a partial physics step, so `after_substep` should be behind.
+ self.assertEqual(self._call_count['after_substep'],
+ self._call_count['before_substep'] - 1)
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ self._call_count['after_substep'] += 1
+
+ def after_step(self, physics, random_state):
+ """Implements `after_step` Composer callback."""
+ if self._has_super:
+ super(HooksTracker, self).after_step(physics, random_state)
+ if not self.tracked:
+ return
+ self.assertHooksCalledOnce('initialize_episode_mjcf',
+ 'after_compile',
+ 'initialize_episode')
+
+ # We are inside a partial control step, so `after_step` should lag behind.
+ self.assertEqual(
+ self._call_count['after_step'], self._call_count['before_step'] - 1)
+
+ # `after_step` is only called in between complete physics steps.
+ self.assertEqual(
+ self._call_count['after_substep'], self._call_count['before_substep'])
+
+ # Number of physics steps is always consistent with `before_substep`.
+ self.assertPhysicsStepCountEqual(
+ physics, self._call_count['before_substep'])
+
+ # Check that the number of physics steps is consistent with control steps.
+ self.assertEqual(
+ self._call_count['before_substep'],
+ self._call_count['before_step'] * self._physics_steps_per_control_step)
+
+ self._call_count['after_step'] += 1
+
+
+class TrackedEntity(HooksTracker, composer.Entity):
+ """A `composer.Entity` that tracks call order of callbacks."""
+
+ def _build(self, name):
+ self._mjcf_root = mjcf.RootElement(model=name)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def name(self):
+ return self._mjcf_root.model
+
+
+class TrackedTask(HooksTracker, composer.NullTask):
+ """A `composer.Task` that tracks call order of callbacks."""
+
+ def __init__(self, physics_timestep, control_timestep, *args, **kwargs):
+ super(TrackedTask, self).__init__(physics_timestep=physics_timestep,
+ control_timestep=control_timestep,
+ *args, **kwargs)
+ self.set_timesteps(physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ add_bodies_and_actuators(self.root_entity.mjcf_model, num_actuators=4)
+
+
+class HooksTestMixin(object):
+ """A mixin for an `absltest.TestCase` to track call order of callbacks."""
+
+ def setUp(self):
+ """Sets up the test case."""
+ super(HooksTestMixin, self).setUp()
+
+ self.num_episodes = 5
+ self.steps_per_episode = 100
+
+ self.control_timestep = 0.05
+ self.physics_timestep = 0.002
+
+ self.extra_hooks = HooksTracker(physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self)
+
+ self.entities = []
+ for i in range(9):
+ self.entities.append(TrackedEntity(name='entity_{}'.format(i),
+ physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self))
+
+ ########################################
+ # Make the following entity hierarchy #
+ # 0 #
+ # 1 2 3 #
+ # 4 5 6 7 #
+ # 8 #
+ ########################################
+
+ self.entities[4].attach(self.entities[8])
+ self.entities[1].attach(self.entities[4])
+ self.entities[1].attach(self.entities[5])
+ self.entities[0].attach(self.entities[1])
+
+ self.entities[2].attach(self.entities[6])
+ self.entities[2].attach(self.entities[7])
+ self.entities[0].attach(self.entities[2])
+
+ self.entities[0].attach(self.entities[3])
+
+ self.task = TrackedTask(root_entity=self.entities[0],
+ physics_timestep=self.physics_timestep,
+ control_timestep=self.control_timestep,
+ test_case=self)
+
+ @contextlib.contextmanager
+ def track_episode(self):
+ tracked_objects = [self.task, self.extra_hooks] + self.entities
+ for obj in tracked_objects:
+ obj.reset_call_counts()
+ obj.tracked = True
+ yield
+ for obj in tracked_objects:
+ obj.assertCompleteEpisode(self.steps_per_episode)
+ obj.tracked = False
diff --git a/DMC/src/env/dm_control/dm_control/composer/initializer.py b/DMC/src/env/dm_control/dm_control/composer/initializer.py
new file mode 100644
index 0000000..e355125
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/initializer.py
@@ -0,0 +1,33 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract initializer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Initializer(object):
+ """The abstract base class for an initializer."""
+
+ @abc.abstractmethod
+ def __call__(self, physics, random_state):
+ raise NotImplementedError
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py b/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py
new file mode 100644
index 0000000..dd9cbfc
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Multi-rate observation and buffering framework for Composer environments."""
+
+from dm_control.composer.observation import observable
+from dm_control.composer.observation.obs_buffer import Buffer
+from dm_control.composer.observation.updater import DEFAULT_BUFFER_SIZE
+from dm_control.composer.observation.updater import DEFAULT_DELAY
+from dm_control.composer.observation.updater import DEFAULT_UPDATE_INTERVAL
+from dm_control.composer.observation.updater import Updater
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py b/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py
new file mode 100644
index 0000000..e9cc6c6
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A fake Physics class for unit testing observation framework."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from dm_control.composer.observation import observable
+from dm_control.rl import control
+import numpy as np
+
+
+class FakePhysics(control.Physics):
+ """A fake Physics class for unit testing observation framework."""
+
+ def __init__(self):
+ self._step_counter = 0
+ self._observables = {
+ 'twice': observable.Generic(FakePhysics.twice),
+ 'repeated': observable.Generic(FakePhysics.repeated, update_interval=5),
+ 'matrix': observable.Generic(FakePhysics.matrix, update_interval=3)
+ }
+
+ def step(self, sub_steps=1):
+ self._step_counter += 1
+
+ @property
+ def observables(self):
+ return self._observables
+
+ def twice(self):
+ return 2*self._step_counter
+
+ def repeated(self):
+ return [self._step_counter, self._step_counter]
+
+ def sqrt(self):
+ return np.sqrt(self._step_counter)
+
+ def matrix(self):
+ return [[self._step_counter] * 3] * 2
+
+ def time(self):
+ return self._step_counter
+
+ def timestep(self):
+ return 1.0
+
+ def set_control(self, ctrl):
+ pass
+
+ def reset(self):
+ self._step_counter = 0
+
+ def after_reset(self):
+ pass
+
+ @contextlib.contextmanager
+ def suppress_physics_errors(self):
+ yield
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py
new file mode 100644
index 0000000..876c0fa
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py
@@ -0,0 +1,251 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+""""An object that manages the buffering and delaying of observation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+import six
+from six.moves import range
+
+
+class InFlightObservation(object):
+ """Represents a delayed observation that may not have arrived yet.
+
+ Attributes:
+ arrival: The time at which this observation will be delivered.
+ timestamp: The time at which this observation was made.
+ delay: The amount of delay between the time at which this observation was
+ made and the time at which it is delivered.
+ value: The value of this observation.
+ """
+
+ __slots__ = ('arrival', 'timestamp', 'delay', 'value')
+
+ def __init__(self, timestamp, delay, value):
+ self.arrival = timestamp + delay
+ self.timestamp = timestamp
+ self.delay = delay
+ self.value = value
+
+ def __lt__(self, other):
+ # This is implemented to facilitate sorting.
+ return self.arrival < other.arrival
+
+
+class Buffer(object):
+ """An object that manages the buffering and delaying of observation."""
+
+ def __init__(self, buffer_size, shape, dtype, pad_value=0,
+ strip_singleton_buffer_dim=False):
+ """Initializes this observation buffer.
+
+ Args:
+ buffer_size: The size of the buffer returned by `read`. Note
+ that this does *not* affect size of the internal buffer held by this
+ object, which always grow as large as is necessary in the presence of
+ large delays.
+ shape: The shape of a single observation held by this buffer, which can
+ either be a single integer or an iterable of integers. The shape of the
+ buffer returned by `read` will then be
+ `(buffer_size, shape[0], ..., shape[n])`, unless `buffer_size == 1`
+ and `strip_singleton_buffer_dim == True`.
+ dtype: The NumPy dtype of observation entries.
+ pad_value: (optional) The value that is used to pad the buffer returned
+ by `read` when the number of observation entries is less
+ then `buffer_size`. Specifically, the buffer will be padded by
+ `np.full(shape, pad_value, dtype)`.
+ strip_singleton_buffer_dim: (optional) A boolean, if `True` and
+ `buffer_size == 1` then the leading dimension will not be added to the
+ shape of the array returned by `read`.
+ """
+ self._buffer_size = buffer_size
+ try:
+ shape = tuple(shape)
+ except TypeError:
+ if isinstance(shape, int):
+ shape = (shape,)
+ else:
+ raise
+
+ self._has_buffer_dim = not (strip_singleton_buffer_dim and buffer_size == 1)
+ if self._has_buffer_dim:
+ self._buffered_shape = (buffer_size,) + shape
+ else:
+ self._buffered_shape = shape
+ self._dtype = dtype
+
+ # The "arrived" deque contains entries that are due to be delivered now.
+ # This deque should never grow beyond buffer_size.
+ self._arrived_deque = collections.deque(maxlen=buffer_size)
+ for _ in range(buffer_size):
+ self._arrived_deque.append(
+ InFlightObservation(-np.inf, 0, np.full(shape, pad_value, dtype)))
+
+ # The "pending" deque contains entries that are stored for future delivery.
+ # This deque can grow arbitrarily large in presence of long delays.
+ self._pending_deque = collections.deque()
+
+ def _update_arrived_deque(self, timestamp):
+ while self._pending_deque and self._pending_deque[0].arrival <= timestamp:
+ self._arrived_deque.append(self._pending_deque.popleft())
+
+ @property
+ def shape(self):
+ return self._buffered_shape
+
+ @property
+ def dtype(self):
+ return self._dtype
+
+ def insert(self, timestamp, delay, value):
+ """Inserts a new observation to the buffer.
+
+ This function implicitly updates the internal "clock" of this buffer to
+ the timestamp of the new observation, and the internal buffer is trimmed
+ accordingly, i.e. at most `buffer_size` items whose delayed arrival time
+ preceeds `timestamp` are kept.
+
+ Args:
+ timestamp: The time at which this observation was made.
+ delay: The amount of delay between the time at which this observation was
+ made and the time at which it is delivered.
+ value: The value of this observation.
+
+ Raises:
+ ValueError: if `delay` is negative.
+ """
+ self._update_arrived_deque(timestamp)
+ new_obs = InFlightObservation(timestamp, delay, np.array(value))
+ arrival = new_obs.arrival
+ if delay == 0:
+ # No delay, so the new observation is due for immediate delivery.
+ # Add it to the arrived deque.
+ self._arrived_deque.append(new_obs)
+ elif delay > 0:
+ if not self._pending_deque or arrival > self._pending_deque[-1].arrival:
+ # New observation's arrival time is monotonic.
+ # Technically, we can handle this in the general code branch below,
+ # but since this is assumed to be the "typical" case, the special
+ # handling here saves us from repeatedly allocating and deallocating
+ # an empty temporary deque.
+ self._pending_deque.append(new_obs)
+ else:
+ # General, out-of-order observation.
+ arriving_after_new_obs = collections.deque()
+ while self._pending_deque and arrival < self._pending_deque[-1].arrival:
+ arriving_after_new_obs.appendleft(self._pending_deque.pop())
+ self._pending_deque.append(new_obs)
+ for existing_obs in arriving_after_new_obs:
+ self._pending_deque.append(existing_obs)
+ else:
+ raise ValueError('`delay` should not be negative: '
+ 'got {!r}'.format(delay))
+
+ def read(self, current_time):
+ """Reads the content of the buffer at the given timestamp."""
+ self._update_arrived_deque(current_time)
+ if self._has_buffer_dim:
+ out = np.empty(self._buffered_shape, dtype=self._dtype)
+ for i, obs in enumerate(self._arrived_deque):
+ out[i] = obs.value
+ else:
+ out = self._arrived_deque[0].value.copy()
+ return out
+
+ def drop_unobserved_upcoming_items(self, observation_schedule, read_interval):
+ """Plans an optimal observation schedule for an upcoming control period.
+
+ This function determines which of the proposed upcoming observations will
+ never in fact be delivered and removes them from the observation schedule.
+
+ We assume that observations will only be queried at times that are integer
+ multiples of `read_interval`. If more observations are generated during
+ the upcoming control step than the `buffer_size` of this `Buffer`
+ then of those new observations will never be required. This function takes
+ into account the delayed arrival time and existing buffered items in the
+ planning process.
+
+ Args:
+ observation_schedule: An list of `(timestamp, delay)` tuples, where
+ `timestamp` is the time at which the observation value will be produced,
+ and `delay` is the amount of time the observation will be delayed by.
+ This list will be modified in place.
+ read_interval: The time interval between successive calls to `read`.
+ We assume that observations will only be queried at times that are
+ integer multiples of `read_interval`.
+ """
+ # Private deques to simulate what the deques will look like in the future,
+ # according to the proposed upcoming observation schedule.
+ future_arrived_deque = collections.deque()
+ future_pending_deque = collections.deque()
+
+ # Take existing buffered observations into account when planning the
+ # upcoming schedule.
+ def get_next_existing_timestamp():
+ for obs in reversed(self._pending_deque):
+ yield InFlightObservation(obs.timestamp, obs.delay, None)
+ while True:
+ yield InFlightObservation(-np.inf, 0, None)
+ existing_timestamp_iter = get_next_existing_timestamp()
+ existing_timestamp = six.next(existing_timestamp_iter)
+
+ # Build the simulated state of the pending deque at the end of the proposed
+ # schedule.
+ sorted_schedule = sorted([InFlightObservation(time[0], time[1], None)
+ for time in observation_schedule])
+ for new_timestamp in reversed(sorted_schedule):
+ # We don't need to worry about any existing item that are delivered before
+ # the first new item, since those are purged independently of our
+ # proposed new observations.
+ while existing_timestamp.arrival > new_timestamp.arrival:
+ future_pending_deque.appendleft(existing_timestamp)
+ existing_timestamp = six.next(existing_timestamp_iter)
+ future_pending_deque.appendleft(new_timestamp)
+
+ # Find the next timestep at which `read` is called.
+ first_proposed_timestamp = min(t for t, _ in observation_schedule)
+ next_read_time = read_interval * int(np.ceil(
+ first_proposed_timestamp // read_interval))
+
+ # Build the simulated state of the arrived deque at each subsequent
+ # control steps.
+ while future_pending_deque:
+ # Keep track of observations that are delivered for the first time
+ # during this control timestep.
+ newly_arrived = collections.deque()
+ while (future_pending_deque and
+ future_pending_deque[0].arrival <= next_read_time):
+ # `fake_observation` is an `InFlightObservation` without `value`.
+ fake_observation = future_pending_deque.popleft()
+ future_arrived_deque.append(fake_observation)
+ newly_arrived.append(fake_observation)
+ while len(future_arrived_deque) > self._buffer_size:
+ stale = future_arrived_deque.popleft()
+ # Newly-arrived items that become immediately stale are never actually
+ # delivered.
+ if newly_arrived and stale == newly_arrived[0]:
+ newly_arrived.popleft()
+ # `stale` might either be one of the existing pending observations or
+ # from the proposed schedule.
+ if stale.timestamp >= first_proposed_timestamp:
+ observation_schedule.remove((stale.timestamp, stale.delay))
+
+ next_read_time += read_interval
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py
new file mode 100644
index 0000000..73841fb
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py
@@ -0,0 +1,83 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for observation.obs_buffer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.observation import obs_buffer
+import numpy as np
+from six.moves import range
+
+
+def _generate_constant_schedule(update_timestep, delay,
+ control_timestep, n_observed_steps):
+ first = update_timestep
+ last = control_timestep * n_observed_steps + 1
+ return [(i, delay) for i in range(first, last, update_timestep)]
+
+
+class BufferTest(parameterized.TestCase):
+
+ def testOutOfOrderArrival(self):
+ buf = obs_buffer.Buffer(buffer_size=3, shape=(), dtype=np.float)
+ buf.insert(timestamp=0, delay=4, value=1)
+ buf.insert(timestamp=1, delay=2, value=2)
+ buf.insert(timestamp=2, delay=3, value=3)
+ np.testing.assert_array_equal(buf.read(current_time=2), [0., 0., 0.])
+ np.testing.assert_array_equal(buf.read(current_time=3), [0., 0., 2.])
+ np.testing.assert_array_equal(buf.read(current_time=4), [0., 2., 1.])
+ np.testing.assert_array_equal(buf.read(current_time=5), [2., 1., 3.])
+ np.testing.assert_array_equal(buf.read(current_time=6), [2., 1., 3.])
+
+ @parameterized.parameters(((3, 3),), ((),))
+ def testStripSingletonDimension(self, shape):
+ buf = obs_buffer.Buffer(buffer_size=1, shape=shape, dtype=np.float,
+ strip_singleton_buffer_dim=True)
+ expected_value = np.full(shape, 42, dtype=np.float)
+ buf.insert(timestamp=0, delay=0, value=expected_value)
+ np.testing.assert_array_equal(buf.read(current_time=1), expected_value)
+
+ def testPlanToSingleUndelayedObservation(self):
+ buf = obs_buffer.Buffer(
+ buffer_size=1, shape=(), dtype=np.float)
+ control_timestep = 20
+ observation_schedule = _generate_constant_schedule(
+ update_timestep=1, delay=0,
+ control_timestep=control_timestep, n_observed_steps=1)
+ buf.drop_unobserved_upcoming_items(
+ observation_schedule, read_interval=control_timestep)
+ self.assertEqual(observation_schedule, [(20, 0)])
+
+ def testPlanTwoStepsAhead(self):
+ buf = obs_buffer.Buffer(
+ buffer_size=1, shape=(), dtype=np.float)
+ control_timestep = 5
+ observation_schedule = _generate_constant_schedule(
+ update_timestep=2, delay=3,
+ control_timestep=control_timestep, n_observed_steps=2)
+ buf.drop_unobserved_upcoming_items(
+ observation_schedule, read_interval=control_timestep)
+ self.assertEqual(observation_schedule, [(2, 3), (6, 3), (10, 3)])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py
new file mode 100644
index 0000000..cd61957
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module for observables in the Composer library."""
+
+from dm_control.composer.observation.observable.base import Generic
+from dm_control.composer.observation.observable.base import MujocoCamera
+from dm_control.composer.observation.observable.base import MujocoFeature
+from dm_control.composer.observation.observable.base import Observable
+
+from dm_control.composer.observation.observable.mjcf import MJCFCamera
+from dm_control.composer.observation.observable.mjcf import MJCFFeature
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py
new file mode 100644
index 0000000..5b8cddc
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py
@@ -0,0 +1,318 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Classes representing observables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import functools
+
+from dm_env import specs
+import numpy as np
+import six
+
+
+def _make_aggregator(np_reducer_func, bounds_preserving):
+ result = functools.partial(np_reducer_func, axis=0)
+ setattr(result, 'bounds_reserving', bounds_preserving)
+ return result
+
+
+AGGREGATORS = {
+ 'min': _make_aggregator(np.min, True),
+ 'max': _make_aggregator(np.max, True),
+ 'mean': _make_aggregator(np.mean, True),
+ 'median': _make_aggregator(np.median, True),
+ 'sum': _make_aggregator(np.sum, False),
+}
+
+
+def _get_aggregator(name_or_callable):
+ """Returns aggregator from predefined set by name, else returns callable."""
+ if name_or_callable is None:
+ return None
+ elif not callable(name_or_callable):
+ try:
+ return AGGREGATORS[name_or_callable]
+ except KeyError:
+ raise KeyError('Unrecognized aggregator name: {!r}. Valid names: {}.'
+ .format(name_or_callable, AGGREGATORS.keys()))
+ else:
+ return name_or_callable
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Observable(object):
+ """Abstract base class for an observable."""
+
+ def __init__(self, update_interval, buffer_size, delay,
+ aggregator, corruptor):
+ self._update_interval = update_interval
+ self._buffer_size = buffer_size
+ self._delay = delay
+ self._aggregator = _get_aggregator(aggregator)
+ self._corruptor = corruptor
+ self._enabled = False
+
+ @property
+ def update_interval(self):
+ return self._update_interval
+
+ @update_interval.setter
+ def update_interval(self, value):
+ self._update_interval = value
+
+ @property
+ def buffer_size(self):
+ return self._buffer_size
+
+ @buffer_size.setter
+ def buffer_size(self, value):
+ self._buffer_size = value
+
+ @property
+ def delay(self):
+ return self._delay
+
+ @delay.setter
+ def delay(self, value):
+ self._delay = value
+
+ @property
+ def aggregator(self):
+ return self._aggregator
+
+ @aggregator.setter
+ def aggregator(self, value):
+ self._aggregator = _get_aggregator(value)
+
+ @property
+ def corruptor(self):
+ return self._corruptor
+
+ @corruptor.setter
+ def corruptor(self, value):
+ self._corruptor = value
+
+ @property
+ def enabled(self):
+ return self._enabled
+
+ @enabled.setter
+ def enabled(self, value):
+ self._enabled = value
+
+ @property
+ def array_spec(self):
+ """The `ArraySpec` which describes observation arrays from this observable.
+
+ If this property is `None`, then the specification should be inferred by
+ actually retrieving an observation from this observable.
+ """
+ return None
+
+ @abc.abstractmethod
+ def _callable(self, physics):
+ pass
+
+ def observation_callable(self, physics, random_state=None):
+ """A callable which returns a (potentially corrupted) observation."""
+ raw_callable = self._callable(physics)
+ if self._corruptor:
+ def _corrupted():
+ return self._corruptor(raw_callable(), random_state=random_state)
+ return _corrupted
+ else:
+ return raw_callable
+
+ def __call__(self, physics, random_state=None):
+ """Convenience function to just call an observable."""
+ return self.observation_callable(physics, random_state)()
+
+ def configure(self, **kwargs):
+ """Sets multiple attributes of this observable.
+
+ Args:
+ **kwargs: The keyword argument names correspond to the attributes
+ being modified.
+ Raises:
+ AttributeError: If kwargs contained an attribute not in the observable.
+ """
+ for key, value in six.iteritems(kwargs):
+ if not hasattr(self, key):
+ raise AttributeError('Cannot add attribute %s in configure.' % key)
+ self.__setattr__(key, value)
+
+
+class Generic(Observable):
+ """A generic observable defined via a callable."""
+
+ def __init__(self, raw_observation_callable, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None):
+ """Initializes this observable.
+
+ Args:
+ raw_observation_callable: A callable which accepts a single argument of
+ type `control.base.Physics` and returns the observation value.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ """
+ self._raw_callable = raw_observation_callable
+ super(Generic, self).__init__(
+ update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ return lambda: self._raw_callable(physics)
+
+
+class MujocoFeature(Observable):
+ """An observable corresponding to a named MuJoCo feature."""
+
+ def __init__(self, kind, feature_name, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None):
+ """Initializes this observable.
+
+ Args:
+ kind: A string corresponding to a field name in MuJoCo's mjData struct.
+ feature_name: A string, or list of strings, or a callable returning
+ either, corresponding to the name(s) of an entity in the
+ MuJoCo XML model.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ """
+ self._kind = kind
+ self._feature_name = feature_name
+ super(MujocoFeature, self).__init__(
+ update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ named_indexer_for_kind = physics.named.data.__getattribute__(self._kind)
+ if callable(self._feature_name):
+ return lambda: named_indexer_for_kind[self._feature_name()]
+ else:
+ return lambda: named_indexer_for_kind[self._feature_name]
+
+
+class MujocoCamera(Observable):
+ """An observable corresponding to a MuJoCo camera."""
+
+ def __init__(self, camera_name, height=240, width=320, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None, depth=False):
+ """Initializes this observable.
+
+ Args:
+ camera_name: A string corresponding to the name of a camera in the
+ MuJoCo XML model.
+ height: (optional) An integer, the height of the rendered image.
+ width: (optional) An integer, the width of the rendered image.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ depth: (optional) A boolean. If `True`, renders a depth image (1-channel)
+ instead of RGB (3-channel).
+ """
+ self._camera_name = camera_name
+ self._height = height
+ self._width = width
+
+ self._n_channels = 1 if depth else 3
+ self._dtype = np.float32 if depth else np.uint8
+ self._depth = depth
+ super(MujocoCamera, self).__init__(
+ update_interval, buffer_size, delay, aggregator, corruptor)
+
+ @property
+ def height(self):
+ return self._height
+
+ @height.setter
+ def height(self, value):
+ self._height = value
+
+ @property
+ def width(self):
+ return self._width
+
+ @width.setter
+ def width(self, value):
+ self._width = value
+
+ @property
+ def array_spec(self):
+ return specs.Array(
+ shape=(self._height, self._width, self._n_channels), dtype=self._dtype)
+
+ def _callable(self, physics):
+ return lambda: physics.render( # pylint: disable=g-long-lambda
+ self._height, self._width, self._camera_name, depth=self._depth)
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py
new file mode 100644
index 0000000..0f519c2
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py
@@ -0,0 +1,154 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for observable."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from dm_control import mujoco
+from dm_control.composer.observation import fake_physics
+from dm_control.composer.observation.observable import base
+import numpy as np
+import six
+
+
+_MJCF = """
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+class _FakeBaseObservable(base.Observable):
+
+ def _callable(self, physics):
+ pass
+
+
+class ObservableTest(absltest.TestCase):
+
+ def testBaseProperties(self):
+ fake_observable = _FakeBaseObservable(update_interval=42,
+ buffer_size=5,
+ delay=10,
+ aggregator=None,
+ corruptor=None)
+ self.assertEqual(fake_observable.update_interval, 42)
+ self.assertEqual(fake_observable.buffer_size, 5)
+ self.assertEqual(fake_observable.delay, 10)
+
+ fake_observable.update_interval = 48
+ self.assertEqual(fake_observable.update_interval, 48)
+
+ fake_observable.buffer_size = 7
+ self.assertEqual(fake_observable.buffer_size, 7)
+
+ fake_observable.delay = 13
+ self.assertEqual(fake_observable.delay, 13)
+
+ enabled = not fake_observable.enabled
+ fake_observable.enabled = not fake_observable.enabled
+ self.assertEqual(fake_observable.enabled, enabled)
+
+ def testGeneric(self):
+ physics = fake_physics.FakePhysics()
+ repeated_observable = base.Generic(
+ fake_physics.FakePhysics.repeated, update_interval=42)
+ repeated_observation = repeated_observable.observation_callable(physics)()
+ self.assertEqual(repeated_observable.update_interval, 42)
+ np.testing.assert_array_equal(repeated_observation, [0, 0])
+
+ def testMujocoFeature(self):
+ physics = mujoco.Physics.from_xml_string(_MJCF)
+
+ hinge_observable = base.MujocoFeature(
+ kind='qpos', feature_name='my_hinge')
+ hinge_observation = hinge_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ hinge_observation, physics.named.data.qpos['my_hinge'])
+
+ box_observable = base.MujocoFeature(
+ kind='geom_xpos', feature_name='small_sphere', update_interval=5)
+ box_observation = box_observable.observation_callable(physics)()
+ self.assertEqual(box_observable.update_interval, 5)
+ np.testing.assert_array_equal(
+ box_observation, physics.named.data.geom_xpos['small_sphere'])
+
+ observable_from_callable = base.MujocoFeature(
+ kind='geom_xpos', feature_name=lambda: ['my_box', 'small_sphere'])
+ observation_from_callable = (
+ observable_from_callable.observation_callable(physics)())
+ np.testing.assert_array_equal(
+ observation_from_callable,
+ physics.named.data.geom_xpos[['my_box', 'small_sphere']])
+
+ def testMujocoCamera(self):
+ physics = mujoco.Physics.from_xml_string(_MJCF)
+
+ camera_observable = base.MujocoCamera(
+ camera_name='world', height=480, width=640, update_interval=7)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(480, 640, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ camera_observable.height = 300
+ camera_observable.width = 400
+ camera_observation = camera_observable.observation_callable(physics)()
+ self.assertEqual(camera_observable.height, 300)
+ self.assertEqual(camera_observable.width, 400)
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(300, 400, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ def testCorruptor(self):
+ physics = fake_physics.FakePhysics()
+ def add_twelve(old_value, random_state):
+ del random_state # Unused.
+ return [x + 12 for x in old_value]
+ repeated_observable = base.Generic(
+ fake_physics.FakePhysics.repeated, corruptor=add_twelve)
+ corrupted = repeated_observable.observation_callable(
+ physics=physics, random_state=None)()
+ np.testing.assert_array_equal(corrupted, [12, 12])
+
+ def testInvalidAggregatorName(self):
+ name = 'invalid_name'
+ with six.assertRaisesRegex(self, KeyError, 'Unrecognized aggregator name'):
+ _ = _FakeBaseObservable(update_interval=3, buffer_size=2, delay=1,
+ aggregator=name, corruptor=None)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py
new file mode 100644
index 0000000..84993c5
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py
@@ -0,0 +1,257 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Observables that are defined in terms of MJCF elements."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base
+from dm_env import specs
+import numpy as np
+
+
+_BOTH_SEGMENTATION_AND_DEPTH_ENABLED = (
+ '`segmentation` and `depth` cannot both be `True`.')
+
+
+def _check_mjcf_element(obj):
+ if not isinstance(obj, mjcf.Element):
+ raise ValueError(
+ 'expected an `mjcf.Element`, got type {}: {}'.format(type(obj), obj))
+
+
+def _check_mjcf_element_iterable(obj_iterable):
+ if not isinstance(obj_iterable, collections.Iterable):
+ obj_iterable = (obj_iterable,)
+ for obj in obj_iterable:
+ _check_mjcf_element(obj)
+
+
+class MJCFFeature(base.Observable):
+ """An observable corresponding to an element in an MJCF model."""
+
+ def __init__(self, kind, mjcf_element, update_interval=1,
+ buffer_size=None, delay=None,
+ aggregator=None, corruptor=None, index=None):
+ """Initializes this observable.
+
+ Args:
+ kind: The name of an attribute of a bound `mjcf.Physics` instance. See the
+ docstring for `mjcf.Physics.bind()` for examples showing this syntax.
+ mjcf_element: An `mjcf.Element`, or iterable of `mjcf.Element`.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ index: (optional) An index that is to be applied to an array attribute
+ to pick out a slice or particular items. As a syntactic sugar,
+ `MJCFFeature` also implements `__getitem__` that returns a copy of the
+ same observable with an index applied.
+
+ Raises:
+ ValueError: if `mjcf_element` is not an `mjcf.Element`.
+ """
+ _check_mjcf_element_iterable(mjcf_element)
+ self._kind = kind
+ self._mjcf_element = mjcf_element
+ self._index = index
+ super(MJCFFeature, self).__init__(
+ update_interval, buffer_size, delay, aggregator, corruptor)
+
+ def _callable(self, physics):
+ binding = physics.bind(self._mjcf_element)
+ if self._index is not None:
+ return lambda: getattr(binding, self._kind)[self._index]
+ else:
+ return lambda: getattr(binding, self._kind)
+
+ def __getitem__(self, key):
+ if self._index is not None:
+ raise NotImplementedError(
+ 'slicing an already-sliced MJCFFeature observable is not supported')
+ return MJCFFeature(self._kind, self._mjcf_element, self._update_interval,
+ self._buffer_size, self._delay, self._aggregator,
+ self._corruptor, key)
+
+
+class MJCFCamera(base.Observable):
+ """An observable corresponding to a camera in an MJCF model."""
+
+ def __init__(self,
+ mjcf_element,
+ height=240,
+ width=320,
+ update_interval=1,
+ buffer_size=None,
+ delay=None,
+ aggregator=None,
+ corruptor=None,
+ depth=False,
+ segmentation=False,
+ scene_option=None):
+ """Initializes this observable.
+
+ Args:
+ mjcf_element: A `mjcf.Element`.
+ height: (optional) An integer, the height of the rendered image.
+ width: (optional) An integer, the width of the rendered image.
+ update_interval: (optional) An integer, number of simulation steps between
+ successive updates to the value of this observable.
+ buffer_size: (optional) The maximum size of the returned buffer.
+ This option is only relevant when used in conjunction with an
+ `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will
+ be used.
+ delay: (optional) Number of additional simulation steps that must be
+ taken before an observation is returned. This option is only relevant
+ when used in conjunction with an`observation.Updater`. If None,
+ `observation.DEFAULT_DELAY` will be used.
+ aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that
+ performs a reduction operation over the first dimension of the buffered
+ observation before it is returned. A value of `None` means that no
+ aggregation will be performed and the whole buffer will be returned.
+ corruptor: (optional) A callable which takes a single observation as
+ an argument, modifies it, and returns it. An example use case for this
+ is to add random noise to the observation. When used in a
+ `BufferedWrapper`, the corruptor is applied to the observation before
+ it is added to the buffer. In particular, this means that the aggregator
+ operates on corrupted observations.
+ depth: (optional) A boolean. If `True`, renders a depth image (1-channel)
+ instead of RGB (3-channel).
+ segmentation: (optional) A boolean. If `True`, renders a segmentation mask
+ (2-channel, int32) labeling the objects in the scene with their
+ (mjModel ID, mjtObj enum object type) pair. Background pixels are
+ set to (-1, -1).
+ scene_option: An optional `wrapper.MjvOption` instance that can be used to
+ render the scene with custom visualization options. If None then the
+ default options will be used.
+
+ Raises:
+ ValueError: if `mjcf_element` is not a element.
+ ValueError: if segmentation and depth flags are both set to True.
+ """
+ _check_mjcf_element(mjcf_element)
+ if mjcf_element.tag != 'camera':
+ raise ValueError(
+ 'expected a element: got {}'.format(mjcf_element))
+ self._mjcf_element = mjcf_element
+ self._height = height
+ self._width = width
+
+ if segmentation and depth:
+ raise ValueError(_BOTH_SEGMENTATION_AND_DEPTH_ENABLED)
+ if segmentation:
+ self._dtype = np.int32
+ self._n_channels = 2
+ elif depth:
+ self._dtype = np.float32
+ self._n_channels = 1
+ else:
+ self._dtype = np.uint8
+ self._n_channels = 3
+ self._depth = depth
+ self._segmentation = segmentation
+ self._scene_option = scene_option
+ super(MJCFCamera, self).__init__(
+ update_interval, buffer_size, delay, aggregator, corruptor)
+
+ @property
+ def height(self):
+ return self._height
+
+ @height.setter
+ def height(self, value):
+ self._height = value
+
+ @property
+ def width(self):
+ return self._width
+
+ @width.setter
+ def width(self, value):
+ self._width = value
+
+ @property
+ def depth(self):
+ return self._depth
+
+ @depth.setter
+ def depth(self, value):
+ self._depth = value
+
+ @property
+ def segmentation(self):
+ return self._segmentation
+
+ @segmentation.setter
+ def segmentation(self, value):
+ self._segmentation = value
+
+ @property
+ def array_spec(self):
+ if self._depth:
+ # Note that these are loose bounds - the exact bounds are given by:
+ # extent*(znear, zfar), however the values of these parameters are unknown
+ # since we don't have access to the compiled model within this method.
+ minimum = 0.0
+ maximum = np.inf
+ elif self._segmentation:
+ # -1 denotes background pixels. See dm_control.mujoco.Camera.render for
+ # further details.
+ minimum = -1
+ maximum = np.iinfo(self._dtype).max
+ else:
+ minimum = np.iinfo(self._dtype).min
+ maximum = np.iinfo(self._dtype).max
+
+ return specs.BoundedArray(
+ minimum=minimum,
+ maximum=maximum,
+ shape=(self._height, self._width, self._n_channels),
+ dtype=self._dtype)
+
+ def _callable(self, physics):
+
+ def get_observation():
+ pixels = physics.render(
+ height=self._height,
+ width=self._width,
+ camera_id=self._mjcf_element.full_identifier,
+ depth=self._depth,
+ segmentation=self._segmentation,
+ scene_option=self._scene_option)
+ return np.atleast_3d(pixels)
+
+ return get_observation
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py
new file mode 100644
index 0000000..9150fa9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py
@@ -0,0 +1,181 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for mjcf observables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.observation.observable import mjcf as mjcf_observable
+from dm_env import specs
+import numpy as np
+import six
+
+_MJCF = """
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+
+class ObservableTest(parameterized.TestCase):
+
+ def testMJCFFeature(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ my_hinge = mjcf_root.find('joint', 'my_hinge')
+ hinge_observable = mjcf_observable.MJCFFeature(
+ kind='qpos', mjcf_element=my_hinge)
+ hinge_observation = hinge_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ hinge_observation, physics.named.data.qpos[my_hinge.full_identifier])
+
+ small_sphere = mjcf_root.find('geom', 'small_sphere')
+ sphere_observable = mjcf_observable.MJCFFeature(
+ kind='xpos', mjcf_element=small_sphere, update_interval=5)
+ sphere_observation = sphere_observable.observation_callable(physics)()
+ self.assertEqual(sphere_observable.update_interval, 5)
+ np.testing.assert_array_equal(
+ sphere_observation, physics.named.data.geom_xpos[
+ small_sphere.full_identifier])
+
+ my_box = mjcf_root.find('geom', 'my_box')
+ list_observable = mjcf_observable.MJCFFeature(
+ kind='xpos', mjcf_element=[my_box, small_sphere])
+ list_observation = (
+ list_observable.observation_callable(physics)())
+ np.testing.assert_array_equal(
+ list_observation,
+ physics.named.data.geom_xpos[[my_box.full_identifier,
+ small_sphere.full_identifier]])
+
+ with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFFeature('qpos', 'my_hinge')
+ with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFFeature('geom_xpos', [my_box, 'small_sphere'])
+
+ def testMJCFFeatureIndex(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ small_sphere = mjcf_root.find('geom', 'small_sphere')
+ sphere_xmat = np.array(
+ physics.named.data.geom_xmat[small_sphere.full_identifier])
+
+ observable_xrow = mjcf_observable.MJCFFeature(
+ 'xmat', small_sphere, index=[1, 3, 5, 7])
+ np.testing.assert_array_equal(
+ observable_xrow.observation_callable(physics)(),
+ sphere_xmat[[1, 3, 5, 7]])
+
+ observable_yyzz = mjcf_observable.MJCFFeature('xmat', small_sphere)[2:6]
+ np.testing.assert_array_equal(
+ observable_yyzz.observation_callable(physics)(), sphere_xmat[2:6])
+
+ def testMJCFCamera(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+
+ camera = mjcf_root.find('camera', 'world')
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=480, width=640, update_interval=7)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(480, 640, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ camera_observable.height = 300
+ camera_observable.width = 400
+ camera_observation = camera_observable.observation_callable(physics)()
+ self.assertEqual(camera_observable.height, 300)
+ self.assertEqual(camera_observable.width, 400)
+ np.testing.assert_array_equal(
+ camera_observation, physics.render(300, 400, 'world'))
+ self.assertEqual(camera_observation.shape,
+ camera_observable.array_spec.shape)
+ self.assertEqual(camera_observation.dtype,
+ camera_observable.array_spec.dtype)
+
+ with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFCamera('world')
+ with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'):
+ mjcf_observable.MJCFCamera([camera])
+ with six.assertRaisesRegex(self, ValueError, 'expected a '):
+ mjcf_observable.MJCFCamera(mjcf_root.find('body', 'body'))
+
+ @parameterized.parameters(
+ dict(camera_type='rgb', channels=3, dtype=np.uint8,
+ minimum=0, maximum=255),
+ dict(camera_type='depth', channels=1, dtype=np.float32,
+ minimum=0., maximum=np.inf),
+ dict(camera_type='segmentation', channels=2, dtype=np.int32,
+ minimum=-1, maximum=np.iinfo(np.int32).max),
+ )
+ def testMJCFCameraSpecs(self, camera_type, channels, dtype, minimum, maximum):
+ width = 640
+ height = 480
+ shape = (height, width, channels)
+ expected_spec = specs.BoundedArray(
+ shape=shape, dtype=dtype, minimum=minimum, maximum=maximum)
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ camera = mjcf_root.find('camera', 'world')
+ observable_kwargs = {} if camera_type == 'rgb' else {camera_type: True}
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=height, width=width, update_interval=7,
+ **observable_kwargs)
+ self.assertEqual(camera_observable.array_spec, expected_spec)
+
+ def testMJCFSegCamera(self):
+ mjcf_root = mjcf.from_xml_string(_MJCF)
+ physics = mjcf.Physics.from_mjcf_model(mjcf_root)
+ camera = mjcf_root.find('camera', 'world')
+ camera_observable = mjcf_observable.MJCFCamera(
+ mjcf_element=camera, height=480, width=640, update_interval=7,
+ segmentation=True)
+ self.assertEqual(camera_observable.update_interval, 7)
+ camera_observation = camera_observable.observation_callable(physics)()
+ np.testing.assert_array_equal(
+ camera_observation,
+ physics.render(480, 640, 'world', segmentation=True))
+ camera_observable.array_spec.validate(camera_observation)
+
+ def testErrorIfSegmentationAndDepthBothEnabled(self):
+ camera = mjcf.from_xml_string(_MJCF).find('camera', 'world')
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, mjcf_observable._BOTH_SEGMENTATION_AND_DEPTH_ENABLED):
+ mjcf_observable.MJCFCamera(mjcf_element=camera, segmentation=True,
+ depth=True)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/updater.py b/DMC/src/env/dm_control/dm_control/composer/observation/updater.py
new file mode 100644
index 0000000..937745a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/updater.py
@@ -0,0 +1,302 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""An object that creates and updates buffers for enabled observables."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from absl import logging
+
+from dm_control.composer.observation import obs_buffer
+from dm_env import specs
+import numpy as np
+import six
+from six.moves import range
+
+DEFAULT_BUFFER_SIZE = 1
+DEFAULT_UPDATE_INTERVAL = 1
+DEFAULT_DELAY = 0
+
+
+class _EnabledObservable(object):
+ """Encapsulates an enabled observable, its buffer, and its update schedule."""
+
+ __slots__ = ('observable', 'observation_callable',
+ 'buffer', 'update_schedule')
+
+ def __init__(self, observable, physics, random_state,
+ strip_singleton_buffer_dim):
+ self.observable = observable
+ self.observation_callable = (
+ observable.observation_callable(physics, random_state))
+
+ obs_spec = self.observable.array_spec
+ if obs_spec is None:
+ # We take an observation to determine the shape and dtype of the array.
+ # This occurs outside of an episode and doesn't affect environment
+ # behavior. At this point the physics state is not guaranteed to be valid,
+ # so we might get a `PhysicsError` if the observation callable calls
+ # `physics.forward`. We suppress such errors since they do not matter as
+ # far as the shape and dtype of the observation are concerned.
+ with physics.suppress_physics_errors():
+ obs_array = self.observation_callable()
+ obs_array = np.asarray(obs_array)
+ obs_spec = specs.Array(shape=obs_array.shape, dtype=obs_array.dtype)
+ self.buffer = obs_buffer.Buffer(
+ buffer_size=(observable.buffer_size or DEFAULT_BUFFER_SIZE),
+ shape=obs_spec.shape, dtype=obs_spec.dtype,
+ strip_singleton_buffer_dim=strip_singleton_buffer_dim)
+ self.update_schedule = collections.deque()
+
+
+def _call_if_callable(arg):
+ if callable(arg):
+ return arg()
+ else:
+ return arg
+
+
+def _validate_structure(structure):
+ """Validates the structure of the given observables collection.
+
+ The collection must either be a dict, or a (list or tuple) of dicts.
+
+ Args:
+ structure: A candidate collection of observables.
+
+ Returns:
+ A boolean that is `True` if `structure` is either a list or a tuple, or
+ `False` otherwise.
+
+ Raises:
+ ValueError: If `structure` is neither a dict nor a (list or tuple) of dicts.
+ """
+ is_nested = isinstance(structure, (list, tuple))
+ if is_nested:
+ is_valid = all(isinstance(obj, dict) for obj in structure)
+ else:
+ is_valid = isinstance(structure, dict)
+ if not is_valid:
+ raise ValueError(
+ '`observables` should be a dict, or a (list or tuple) of dicts'
+ ': got {}'.format(structure))
+ return is_nested
+
+
+class Updater(object):
+ """Creates and updates buffers for enabled observables."""
+
+ def __init__(self, observables, physics_steps_per_control_step=1,
+ strip_singleton_buffer_dim=False):
+ self._physics_steps_per_control_step = physics_steps_per_control_step
+ self._strip_singleton_buffer_dim = strip_singleton_buffer_dim
+ self._step_counter = 0
+ self._observables = observables
+ self._is_nested = _validate_structure(observables)
+ self._enabled_structure = None
+ self._enabled_list = None
+
+ def reset(self, physics, random_state):
+ """Resets this updater's state."""
+
+ def make_buffers_dict(observables):
+ """Makes observable states in a dict."""
+ # Use `type(observables)` so that our output structure respects the
+ # original dict subclass (e.g. OrderedDict).
+ out_dict = type(observables)()
+ for key, value in six.iteritems(observables):
+ if value.enabled:
+ out_dict[key] = _EnabledObservable(value, physics, random_state,
+ self._strip_singleton_buffer_dim)
+ return out_dict
+
+ if self._is_nested:
+ self._enabled_structure = type(self._observables)(
+ make_buffers_dict(obs_dict) for obs_dict in self._observables)
+ self._enabled_list = []
+ for enabled_dict in self._enabled_structure:
+ self._enabled_list.extend(enabled_dict.values())
+ else:
+ self._enabled_structure = make_buffers_dict(self._observables)
+ self._enabled_list = self._enabled_structure.values()
+
+ self._step_counter = 0
+ for enabled in self._enabled_list:
+ first_delay = _call_if_callable(enabled.observable.delay or DEFAULT_DELAY)
+ enabled.buffer.insert(
+ 0, first_delay,
+ enabled.observation_callable())
+
+ def observation_spec(self):
+ """The observation specification for this environment.
+
+ Returns a dict mapping the names of enabled observations to their
+ corresponding `Array` or `BoundedArray` specs.
+
+ If an obs has a BoundedArray spec, but uses an aggregator that
+ does not preserve those bounds (such as `sum`), it will be mapped to an
+ (unbounded) `Array` spec. If using a bounds-preserving custom aggregator
+ `my_agg`, give it an attribute `my_agg.preserves_bounds = True` to indicate
+ to this method that it is bounds-preserving.
+
+ The returned specification is only valid as of the previous call
+ to `reset`. In particular, it is an error to call this function before
+ the first call to `reset`.
+
+ Returns:
+ A dict mapping observation name to `Array` or `BoundedArray` spec
+ containing the observation shape and dtype, and possibly bounds.
+
+ Raises:
+ RuntimeError: If this method is called before `reset` has been called.
+ """
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `observation_spec`.')
+
+ def make_observation_spec_dict(enabled_dict):
+ """Makes a dict of enabled observation specs from of observables."""
+ out_dict = type(enabled_dict)()
+ for name, enabled in six.iteritems(enabled_dict):
+
+ if isinstance(enabled.observable.array_spec, specs.BoundedArray):
+ bounds = (enabled.observable.array_spec.minimum,
+ enabled.observable.array_spec.maximum)
+ else:
+ bounds = None
+
+ if enabled.observable.aggregator:
+ aggregator = enabled.observable.aggregator
+ aggregated = aggregator(np.zeros(enabled.buffer.shape,
+ dtype=enabled.buffer.dtype))
+ shape = aggregated.shape
+ dtype = aggregated.dtype
+
+ # Ditch bounds if the aggregator isn't known to be bounds-preserving.
+ if bounds:
+ if not hasattr(aggregator, 'preserves_bounds'):
+ logging.warning('Ignoring the bounds of this observable\'s spec, '
+ 'as its aggregator method has no boolean '
+ '`preserves_bounds` attrubute.')
+ bounds = None
+ elif not aggregator.preserves_bounds:
+ bounds = None
+ else:
+ shape = enabled.buffer.shape
+ dtype = enabled.buffer.dtype
+
+ if bounds:
+ spec = specs.BoundedArray(minimum=bounds[0],
+ maximum=bounds[1],
+ shape=shape,
+ dtype=dtype,
+ name=name)
+ else:
+ spec = specs.Array(shape=shape, dtype=dtype, name=name)
+
+ out_dict[name] = spec
+ return out_dict
+
+ if self._is_nested:
+ enabled_specs = type(self._enabled_structure)(
+ make_observation_spec_dict(enabled_dict)
+ for enabled_dict in self._enabled_structure)
+ else:
+ enabled_specs = make_observation_spec_dict(self._enabled_structure)
+
+ return enabled_specs
+
+ def prepare_for_next_control_step(self):
+ """Simulates the next control step and optimizes the update schedule."""
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `before_step`.')
+ for enabled in self._enabled_list:
+ update_interval = (
+ enabled.observable.update_interval or DEFAULT_UPDATE_INTERVAL)
+ delay = enabled.observable.delay or DEFAULT_DELAY
+ buffer_size = enabled.observable.buffer_size or DEFAULT_BUFFER_SIZE
+
+ if (update_interval == DEFAULT_UPDATE_INTERVAL and delay == DEFAULT_DELAY
+ and buffer_size < self._physics_steps_per_control_step):
+ for i in reversed(range(buffer_size)):
+ next_step = (
+ self._step_counter + self._physics_steps_per_control_step - i)
+ next_delay = DEFAULT_DELAY
+ enabled.update_schedule.append((next_step, next_delay))
+ else:
+ if enabled.update_schedule:
+ last_scheduled_step = enabled.update_schedule[-1][0]
+ else:
+ last_scheduled_step = self._step_counter
+ max_step = self._step_counter + 2 * self._physics_steps_per_control_step
+ while last_scheduled_step < max_step:
+ next_update_interval = _call_if_callable(update_interval)
+ next_step = last_scheduled_step + next_update_interval
+ next_delay = _call_if_callable(delay)
+ enabled.update_schedule.append((next_step, next_delay))
+ last_scheduled_step = next_step
+ # Optimize the schedule by planning ahead and dropping unseen entries.
+ enabled.buffer.drop_unobserved_upcoming_items(
+ enabled.update_schedule, self._physics_steps_per_control_step)
+
+ def update(self):
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `after_substep`.')
+ self._step_counter += 1
+ for enabled in self._enabled_list:
+ if (enabled.update_schedule and
+ enabled.update_schedule[0][0] == self._step_counter):
+ timestamp, delay = enabled.update_schedule.popleft()
+ enabled.buffer.insert(
+ timestamp, delay,
+ enabled.observation_callable())
+
+ def get_observation(self):
+ """Gets the current observation.
+
+ The returned observation is only valid as of the previous call
+ to `reset`. In particular, it is an error to call this function before
+ the first call to `reset`.
+
+ Returns:
+ A dict, or list of dicts, or tuple of dicts, of observation values.
+ The returned structure corresponds to the structure of the `observables`
+ that was given at initialization time.
+
+ Raises:
+ RuntimeError: If this method is called before `reset` has been called.
+ """
+ if self._enabled_structure is None:
+ raise RuntimeError('`reset` must be called before `observation`.')
+
+ def aggregate_dict(enabled_dict):
+ out_dict = type(enabled_dict)()
+ for name, enabled in six.iteritems(enabled_dict):
+ if enabled.observable.aggregator:
+ aggregated = enabled.observable.aggregator(
+ enabled.buffer.read(self._step_counter))
+ else:
+ aggregated = enabled.buffer.read(self._step_counter)
+ out_dict[name] = aggregated
+ return out_dict
+
+ if self._is_nested:
+ return type(self._enabled_structure)(
+ aggregate_dict(enabled_dict)
+ for enabled_dict in self._enabled_structure)
+ else:
+ return aggregate_dict(self._enabled_structure)
diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py
new file mode 100644
index 0000000..5dd04f7
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py
@@ -0,0 +1,279 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for observation.observation_updater."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+import math
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.observation import fake_physics
+from dm_control.composer.observation import observable
+from dm_control.composer.observation import updater
+from dm_env import specs
+import numpy as np
+import six
+from six.moves import range
+
+
+class DeterministicSequence(object):
+
+ def __init__(self, sequence):
+ self._iter = itertools.cycle(sequence)
+
+ def __call__(self, random_state=None):
+ del random_state # unused
+ return six.next(self._iter)
+
+
+class BoundedGeneric(observable.Generic):
+
+ def __init__(self, raw_observation_callable, minimum, maximum, **kwargs):
+ super(BoundedGeneric, self).__init__(
+ raw_observation_callable=raw_observation_callable,
+ **kwargs)
+ self._bounds = (minimum, maximum)
+
+ @property
+ def array_spec(self):
+ datum = np.array(self(None, None))
+ return specs.BoundedArray(shape=datum.shape,
+ dtype=datum.dtype,
+ minimum=self._bounds[0],
+ maximum=self._bounds[1])
+
+
+class UpdaterTest(parameterized.TestCase):
+
+ @parameterized.parameters(list, tuple)
+ def testNestedSpecsAndValues(self, list_or_tuple):
+ observables = list_or_tuple((
+ {'one': observable.Generic(lambda _: 1.),
+ 'two': observable.Generic(lambda _: [2, 2]),
+ }, collections.OrderedDict([
+ ('three', observable.Generic(lambda _: np.full((2, 2), 3))),
+ ('four', observable.Generic(lambda _: [4.])),
+ ('five', observable.Generic(lambda _: 5)),
+ ('six', BoundedGeneric(lambda _: [2, 2], 1, 4)),
+ ('seven', BoundedGeneric(lambda _: 2, 1, 4, aggregator='sum')),
+ ])
+ ))
+
+ observables[0]['two'].enabled = True
+ observables[1]['three'].enabled = True
+ observables[1]['five'].enabled = True
+ observables[1]['six'].enabled = True
+ observables[1]['seven'].enabled = True
+
+ observation_updater = updater.Updater(observables)
+ observation_updater.reset(physics=fake_physics.FakePhysics(),
+ random_state=None)
+
+ def make_spec(obs):
+ array = np.array(obs.observation_callable(None, None)())
+ shape = array.shape if obs.aggregator else (1,) + array.shape
+
+ if (isinstance(obs, BoundedGeneric) and
+ obs.aggregator is not observable.base.AGGREGATORS['sum']):
+ return specs.BoundedArray(shape=shape,
+ dtype=array.dtype,
+ minimum=obs.array_spec.minimum,
+ maximum=obs.array_spec.maximum)
+ else:
+ return specs.Array(shape=shape, dtype=array.dtype)
+
+ expected_specs = list_or_tuple((
+ {'two': make_spec(observables[0]['two'])},
+ collections.OrderedDict([
+ ('three', make_spec(observables[1]['three'])),
+ ('five', make_spec(observables[1]['five'])),
+ ('six', make_spec(observables[1]['six'])),
+ ('seven', make_spec(observables[1]['seven'])),
+ ])
+ ))
+
+ actual_specs = observation_updater.observation_spec()
+ self.assertIs(type(actual_specs), type(expected_specs))
+ for actual_dict, expected_dict in zip(actual_specs, expected_specs):
+ self.assertIs(type(actual_dict), type(expected_dict))
+ self.assertEqual(actual_dict, expected_dict)
+
+ def make_value(obs):
+ value = obs(physics=None, random_state=None)
+ if obs.aggregator:
+ return value
+ else:
+ value = np.array(value)
+ value = value[np.newaxis, ...]
+ return value
+
+ expected_values = list_or_tuple((
+ {'two': make_value(observables[0]['two'])},
+ collections.OrderedDict([
+ ('three', make_value(observables[1]['three'])),
+ ('five', make_value(observables[1]['five'])),
+ ('six', make_value(observables[1]['six'])),
+ ('seven', make_value(observables[1]['seven'])),
+ ])
+ ))
+
+ actual_values = observation_updater.get_observation()
+ self.assertIs(type(actual_values), type(expected_values))
+ for actual_dict, expected_dict in zip(actual_values, expected_values):
+ self.assertIs(type(actual_dict), type(expected_dict))
+ self.assertLen(actual_dict, len(expected_dict))
+ for actual, expected in zip(six.iteritems(actual_dict),
+ six.iteritems(expected_dict)):
+ actual_name, actual_value = actual
+ expected_name, expected_value = expected
+ self.assertEqual(actual_name, expected_name)
+ np.testing.assert_array_equal(actual_value, expected_value)
+
+ def assertCorrectSpec(
+ self, spec, expected_shape, expected_dtype, expected_name):
+ self.assertEqual(spec.shape, expected_shape)
+ self.assertEqual(spec.dtype, expected_dtype)
+ self.assertEqual(spec.name, expected_name)
+
+ def testObservationSpecInference(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['repeated'].buffer_size = 5
+ physics.observables['matrix'].buffer_size = 4
+ physics.observables['sqrt'] = observable.Generic(
+ fake_physics.FakePhysics.sqrt, buffer_size=3)
+
+ for obs in six.itervalues(physics.observables):
+ obs.enabled = True
+
+ observation_updater = updater.Updater(physics.observables)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ spec = observation_updater.observation_spec()
+ self.assertCorrectSpec(spec['repeated'], (5, 2), np.int, 'repeated')
+ self.assertCorrectSpec(spec['matrix'], (4, 2, 3), np.int, 'matrix')
+ self.assertCorrectSpec(spec['sqrt'], (3,), np.float, 'sqrt')
+
+ def testObservation(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['repeated'].buffer_size = 5
+ physics.observables['matrix'].delay = 1
+ physics.observables['sqrt'] = observable.Generic(
+ fake_physics.FakePhysics.sqrt, update_interval=7,
+ buffer_size=3, delay=2)
+ for obs in six.itervalues(physics.observables):
+ obs.enabled = True
+ with physics.reset_context():
+ pass
+
+ physics_steps_per_control_step = 5
+ observation_updater = updater.Updater(
+ physics.observables, physics_steps_per_control_step)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ for control_step in range(0, 200):
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+
+ step_counter = (control_step + 1) * physics_steps_per_control_step
+
+ observation = observation_updater.get_observation()
+ def assert_correct_buffer(obs_name, expected_callable,
+ observation=observation,
+ step_counter=step_counter):
+ update_interval = (physics.observables[obs_name].update_interval
+ or updater.DEFAULT_UPDATE_INTERVAL)
+ buffer_size = (physics.observables[obs_name].buffer_size
+ or updater.DEFAULT_BUFFER_SIZE)
+ delay = (physics.observables[obs_name].delay
+ or updater.DEFAULT_DELAY)
+
+ # The final item in the buffer is the current time, less the delay,
+ # rounded _down_ to the nearest multiple of the update interval.
+ end = update_interval * int(
+ math.floor((step_counter - delay) / update_interval))
+
+ # Figure out the first item in the buffer by working backwards from
+ # the final item in multiples of the update interval.
+ start = end - (buffer_size - 1) * update_interval
+
+ # Clamp both the start and end step number below by zero.
+ buffer_range = range(max(0, start), max(0, end + 1), update_interval)
+
+ # Arrays with expected shapes, filled with expected default values.
+ expected_value_spec = observation_updater.observation_spec()[obs_name]
+ expected_values = np.zeros(shape=expected_value_spec.shape,
+ dtype=expected_value_spec.dtype)
+
+ # The arrays are filled from right to left, such that the most recent
+ # entry is the rightmost one, and any padding is on the left.
+ for index, timestamp in enumerate(reversed(buffer_range)):
+ expected_values[-(index+1)] = expected_callable(timestamp)
+
+ np.testing.assert_array_equal(observation[obs_name], expected_values)
+
+ assert_correct_buffer('twice', lambda x: 2*x)
+ assert_correct_buffer('matrix', lambda x: [[x]*3]*2)
+ assert_correct_buffer('repeated', lambda x: [x, x])
+ assert_correct_buffer('sqrt', np.sqrt)
+
+ def testVariableRatesAndDelays(self):
+ physics = fake_physics.FakePhysics()
+ physics.observables['time'] = observable.Generic(
+ lambda physics: physics.time(),
+ buffer_size=3,
+ # observations produced on step numbers 20*N + [0, 3, 5, 8, 11, 15, 16]
+ update_interval=DeterministicSequence([3, 2, 3, 3, 4, 1, 4]),
+ # observations arrive on step numbers 20*N + [3, 8, 7, 12, 11, 17, 20]
+ delay=DeterministicSequence([3, 5, 2, 5, 1, 2, 4]))
+ physics.observables['time'].enabled = True
+
+ physics_steps_per_control_step = 10
+ observation_updater = updater.Updater(
+ physics.observables, physics_steps_per_control_step)
+ observation_updater.reset(physics=physics, random_state=None)
+
+ # Run through a few cycles of the variation sequences to make sure that
+ # cross-control-boundary behaviour is correct.
+ for i in range(5):
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+ np.testing.assert_array_equal(
+ observation_updater.get_observation()['time'],
+ 20*i + np.array([0, 5, 3]))
+
+ observation_updater.prepare_for_next_control_step()
+ for _ in range(physics_steps_per_control_step):
+ physics.step()
+ observation_updater.update()
+ # Note that #11 is dropped since it arrives after #8,
+ # whose large delay caused it to cross the control step boundary at #10.
+ np.testing.assert_array_equal(
+ observation_updater.get_observation()['time'],
+ 20*i + np.array([8, 15, 16]))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/robot.py b/DMC/src/env/dm_control/dm_control/composer/robot.py
new file mode 100644
index 0000000..fa4dfad
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/robot.py
@@ -0,0 +1,38 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Module defining the abstract robot class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from dm_control.composer import entity
+import numpy as np
+import six
+
+DOWN_QUATERNION = np.array([0., 0.70710678118, 0.70710678118, 0.])
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Robot(entity.Entity):
+ """The abstract base class for robots."""
+
+ @abc.abstractproperty
+ def actuators(self):
+ """Returns the actuator elements of the robot."""
+ raise NotImplementedError
diff --git a/DMC/src/env/dm_control/dm_control/composer/task.py b/DMC/src/env/dm_control/dm_control/composer/task.py
new file mode 100644
index 0000000..606a109
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/task.py
@@ -0,0 +1,332 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Abstract base class for a Composer task."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import copy
+import sys
+
+from dm_control import mujoco
+from dm_env import specs
+import six
+from six.moves import range
+
+
+def _check_timesteps_divisible(control_timestep, physics_timestep):
+ num_steps = control_timestep / physics_timestep
+ rounded_num_steps = int(round(num_steps))
+ if abs(num_steps - rounded_num_steps) > 1e-6:
+ raise ValueError(
+ 'Control timestep should be an integer multiple of physics timestep'
+ ': got {!r} and {!r}'.format(control_timestep, physics_timestep))
+ return rounded_num_steps
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Task(object):
+ """Abstract base class for a Composer task."""
+
+ @abc.abstractproperty
+ def root_entity(self):
+ """A `base.Entity` instance for this task."""
+ raise NotImplementedError
+
+ def iter_entities(self):
+ return self.root_entity.iter_entities()
+
+ @property
+ def observables(self):
+ """An OrderedDict of `control.Observable` instances for this task.
+
+ Task subclasses should generally NOT override this property.
+
+ This property is automatically computed by combining the observables dict
+ provided by each `Entity` present in this task, and any additional
+ observables returned via the `task_observables` property.
+
+ To provide an observable to an agent, the task code should either set
+ `enabled` property of an `Entity`-bound observable to `True`, or override
+ the `task_observables` property to provide additional observables not bound
+ to an `Entity`.
+
+ Returns:
+ An `collections.OrderedDict` mapping strings to instances of
+ `control.Observable`.
+ """
+ # Make a shallow copy of the OrderedDict, not the Observables themselves.
+ observables = copy.copy(self.task_observables)
+ for entity in self.root_entity.iter_entities():
+ observables.update(entity.observables.as_dict())
+ return observables
+
+ @property
+ def task_observables(self):
+ """An OrderedDict of task-specific `control.Observable` instances.
+
+ A task should override this property if it wants to provide additional
+ observables to the agent that are not already provided by any `Entity` that
+ forms part of the task's model. For example, this may be used to provide
+ observations that is derived from relative poses between two entities.
+
+ Returns:
+ An `collections.OrderedDict` mapping strings to instances of
+ `control.Observable`.
+ """
+ return collections.OrderedDict()
+
+ def after_compile(self, physics, random_state):
+ """A callback which is executed after the Mujoco Physics is recompiled.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def _check_root_entity(self, callee_name):
+ try:
+ _ = self.root_entity
+ except: # pylint: disable=bare-except
+ err_type, err, tb = sys.exc_info()
+ message = (
+ 'call to `{}` made before `root_entity` is available;\n'
+ 'original error message: {}'.format(callee_name, str(err)))
+ six.reraise(err_type, err_type(message), tb)
+
+ @property
+ def control_timestep(self):
+ """Returns the agent's control timestep for this task (in seconds)."""
+ self._check_root_entity('control_timestep')
+ if hasattr(self, '_control_timestep'):
+ return self._control_timestep
+ else:
+ return self.physics_timestep
+
+ @control_timestep.setter
+ def control_timestep(self, new_value):
+ """Changes the agent's control timestep for this task.
+
+ Args:
+ new_value: the new control timestep (in seconds).
+
+ Raises:
+ ValueError: if `new_value` is set and is not divisible by
+ `physics_timestep`.
+ """
+ self._check_root_entity('control_timestep')
+ _check_timesteps_divisible(new_value, self.physics_timestep)
+ self._control_timestep = new_value
+
+ @property
+ def physics_timestep(self):
+ """Returns the physics timestep for this task (in seconds)."""
+ self._check_root_entity('physics_timestep')
+ if self.root_entity.mjcf_model.option.timestep is None:
+ return 0.002 # MuJoCo's default.
+ else:
+ return self.root_entity.mjcf_model.option.timestep
+
+ @physics_timestep.setter
+ def physics_timestep(self, new_value):
+ """Changes the physics simulation timestep for this task.
+
+ Args:
+ new_value: the new simulation timestep (in seconds).
+
+ Raises:
+ ValueError: if `control_timestep` is set and is not divisible by
+ `new_value`.
+ """
+ self._check_root_entity('physics_timestep')
+ if hasattr(self, '_control_timestep'):
+ _check_timesteps_divisible(self._control_timestep, new_value)
+ self.root_entity.mjcf_model.option.timestep = new_value
+
+ def set_timesteps(self, control_timestep, physics_timestep):
+ """Changes the agent's control timestep and physics simulation timestep.
+
+ This is equivalent to modifying `control_timestep` and `physics_timestep`
+ simultaneously. The divisibility check is performed between the two
+ new values.
+
+ Args:
+ control_timestep: the new agent's control timestep (in seconds).
+ physics_timestep: the new physics simulation timestep (in seconds).
+
+ Raises:
+ ValueError: if `control_timestep` is not divisible by `physics_timestep`.
+ """
+ self._check_root_entity('set_timesteps')
+ _check_timesteps_divisible(control_timestep, physics_timestep)
+ self.root_entity.mjcf_model.option.timestep = physics_timestep
+ self._control_timestep = control_timestep
+
+ @property
+ def physics_steps_per_control_step(self):
+ """Returns number of physics steps per agent's control step."""
+ return _check_timesteps_divisible(
+ self.control_timestep, self.physics_timestep)
+
+ def action_spec(self, physics):
+ """Returns a `BoundedArray` spec matching the `Physics` actuators.
+
+ BoundedArray.name should contain a tab-separated list of actuator names.
+ When overloading this method, non-MuJoCo actuators should be added to the
+ top of the list when possible, as a matter of convention.
+
+ Args:
+ physics: used to query actuator names in the model.
+ """
+ names = [physics.model.id2name(i, 'actuator') or str(i)
+ for i in range(physics.model.nu)]
+ action_spec = mujoco.action_spec(physics)
+ return specs.BoundedArray(shape=action_spec.shape,
+ dtype=action_spec.dtype,
+ minimum=action_spec.minimum,
+ maximum=action_spec.maximum,
+ name='\t'.join(names))
+
+ def get_reward_spec(self):
+ """Optional method to define non-scalar rewards for a `Task`."""
+ return None
+
+ def get_discount_spec(self):
+ """Optional method to define non-scalar discounts for a `Task`."""
+ return None
+
+ def initialize_episode_mjcf(self, random_state):
+ """Modifies the MJCF model of this task before the next episode begins.
+
+ The Environment calls this method and recompiles the physics
+ if necessary before calling `initialize_episode`.
+
+ Args:
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def initialize_episode(self, physics, random_state):
+ """Modifies the physics state before the next episode begins.
+
+ The Environment calls this method after `initialize_episode_mjcf`, and also
+ after the physics has been recompiled if necessary.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def before_step(self, physics, action, random_state):
+ """A callback which is executed before an agent control step.
+
+ The default implementation sets the control signal for the actuators in
+ `physics` to be equal to `action`. Subclasses that override this method
+ should ensure that the overriding method also sets the control signal before
+ returning, either by calling `super(..., self).before_step`, or by setting
+ the control signal explicitly (e.g. in order to create a non-trivial mapping
+ between `action` and the control signal).
+
+ Args:
+ physics: An instance of `control.Physics`.
+ action: A NumPy array corresponding to agent actions.
+ random_state: An instance of `np.random.RandomState` (unused).
+ """
+ del random_state # Unused.
+ physics.set_control(action)
+
+ def before_substep(self, physics, action, random_state):
+ """A callback which is executed before a simulation step.
+
+ Actuation can be set, or overridden, in this callback.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ action: A NumPy array corresponding to agent actions.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def after_substep(self, physics, random_state):
+ """A callback which is executed after a simulation step.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ def after_step(self, physics, random_state):
+ """A callback which is executed after an agent control step.
+
+ Args:
+ physics: An instance of `control.Physics`.
+ random_state: An instance of `np.random.RandomState`.
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_reward(self, physics):
+ """Calculates the reward signal given the physics state.
+
+ Args:
+ physics: A Physics object.
+
+ Returns:
+ A float
+ """
+ raise NotImplementedError
+
+ def should_terminate_episode(self, physics): # pylint: disable=unused-argument
+ """Determines whether the episode should terminate given the physics state.
+
+ Args:
+ physics: A Physics object
+
+ Returns:
+ A boolean
+ """
+ return False
+
+ def get_discount(self, physics): # pylint: disable=unused-argument
+ """Calculates the reward discount factor given the physics state.
+
+ Args:
+ physics: A Physics object
+
+ Returns:
+ A float
+ """
+ return 1.0
+
+
+class NullTask(Task):
+ """A class that wraps a single `Entity` into a `Task` with no reward."""
+
+ def __init__(self, root_entity):
+ self._root_entity = root_entity
+
+ @property
+ def root_entity(self):
+ return self._root_entity
+
+ def get_reward(self, physics):
+ return 0.0
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py b/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py
new file mode 100644
index 0000000..4831e1e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py
@@ -0,0 +1,137 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A module that helps manage model variation in Composer environments."""
+
+import collections
+import copy
+
+from dm_control.composer.variation.base import Variation
+from dm_control.composer.variation.variation_values import evaluate
+import six
+
+
+class _VariationInfo(object):
+
+ __slots__ = ['initial_value', 'variation']
+
+ def __init__(self, initial_value=None, variation=None):
+ self.initial_value = initial_value
+ self.variation = variation
+
+
+class MJCFVariator(object):
+ """Helper object for applying variations to MJCF attributes.
+
+ An instance of this class remembers the original value of each MJCF attribute
+ the first time a variation is applied. The original value is then passed as an
+ argument to each variation callable.
+ """
+
+ def __init__(self):
+ self._variations = collections.defaultdict(dict)
+
+ def bind_attributes(self, element, **kwargs):
+ """Binds variations to attributes of an MJCF element.
+
+ Args:
+ element: An `mjcf.Element` object.
+ **kwargs: Keyword arguments mapping attribute names to the corresponding
+ variations. A variation is either a fixed value or a callable that
+ optionally takes the original value of an attribute and returns a
+ new value.
+ """
+ for attribute_name, variation in six.iteritems(kwargs):
+ if variation is None and attribute_name in self._variations[element]:
+ del self._variations[element][attribute_name]
+ else:
+ initial_value = copy.copy(getattr(element, attribute_name))
+ self._variations[element][attribute_name] = (
+ _VariationInfo(initial_value, variation))
+
+ def apply_variations(self, random_state):
+ """Applies variations in-place to the specified MJCF element.
+
+ Args:
+ random_state: A `numpy.random.RandomState` instance.
+ """
+ for element, attribute_variations in six.iteritems(self._variations):
+ new_values = {}
+ for attribute_name, variation_info in six.iteritems(attribute_variations):
+ current_value = getattr(element, attribute_name)
+ if variation_info.initial_value is None:
+ variation_info.initial_value = copy.copy(current_value)
+ new_values[attribute_name] = evaluate(
+ variation_info.variation, variation_info.initial_value,
+ current_value, random_state)
+ element.set_attributes(**new_values)
+
+ def clear(self):
+ """Clears all bound attribute variations."""
+ self._variations.clear()
+
+ def reset_initial_values(self):
+ for variations in six.itervalues(self._variations):
+ for variation_info in six.itervalues(variations):
+ variation_info.initial_value = None
+
+
+class PhysicsVariator(object):
+ """Helper object for applying variations to MjModel and MjData.
+
+ An instance of this class remembers the original value of each attribute
+ the first time a variation is applied. The original value is then passed as an
+ argument to each variation callable.
+ """
+
+ def __init__(self):
+ self._variations = collections.defaultdict(dict)
+
+ def bind_attributes(self, element, **kwargs):
+ """Binds variations to attributes of an MJCF element.
+
+ Args:
+ element: An `mjcf.Element` object.
+ **kwargs: Keyword arguments mapping attribute names to the corresponding
+ variations. A variation is either a fixed value or a callable that
+ optionally takes the original value of an attribute and returns a
+ new value.
+ """
+ for attribute_name, variation in six.iteritems(kwargs):
+ if variation is None and attribute_name in self._variations[element]:
+ del self._variations[element][attribute_name]
+ else:
+ self._variations[element][attribute_name] = (
+ _VariationInfo(None, variation))
+
+ def apply_variations(self, physics, random_state):
+ for element, variations in six.iteritems(self._variations):
+ binding = physics.bind(element)
+ for attribute_name, variation_info in six.iteritems(variations):
+ current_value = getattr(binding, attribute_name)
+ if variation_info.initial_value is None:
+ variation_info.initial_value = copy.copy(current_value)
+ setattr(binding, attribute_name, evaluate(
+ variation_info.variation, variation_info.initial_value,
+ current_value, random_state))
+
+ def clear(self):
+ """Clears all bound attribute variations."""
+ self._variations.clear()
+
+ def reset_initial_values(self):
+ for variations in six.itervalues(self._variations):
+ for variation_info in six.itervalues(variations):
+ variation_info.initial_value = None
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/base.py b/DMC/src/env/dm_control/dm_control/composer/variation/base.py
new file mode 100644
index 0000000..7783c8d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/base.py
@@ -0,0 +1,99 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for variations and binary operations on variations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import operator
+
+from dm_control.composer.variation import variation_values
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Variation(object):
+ """Abstract base class for variations."""
+
+ @abc.abstractmethod
+ def __call__(self, initial_value, current_value, random_state):
+ """Generates a value for this variation.
+
+ Args:
+ initial_value: The original value of the attribute being varied.
+ Absolute variations may ignore this argument.
+ current_value: The current value of the attribute being varied.
+ Absolute variations may ignore this argument.
+ random_state: A `numpy.RandomState` used to generate the value.
+ Deterministic variations may ignore this argument.
+
+ Returns:
+ The next value for this variation.
+ """
+
+ def __add__(self, other):
+ return _BinaryOperation(operator.add, self, other)
+
+ def __radd__(self, other):
+ return _BinaryOperation(operator.add, other, self)
+
+ def __sub__(self, other):
+ return _BinaryOperation(operator.sub, self, other)
+
+ def __rsub__(self, other):
+ return _BinaryOperation(operator.sub, other, self)
+
+ def __mul__(self, other):
+ return _BinaryOperation(operator.mul, self, other)
+
+ def __rmul__(self, other):
+ return _BinaryOperation(operator.mul, other, self)
+
+ def __truediv__(self, other):
+ return _BinaryOperation(operator.truediv, self, other)
+
+ def __rtruediv__(self, other):
+ return _BinaryOperation(operator.truediv, other, self)
+
+ def __floordiv__(self, other):
+ return _BinaryOperation(operator.floordiv, self, other)
+
+ def __rfloordiv__(self, other):
+ return _BinaryOperation(operator.floordiv, other, self)
+
+ def __pow__(self, other):
+ return _BinaryOperation(operator.pow, self, other)
+
+ def __rpow__(self, other):
+ return _BinaryOperation(operator.pow, other, self)
+
+
+class _BinaryOperation(Variation):
+ """Represents the result of applying a binary operator to two Variations."""
+
+ def __init__(self, op, first, second):
+ self._first = first
+ self._second = second
+ self._op = op
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ first_value = variation_values.evaluate(
+ self._first, initial_value, current_value, random_state)
+ second_value = variation_values.evaluate(
+ self._second, initial_value, current_value, random_state)
+ return self._op(first_value, second_value)
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/colors.py b/DMC/src/env/dm_control/dm_control/composer/variation/colors.py
new file mode 100644
index 0000000..2fc091d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/colors.py
@@ -0,0 +1,79 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Variations in colors.
+
+Classes in this module allow users to specify a variations for each channel in
+a variety of color spaces. The generated values are always RGBA arrays.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import colorsys
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+import numpy as np
+
+
+class RgbVariation(base.Variation):
+ """Represents a variation in the RGB color space.
+
+ This class allows users to specify independent variations in the R, G, B, and
+ alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, r, g, b, alpha=1.0):
+ self._r, self._g, self._b = r, g, b
+ self._alpha = alpha
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ return np.asarray(
+ variation_values.evaluate([self._r, self._g, self._b, self._alpha],
+ initial_value, current_value, random_state))
+
+
+class HsvVariation(base.Variation):
+ """Represents a variation in the HSV color space.
+
+ This class allows users to specify independent variations in the H, S, V, and
+ alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, h, s, v, alpha=1.0):
+ self._h, self._s, self._v = h, s, v
+ self._alpha = alpha
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ h, s, v, alpha = variation_values.evaluate(
+ (self._h, self._s, self._v, self._alpha), initial_value, current_value,
+ random_state)
+ return np.asarray(list(colorsys.hsv_to_rgb(h, s, v)) + [alpha])
+
+
+class GrayVariation(HsvVariation):
+ """Represents a variation in gray level.
+
+ This class allows users to specify independent variations in the gray level
+ and alpha channels of a color, and generates the corresponding array of RGBA
+ values.
+ """
+
+ def __init__(self, gray_level, alpha=1.0):
+ super(GrayVariation, self).__init__(h=0.0, s=0.0, v=gray_level, alpha=alpha)
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py b/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py
new file mode 100644
index 0000000..afea874
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Deterministic variations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control.composer.variation import base
+
+
+class Constant(base.Variation):
+ """Wraps a constant value into a Variation object.
+
+ This class is provided mainly for use in tests, to check that variations are
+ invoked correctly without having to introduce randomness in test cases.
+ """
+
+ def __init__(self, value):
+ self._value = value
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ return self._value
+
+
+class Sequence(base.Variation):
+ """Variation representing a fixed sequence of values."""
+
+ def __init__(self, values):
+ self._values = values
+ self._iterator = iter(self._values)
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ try:
+ return next(self._iterator)
+ except StopIteration:
+ self._iterator = iter(self._values)
+ return next(self._iterator)
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py b/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py
new file mode 100644
index 0000000..e948bf9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py
@@ -0,0 +1,212 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Standard statistical distributions that conform to the Variation API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import functools
+
+from dm_control.composer import variation
+from dm_control.composer.variation import base
+import numpy as np
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Distribution(base.Variation):
+ """Base Distribution class for sampling a parametrized distribution.
+
+ Subclasses need to implement `_callable`, which needs to return a callable
+ based on the random_state passed as arg. This callable then gets called using
+ the arguments passed to the constructor, after being evaluated. This allows
+ the distribution parameters themselves to be instances of `base.Variation`.
+ By default samples are drawn in the shape of `initial_value`, unless the
+ optional `single_sample` constructor arg is set to `True`, in which case only
+ a single sample is drawn.
+ """
+ __slots__ = ('_single_sample', '_args', '_kwargs')
+
+ def __init__(self, *args, **kwargs):
+ self._single_sample = kwargs.pop('single_sample', False)
+ self._args = args
+ self._kwargs = kwargs
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ local_random_state = random_state or np.random
+ size = (None if self._single_sample or initial_value is None # pylint: disable=g-long-ternary
+ else np.shape(initial_value))
+ local_args = variation.evaluate(self._args,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ local_kwargs = variation.evaluate(self._kwargs,
+ initial_value=initial_value,
+ current_value=current_value,
+ random_state=random_state)
+ return self._callable(local_random_state)(*local_args,
+ size=size,
+ **local_kwargs)
+
+ @abc.abstractmethod
+ def _callable(self, random_state):
+ raise NotImplementedError
+
+
+class Uniform(Distribution):
+ __slots__ = ()
+
+ def __init__(self, low=0.0, high=1.0, single_sample=False):
+ super(Uniform, self).__init__(low=low, high=high,
+ single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.uniform
+
+
+class UniformInteger(Distribution):
+ __slots__ = ()
+
+ def __init__(self, low, high=None, single_sample=False):
+ super(UniformInteger, self).__init__(low, high=high,
+ single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.randint
+
+
+class UniformChoice(Distribution):
+ __slots__ = ()
+
+ def __init__(self, choices, single_sample=False):
+ super(UniformChoice, self).__init__(choices, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.choice
+
+
+class UniformPointOnSphere(base.Variation):
+ __slots__ = ()
+
+ def __call__(self, initial_value=None,
+ current_value=None, random_state=None):
+ random_state = random_state or np.random
+ size = 3 if initial_value is None else np.append(np.shape(initial_value), 3)
+ axis = random_state.normal(size=size)
+ axis /= np.linalg.norm(axis, axis=-1, keepdims=True)
+ return axis
+
+
+class Normal(Distribution):
+ __slots__ = ()
+
+ def __init__(self, loc=0.0, scale=1.0, single_sample=False):
+ super(Normal, self).__init__(loc=loc, scale=scale,
+ single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.normal
+
+
+class LogNormal(Distribution):
+ __slots__ = ()
+
+ def __init__(self, mean=0.0, sigma=1.0, single_sample=False):
+ super(LogNormal, self).__init__(mean=mean, sigma=sigma,
+ single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.lognormal
+
+
+class Exponential(Distribution):
+ __slots__ = ()
+
+ def __init__(self, scale=1.0, single_sample=False):
+ super(Exponential, self).__init__(scale=scale, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.exponential
+
+
+class Poisson(Distribution):
+ __slots__ = ()
+
+ def __init__(self, lam=1.0, single_sample=False):
+ super(Poisson, self).__init__(lam=lam, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return random_state.poisson
+
+
+class Bernoulli(Distribution):
+ __slots__ = ()
+
+ def __init__(self, prob=0.5, single_sample=False):
+ super(Bernoulli, self).__init__(prob, single_sample=single_sample)
+
+ def _callable(self, random_state):
+ return functools.partial(random_state.binomial, 1)
+
+
+_NEGATIVE_STDEV = '`stdev` must be >= 0, got {}.'
+_NEGATIVE_TIMESCALE = '`timescale` must be >= 0, got {}.'
+
+
+class BiasedRandomWalk(base.Variation):
+ """A Class for generating noise from a zero-mean Ornstein-Uhlenbeck process.
+
+ Let
+ `retain = np.exp(-1. / timescale)`
+ and
+ `scale = stdev * sqrt(1 - (retain * retain))`
+ Then the discete-time first-order filtered diffusion process
+ `x_next = retain * x + N(0, scale))`
+ has standard deviation `stdev` and characteristic timescale `timescale`.
+ """
+ __slots__ = ('_scale', '_value')
+
+ def __init__(self, stdev=0.1, timescale=10.):
+ """Initializes a `BiasedRandomWalk`.
+
+ Args:
+ stdev: Float. Standard deviation of the output sequence.
+ timescale: Integer. Number of timesteps characteristic of the random walk.
+ After `timescale` steps the correlation is reduced by exp(-1). Larger
+ or equal to 0, where a value of 0 is an uncorrelated normal
+ distribution.
+
+ Raises:
+ ValueError: if either `stdev` or `timescale` is negative.
+ """
+ if stdev < 0:
+ raise ValueError(_NEGATIVE_STDEV.format(stdev))
+ if timescale < 0:
+ raise ValueError(_NEGATIVE_TIMESCALE.format(timescale))
+ elif timescale == 0:
+ self._retain = 0.
+ else:
+ self._retain = np.exp(-1. / timescale)
+ self._scale = stdev * np.sqrt(1 - (self._retain * self._retain))
+ self._value = 0.0
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ self._value = (self._retain * self._value +
+ random_state.normal(loc=0.0, scale=self._scale))
+ return self._value
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py
new file mode 100644
index 0000000..1abe269
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py
@@ -0,0 +1,115 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for distributions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.variation import distributions
+import numpy as np
+from six.moves import range
+
+RANDOM_SEED = 123
+NUM_ITERATIONS = 100
+
+
+def _make_random_state():
+ return np.random.RandomState(RANDOM_SEED)
+
+
+class DistributionsTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(DistributionsTest, self).setUp()
+ self._variation_random_state = _make_random_state()
+ self._np_random_state = _make_random_state()
+
+ def testUniform(self):
+ lower, upper = [2, 3, 4], [5, 6, 7]
+ variation = distributions.Uniform(low=lower, high=upper)
+ for _ in range(NUM_ITERATIONS):
+ np.testing.assert_array_equal(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.uniform(lower, upper))
+
+ def testUniformChoice(self):
+ choices = ['apple', 'banana', 'cherry']
+ variation = distributions.UniformChoice(choices)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.choice(choices))
+
+ def testUniformPointOnSphere(self):
+ variation = distributions.UniformPointOnSphere()
+ samples = []
+ for _ in range(NUM_ITERATIONS):
+ sample = variation(random_state=self._variation_random_state)
+ self.assertEqual(sample.size, 3)
+ np.testing.assert_approx_equal(np.linalg.norm(sample), 1.0)
+ samples.append(sample)
+ # Make sure that none of the samples are the same.
+ self.assertLen(
+ set(np.reshape(np.asarray(samples), -1)), 3 * NUM_ITERATIONS)
+
+ def testNormal(self):
+ loc, scale = 1, 2
+ variation = distributions.Normal(loc=loc, scale=scale)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.normal(loc, scale))
+
+ def testExponential(self):
+ scale = 3
+ variation = distributions.Exponential(scale=scale)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.exponential(scale))
+
+ def testPoisson(self):
+ lam = 4
+ variation = distributions.Poisson(lam=lam)
+ for _ in range(NUM_ITERATIONS):
+ self.assertEqual(
+ variation(random_state=self._variation_random_state),
+ self._np_random_state.poisson(lam))
+
+ @parameterized.parameters(0, 10)
+ def testBiasedRandomWalk(self, timescale):
+ stdev = 1.
+ variation = distributions.BiasedRandomWalk(stdev=stdev, timescale=timescale)
+ sequence = [variation(random_state=self._variation_random_state)
+ for _ in range(int(max(timescale, 1)*NUM_ITERATIONS*1000))]
+ self.assertAlmostEqual(np.mean(sequence), 0., delta=0.01)
+ self.assertAlmostEqual(np.std(sequence), stdev, delta=0.01)
+
+ @parameterized.parameters(
+ dict(arg_name='stdev', template=distributions._NEGATIVE_STDEV),
+ dict(arg_name='timescale', template=distributions._NEGATIVE_TIMESCALE))
+ def testBiasedRandomWalkExceptions(self, arg_name, template):
+ bad_value = -1.
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, template.format(bad_value)):
+ _ = distributions.BiasedRandomWalk(**{arg_name: bad_value})
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/noises.py b/DMC/src/env/dm_control/dm_control/composer/variation/noises.py
new file mode 100644
index 0000000..a43dee3
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/noises.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Meta-variations that modify original values by a specified variation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+
+
+class Additive(base.Variation):
+ """A variation that adds to an existing value.
+
+ This variation takes a value generated by another variation and adds it to an
+ existing value. In cumulative mode, the generated value is added to the
+ current value being varied. In non-cumulative mode, the generated value is
+ added to a fixed initial value.
+ """
+
+ def __init__(self, variation, cumulative=False):
+ self._variation = variation
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ base_value = current_value if self._cumulative else initial_value
+ return base_value + (
+ variation_values.evaluate(self._variation, initial_value, current_value,
+ random_state))
+
+
+class Multiplicative(base.Variation):
+ """A variation that multiplies to an existing value.
+
+ This variation takes a value generated by another variation and multiplies it
+ to an existing value. In cumulative mode, the generated value is multiplied to
+ the current value being varied. In non-cumulative mode, the generated value is
+ multiplied to a fixed initial value.
+ """
+
+ def __init__(self, variation, cumulative=False):
+ self._variation = variation
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ base_value = current_value if self._cumulative else initial_value
+ return base_value * (
+ variation_values.evaluate(self._variation, initial_value, current_value,
+ random_state))
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py
new file mode 100644
index 0000000..da8145f
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py
@@ -0,0 +1,93 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for noises."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer.variation import deterministic
+from dm_control.composer.variation import noises
+from six.moves import range
+
+NUM_ITERATIONS = 100
+
+
+class NoisesTest(parameterized.TestCase):
+
+ @parameterized.parameters(False, True)
+ def testAdditive(self, use_constant_variation_object):
+ amount = 2
+ if use_constant_variation_object:
+ variation = noises.Additive(deterministic.Constant(amount))
+ else:
+ variation = noises.Additive(amount)
+ initial_value = 0
+ current_value = initial_value
+ for _ in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value + amount)
+
+ @parameterized.parameters(False, True)
+ def testAdditiveCumulative(self, use_constant_variation_object):
+ amount = 3
+ if use_constant_variation_object:
+ variation = noises.Additive(
+ deterministic.Constant(amount), cumulative=True)
+ else:
+ variation = noises.Additive(amount, cumulative=True)
+ initial_value = 1
+ current_value = initial_value
+ for i in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value + amount * (i + 1))
+
+ @parameterized.parameters(False, True)
+ def testMultiplicative(self, use_constant_variation_object):
+ amount = 23
+ if use_constant_variation_object:
+ variation = noises.Multiplicative(deterministic.Constant(amount))
+ else:
+ variation = noises.Multiplicative(amount)
+ initial_value = 3
+ current_value = initial_value
+ for _ in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value * amount)
+
+ @parameterized.parameters(False, True)
+ def testMultiplicativeCumulative(self, use_constant_variation_object):
+ amount = 2
+ if use_constant_variation_object:
+ variation = noises.Multiplicative(
+ deterministic.Constant(amount), cumulative=True)
+ else:
+ variation = noises.Multiplicative(amount, cumulative=True)
+ initial_value = 3
+ current_value = initial_value
+ for i in range(NUM_ITERATIONS):
+ current_value = variation(
+ initial_value=initial_value, current_value=current_value)
+ self.assertEqual(current_value, initial_value * amount ** (i + 1))
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py b/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py
new file mode 100644
index 0000000..b84e19b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Variations in 3D rotations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control.composer.variation import base
+from dm_control.composer.variation import variation_values
+import numpy as np
+
+IDENTITY_QUATERNION = np.array([1., 0., 0., 0.])
+
+
+class UniformQuaternion(base.Variation):
+ """Uniformly distributed unit quaternions."""
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ u1, u2, u3 = random_state.uniform([0.] * 3, [1., 2. * np.pi, 2. * np.pi])
+ return np.array([np.sqrt(1. - u1) * np.sin(u2),
+ np.sqrt(1. - u1) * np.cos(u2),
+ np.sqrt(u1) * np.sin(u3),
+ np.sqrt(u1) * np.cos(u3)])
+
+
+class QuaternionFromAxisAngle(base.Variation):
+ """Quaternion variation specified in terms of variations in axis and angle."""
+
+ def __init__(self, axis, angle):
+ self._axis = axis
+ self._angle = angle
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ axis = variation_values.evaluate(
+ self._axis, initial_value, current_value, random_state)
+ angle = variation_values.evaluate(
+ self._angle, initial_value, current_value, random_state)
+ sine, cosine = np.sin(angle / 2), np.cos(angle / 2)
+ return np.array([cosine, axis[0] * sine, axis[1] * sine, axis[2] * sine])
+
+
+class QuaternionPreMultiply(base.Variation):
+ """A variation that pre-multiplies an existing quaternion value.
+
+ This variation takes a quaternion value generated by another variation and
+ pre-multiplies it to an existing value. In cumulative mode, the new quaternion
+ is pre-multiplied to the current value being varied. In non-cumulative mode,
+ the new quaternion is pre-multiplied to a fixed initial value.
+ """
+
+ def __init__(self, quat, cumulative=False):
+ self._quat = quat
+ self._cumulative = cumulative
+
+ def __call__(self, initial_value=None, current_value=None, random_state=None):
+ random_state = random_state or np.random
+ q1 = variation_values.evaluate(self._quat, initial_value, current_value,
+ random_state)
+ q2 = current_value if self._cumulative else initial_value
+ return np.array([
+ q1[0]*q2[0] - q1[1]*q2[1] - q1[2]*q2[2] - q1[3]*q2[3],
+ q1[0]*q2[1] + q1[1]*q2[0] + q1[2]*q2[3] - q1[3]*q2[2],
+ q1[0]*q2[2] - q1[1]*q2[3] + q1[2]*q2[0] + q1[3]*q2[1],
+ q1[0]*q2[3] + q1[1]*q2[2] - q1[2]*q2[1] + q1[3]*q2[0]])
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py
new file mode 100644
index 0000000..8e4b752
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py
@@ -0,0 +1,52 @@
+# Copyright 2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for base variation operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import operator
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.composer import variation
+from dm_control.composer.variation import deterministic
+
+
+class VariationTest(parameterized.TestCase):
+
+ def setUp(self):
+ self.value_1 = 3
+ self.variation_1 = deterministic.Constant(self.value_1)
+ self.value_2 = 5
+ self.variation_2 = deterministic.Constant(self.value_2)
+
+ @parameterized.parameters(['add', 'sub', 'mul', 'truediv', 'floordiv', 'pow'])
+ def test_operator(self, name):
+ func = getattr(operator, name)
+ self.assertEqual(
+ variation.evaluate(func(self.value_1, self.variation_2)),
+ func(self.value_1, self.value_2))
+ self.assertEqual(
+ variation.evaluate(func(self.variation_1, self.value_2)),
+ func(self.value_1, self.value_2))
+ self.assertEqual(
+ variation.evaluate(func(self.variation_1, self.variation_2)),
+ func(self.value_1, self.value_2))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py b/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py
new file mode 100644
index 0000000..8791225
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py
@@ -0,0 +1,39 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities for handling nested structures of callables or constants."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tree
+
+
+def evaluate(structure, *args, **kwargs):
+ """Evaluates a arbitrarily nested structure of callables or constant values.
+
+ Args:
+ structure: An arbitrarily nested structure of callables or constant values.
+ By "structures", we mean lists, tuples, namedtuples, or dicts.
+ *args: Positional arguments passed to each callable in `structure`.
+ **kwargs: Keyword arguments passed to each callable in `structure.
+
+ Returns:
+ The same nested structure, with each callable replaced by the value returned
+ by calling it.
+ """
+ return tree.map_structure(
+ lambda x: x(*args, **kwargs) if callable(x) else x, structure)
diff --git a/DMC/src/env/dm_control/dm_control/entities/__init__.py b/DMC/src/env/dm_control/dm_control/entities/__init__.py
new file mode 100644
index 0000000..4224c02
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/entities/props/__init__.py b/DMC/src/env/dm_control/dm_control/entities/props/__init__.py
new file mode 100644
index 0000000..5b8863c
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/props/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Composer entities corresponding to props.
+
+A "prop" is typically a non-actuated entity representing an object in the world.
+"""
+
+# Prop imports
+from dm_control.entities.props.position_detector import PositionDetector
+from dm_control.entities.props.primitive import Primitive
diff --git a/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py b/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py
new file mode 100644
index 0000000..7332ab7
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py
@@ -0,0 +1,270 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Detects the presence of registered entities within a cuboidal region."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control import mjcf
+import numpy as np
+
+_RENDERED_HEIGHT_IN_2D_MODE = 0.1
+
+
+def _ensure_3d(pos):
+ # Pad the array with a zero if its length is 2.
+ if len(pos) == 2:
+ return np.hstack([pos, 0.])
+ return pos
+
+
+class _Detection(object):
+
+ __slots__ = ('entity', 'detected')
+
+ def __init__(self, entity, detected=False):
+ self.entity = entity
+ self.detected = detected
+
+
+class PositionDetector(composer.Entity):
+ """Detects the presence of registered entities within a cuboidal region.
+
+ An entity is considered "detected" if the `xpos` value of any one of its geom
+ lies within the active region defined by this detector. Note that this is NOT
+ a contact-based detector. Generally speaking, a geom will not be detected
+ until it is already "half inside" the region.
+
+ This detector supports both 2D and 3D modes. In 2D mode, the active region
+ has an effective infinite height along the z-direction.
+
+ This detector also provides an "inverted" detection mode, where an entity is
+ detected when it is not inside the detector's region.
+ """
+
+ def _build(self,
+ pos,
+ size,
+ inverted=False,
+ visible=False,
+ rgba=(1, 0, 0, 0.25),
+ detected_rgba=(0, 1, 0, 0.25),
+ name='position_detector'):
+ """Builds the detector.
+
+ Args:
+ pos: The position at the center of this detector's active region. Should
+ be an array-like object of length 3 in 3D mode, or length 2 in 2D mode.
+ size: The half-lengths of this detector's active region. Should
+ be an array-like object of length 3 in 3D mode, or length 2 in 2D mode.
+ inverted: (optional) A boolean, whether to operate in inverted detection
+ mode. If `True`, an entity is detected when it is not in the active
+ region.
+ visible: (optional) A boolean, whether this detector is visible by
+ default in rendered images. If `False`, this detector's active zone
+ is placed in MuJoCo rendering group 4, which is not rendered by default,
+ but can be toggled on (e.g. in `dm_control.viewer`) for debugging
+ purposes.
+ rgba: (optional) The color to render when nothing is detected.
+ detected_rgba: (optional) The color to render when an entity is detected.
+ name: (optional) XML element name of this position detector.
+
+ Raises:
+ ValueError: If the `pos` and `size` arrays do not have the same length.
+ """
+ if len(pos) != len(size):
+ raise ValueError('`pos` and `size` should have the same length: '
+ 'got {!r} and {!r}'.format(pos, size))
+
+ self._inverted = inverted
+ self._detected = False
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ self._entities = []
+ self._entity_geoms = {}
+
+ self._rgba = np.asarray(rgba)
+ self._detected_rgba = np.asarray(detected_rgba)
+
+ render_pos = np.zeros(3)
+ render_pos[:len(pos)] = pos
+
+ render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE)
+ render_size[:len(size)] = size
+
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._site = self._mjcf_root.worldbody.add(
+ 'site', name='detection_zone', type='box',
+ pos=render_pos, size=render_size, rgba=self._rgba)
+ self._lower_site = self._mjcf_root.worldbody.add(
+ 'site', name='lower', pos=self._lower_3d, size=[0.05],
+ rgba=self._rgba)
+ self._mid_site = self._mjcf_root.worldbody.add(
+ 'site', name='mid', pos=self._mid_3d, size=[0.05],
+ rgba=self._rgba)
+ self._upper_site = self._mjcf_root.worldbody.add(
+ 'site', name='upper', pos=self._upper_3d, size=[0.05],
+ rgba=self._rgba)
+ self._lower_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._lower_site,
+ name='{}_lower'.format(name))
+ self._mid_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._mid_site,
+ name='{}_mid'.format(name))
+ self._upper_sensor = self._mjcf_root.sensor.add(
+ 'framepos', objtype='site', objname=self._upper_site,
+ name='{}_upper'.format(name))
+
+ if not visible:
+ self._site.group = composer.SENSOR_SITES_GROUP
+ self._lower_site.group = composer.SENSOR_SITES_GROUP
+ self._mid_site.group = composer.SENSOR_SITES_GROUP
+ self._upper_site.group = composer.SENSOR_SITES_GROUP
+
+ def resize(self, pos, size):
+ if len(pos) != len(size):
+ raise ValueError('`pos` and `size` should have the same length: '
+ 'got {!r} and {!r}'.format(pos, size))
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ render_pos = np.zeros(3)
+ render_pos[:len(pos)] = pos
+
+ render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE)
+ render_size[:len(size)] = size
+
+ self._site.pos = render_pos
+ self._site.size = render_size
+ self._lower_site.pos = self._lower_3d
+ self._mid_site.pos = self._mid_3d
+ self._upper_site.pos = self._upper_3d
+
+ def set_colors(self, rgba, detected_rgba):
+ self.set_color(rgba)
+ self.set_detected_color(detected_rgba)
+
+ def set_color(self, rgba):
+ self._rgba[:3] = rgba
+ self._site.rgba = self._rgba
+
+ def set_detected_color(self, detected_rgba):
+ self._detected_rgba[:3] = detected_rgba
+
+ def set_position(self, physics, pos):
+ physics.bind(self._site).pos = pos
+ size = physics.bind(self._site).size[:3]
+ self._lower = np.array(pos) - np.array(size)
+ self._upper = np.array(pos) + np.array(size)
+
+ self._lower_3d = _ensure_3d(self._lower)
+ self._upper_3d = _ensure_3d(self._upper)
+ self._mid_3d = (self._lower_3d + self._upper_3d) / 2.
+
+ physics.bind(self._lower_site).pos = self._lower_3d
+ physics.bind(self._mid_site).pos = self._mid_3d
+ physics.bind(self._upper_site).pos = self._upper_3d
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def register_entities(self, *entities):
+ for entity in entities:
+ self._entities.append(_Detection(entity))
+ self._entity_geoms[entity] = entity.mjcf_model.find_all('geom')
+
+ def deregister_entities(self):
+ self._entities = []
+
+ @property
+ def detected_entities(self):
+ """A list of detected entities."""
+ return [
+ detection.entity for detection in self._entities if detection.detected]
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._entity_geoms = {}
+ for detection in self._entities:
+ entity = detection.entity
+ self._entity_geoms[entity] = entity.mjcf_model.find_all('geom')
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._update_detection(physics)
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_detection(physics)
+
+ def _is_in_zone(self, xpos):
+ return (np.all(self._lower < xpos[:len(self._lower)])
+ and np.all(self._upper > xpos[:len(self._upper)]))
+
+ def _update_detection(self, physics):
+ previously_detected = self._detected
+ self._detected = False
+ for detection in self._entities:
+ detection.detected = False
+ for geom in self._entity_geoms[detection.entity]:
+ if self._is_in_zone(physics.bind(geom).xpos) != self._inverted:
+ detection.detected = True
+ self._detected = True
+ break
+
+ if self._detected and not previously_detected:
+ physics.bind(self._site).rgba = self._detected_rgba
+ elif previously_detected and not self._detected:
+ physics.bind(self._site).rgba = self._rgba
+
+ def site_pos(self, physics):
+ return physics.bind(self._site).pos
+
+ @property
+ def activated(self):
+ return self._detected
+
+ @property
+ def upper(self):
+ return self._upper
+
+ @property
+ def lower(self):
+ return self._lower
+
+ @property
+ def mid(self):
+ return (self._lower + self._upper) / 2.
+
+ @property
+ def lower_sensor(self):
+ return self._lower_sensor
+
+ @property
+ def mid_sensor(self):
+ return self._mid_sensor
+
+ @property
+ def upper_sensor(self):
+ return self._upper_sensor
diff --git a/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py b/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py
new file mode 100644
index 0000000..cfdd073
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py
@@ -0,0 +1,133 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.props.position_detector."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control.entities.props import position_detector
+from dm_control.entities.props import primitive
+import numpy as np
+
+
+class PositionDetectorTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(PositionDetectorTest, self).setUp()
+ self.arena = composer.Arena()
+ self.props = [
+ primitive.Primitive(geom_type='sphere', size=(0.1,)),
+ primitive.Primitive(geom_type='sphere', size=(0.1,))
+ ]
+ for prop in self.props:
+ self.arena.add_free_entity(prop)
+ self.task = composer.NullTask(self.arena)
+
+ def assertDetected(self, entity, detector):
+ if not self.inverted:
+ self.assertIn(entity, detector.detected_entities)
+ else:
+ self.assertNotIn(entity, detector.detected_entities)
+
+ def assertNotDetected(self, entity, detector):
+ if not self.inverted:
+ self.assertNotIn(entity, detector.detected_entities)
+ else:
+ self.assertIn(entity, detector.detected_entities)
+
+ @parameterized.parameters(False, True)
+ def test3DDetection(self, inverted):
+ self.inverted = inverted
+
+ detector_pos = np.array([0.3, 0.2, 0.1])
+ detector_size = np.array([0.1, 0.2, 0.3])
+ detector = position_detector.PositionDetector(
+ pos=detector_pos, size=detector_size, inverted=inverted)
+ detector.register_entities(*self.props)
+ self.arena.attach(detector)
+ env = composer.Environment(self.task)
+
+ env.reset()
+ self.assertNotDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ def initialize_episode(physics, unused_random_state):
+ for prop in self.props:
+ prop.set_pose(physics, detector_pos)
+ self.task.initialize_episode = initialize_episode
+ env.reset()
+ self.assertDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(env.physics, detector_pos - detector_size)
+ env.step([])
+ self.assertNotDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(env.physics, detector_pos - detector_size / 2)
+ self.props[1].set_pose(env.physics, detector_pos + detector_size * 1.01)
+ env.step([])
+ self.assertDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ @parameterized.parameters(False, True)
+ def test2DDetection(self, inverted):
+ self.inverted = inverted
+
+ detector_pos = np.array([0.3, 0.2])
+ detector_size = np.array([0.1, 0.2])
+ detector = position_detector.PositionDetector(
+ pos=detector_pos, size=detector_size, inverted=inverted)
+ detector.register_entities(*self.props)
+ self.arena.attach(detector)
+ env = composer.Environment(self.task)
+
+ env.reset()
+ self.assertNotDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+ def initialize_episode(physics, unused_random_state):
+ # In 2D mode, detection should occur no matter how large |z| is.
+ self.props[0].set_pose(physics, [detector_pos[0], detector_pos[1], 1e+6])
+ self.props[1].set_pose(physics, [detector_pos[0], detector_pos[1], -1e+6])
+ self.task.initialize_episode = initialize_episode
+ env.reset()
+ self.assertDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(
+ env.physics, [detector_pos[0] - detector_size[0], detector_pos[1], 0])
+ env.step([])
+ self.assertNotDetected(self.props[0], detector)
+ self.assertDetected(self.props[1], detector)
+
+ self.props[0].set_pose(
+ env.physics, [detector_pos[0] - detector_size[0] / 2,
+ detector_pos[1] + detector_size[1] / 2, 0])
+ self.props[1].set_pose(
+ env.physics, [detector_pos[0], detector_pos[1] + detector_size[1], 0])
+ env.step([])
+ self.assertDetected(self.props[0], detector)
+ self.assertNotDetected(self.props[1], detector)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/entities/props/primitive.py b/DMC/src/env/dm_control/dm_control/entities/props/primitive.py
new file mode 100644
index 0000000..718e8a8
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/props/primitive.py
@@ -0,0 +1,112 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Prop consisting of a single geom with position and velocity sensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import define
+from dm_control.composer.observation import observable
+
+
+class Primitive(composer.Entity):
+ """A prop consisting of a single geom with position and velocity sensors."""
+
+ def _build(self, geom_type, size, name=None, **kwargs):
+ """Initializes the prop.
+
+ Args:
+ geom_type: String specifying the geom type.
+ size: List or numpy array of up to 3 numbers, depending on `geom_type`:
+ geom_type='box', size=[x_half_length, y_half_length, z_half_length]
+ geom_type='capsule', size=[radius, half_length]
+ geom_type='cylinder', size=[radius, half_length]
+ geom_type='ellipsoid', size=[x_radius, y_radius, z_radius]
+ geom_type='sphere', size=[radius]
+ name: (optional) A string, the name of this prop.
+ **kwargs: Additional geom parameters. Please see the MuJoCo documentation
+ for further details: http://www.mujoco.org/book/XMLreference.html#geom.
+ """
+ self._mjcf_root = mjcf.element.RootElement(model=name)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', name='geom', type=geom_type, size=size, **kwargs)
+ self._position = self._mjcf_root.sensor.add(
+ 'framepos', name='position', objtype='geom', objname=self.geom)
+ self._orientation = self._mjcf_root.sensor.add(
+ 'framequat', name='orientation', objtype='geom', objname=self.geom)
+ self._linear_velocity = self._mjcf_root.sensor.add(
+ 'framelinvel', name='linear_velocity', objtype='geom',
+ objname=self.geom)
+ self._angular_velocity = self._mjcf_root.sensor.add(
+ 'frameangvel', name='angular_velocity', objtype='geom',
+ objname=self.geom)
+
+ def _build_observables(self):
+ return PrimitiveObservables(self)
+
+ @property
+ def geom(self):
+ """The geom belonging to this prop."""
+ return self._geom
+
+ @property
+ def position(self):
+ """Sensor that returns the prop position."""
+ return self._position
+
+ @property
+ def orientation(self):
+ """Sensor that returns the prop orientation (as a quaternion)."""
+ # TODO(b/120829807): Consider returning a rotation matrix instead.
+ return self._orientation
+
+ @property
+ def linear_velocity(self):
+ """Sensor that returns the linear velocity of the prop."""
+ return self._linear_velocity
+
+ @property
+ def angular_velocity(self):
+ """Sensor that returns the angular velocity of the prop."""
+ return self._angular_velocity
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+
+class PrimitiveObservables(composer.Observables,
+ composer.FreePropObservableMixin):
+ """Primitive entity's observables."""
+
+ @define.observable
+ def position(self):
+ return observable.MJCFFeature('sensordata', self._entity.position)
+
+ @define.observable
+ def orientation(self):
+ return observable.MJCFFeature('sensordata', self._entity.orientation)
+
+ @define.observable
+ def linear_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.linear_velocity)
+
+ @define.observable
+ def angular_velocity(self):
+ return observable.MJCFFeature('sensordata', self._entity.angular_velocity)
diff --git a/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py b/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py
new file mode 100644
index 0000000..a731e0d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py
@@ -0,0 +1,100 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.composer.props.primitive."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.entities.props import primitive
+import numpy as np
+
+
+class PrimitiveTest(parameterized.TestCase):
+
+ def _make_free_prop(self, geom_type='sphere', size=(0.1,), **kwargs):
+ prop = primitive.Primitive(geom_type=geom_type, size=size, **kwargs)
+ arena = composer.Arena()
+ arena.add_free_entity(prop)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ return prop, physics
+
+ @parameterized.parameters([
+ dict(geom_type='sphere', size=[0.1]),
+ dict(geom_type='capsule', size=[0.1, 0.2]),
+ dict(geom_type='cylinder', size=[0.1, 0.2]),
+ dict(geom_type='box', size=[0.1, 0.2, 0.3]),
+ dict(geom_type='ellipsoid', size=[0.1, 0.2, 0.3]),
+ ])
+ def test_instantiation(self, geom_type, size):
+ name = 'foo'
+ rgba = [1., 0., 1., 0.5]
+ prop, physics = self._make_free_prop(
+ geom_type=geom_type, size=size, name=name, rgba=rgba)
+ # Check that the name and other kwargs are set correctly.
+ self.assertEqual(prop.mjcf_model.model, name)
+ np.testing.assert_array_equal(physics.bind(prop.geom).rgba, rgba)
+ # Check that we can step without anything breaking.
+ physics.step()
+
+ @parameterized.parameters([
+ dict(position=[0., 0., 0.]),
+ dict(position=[0.1, -0.2, 0.3]),
+ ])
+ def test_position_observable(self, position):
+ prop, physics = self._make_free_prop()
+ prop.set_pose(physics, position=position)
+ observation = prop.observables.position(physics)
+ np.testing.assert_array_equal(position, observation)
+
+ @parameterized.parameters([
+ dict(quat=[1., 0., 0., 0.]),
+ dict(quat=[0., -1., 1., 0.]),
+ ])
+ def test_orientation_observable(self, quat):
+ prop, physics = self._make_free_prop()
+ normalized_quat = np.array(quat) / np.linalg.norm(quat)
+ prop.set_pose(physics, quaternion=normalized_quat)
+ observation = prop.observables.orientation(physics)
+ np.testing.assert_array_almost_equal(normalized_quat, observation)
+
+ @parameterized.parameters([
+ dict(velocity=[0., 0., 0.]),
+ dict(velocity=[0.1, -0.2, 0.3]),
+ ])
+ def test_linear_velocity_observable(self, velocity):
+ prop, physics = self._make_free_prop()
+ prop.set_velocity(physics, velocity=velocity)
+ observation = prop.observables.linear_velocity(physics)
+ np.testing.assert_array_almost_equal(velocity, observation)
+
+ @parameterized.parameters([
+ dict(angular_velocity=[0., 0., 0.]),
+ dict(angular_velocity=[0.1, -0.2, 0.3]),
+ ])
+ def test_angular_velocity_observable(self, angular_velocity):
+ prop, physics = self._make_free_prop()
+ prop.set_velocity(physics, angular_velocity=angular_velocity)
+ observation = prop.observables.angular_velocity(physics)
+ np.testing.assert_array_almost_equal(angular_velocity, observation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/README.md b/DMC/src/env/dm_control/dm_control/locomotion/README.md
new file mode 100644
index 0000000..0941561
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/README.md
@@ -0,0 +1,88 @@
+# Locomotion task library
+
+This package contains reusable components for defining control tasks that are
+related to locomotion. New users are encouraged to start by browsing the
+`examples/` subdirectory, which contains preconfigured RL environments
+associated with various research papers. These examples can serve as starting
+points or be customized to design new environments using the components
+available from this library.
+
+
+
+
+
+
+## Terminology
+
+This library facilitates the creation of environments that require **walkers**
+to perform a **task** in an **arena**.
+
+- **walkers** refer to detached bodies that can move around in the
+ environment.
+
+- **arenas** refer to the surroundings in which the walkers and possibly other
+ objects exist.
+
+- **tasks** refer to the specification of observations and rewards that are
+ passed from the "environment" to the "agent", along with runtime details
+ such as initialization and termination logic.
+
+## Installation and requirements
+
+See [the documentation for `dm_control`][installation-and-requirements].
+
+## Quickstart
+
+```python
+from dm_control import composer
+from dm_control.locomotion.examples import basic_cmu_2019
+import numpy as np
+
+# Build an example environment.
+env = basic_cmu_2019.cmu_humanoid_run_walls()
+
+# Get the `action_spec` describing the control inputs.
+action_spec = env.action_spec()
+
+# Step through the environment for one episode with random actions.
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum, action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print("reward = {}, discount = {}, observations = {}.".format(
+ time_step.reward, time_step.discount, time_step.observation))
+```
+
+[`dm_control.viewer`] can also be used to visualize and interact with the
+environment, e.g.:
+
+```python
+from dm_control import viewer
+
+viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_walls)
+```
+
+## Publications
+
+This library contains environments that were adapted from several research
+papers. Relevant references include:
+
+- [Emergence of Locomotion Behaviours in Rich Environments (2017)][heess2017].
+
+- [Learning human behaviors from motion capture by adversarial imitation
+ (2017)][merel2017].
+
+- [Hierarchical visuomotor control of humanoids (2019)][merel2019a].
+
+- [Neural probabilistic motor primitives for humanoid control (2019)][merel2019b].
+
+- [Deep neuroethology of a virtual rodent (2020)][merel2020].
+
+[installation-and-requirements]: ../../README.md#installation-and-requirements
+[`dm_control.viewer`]: ../viewer/README.md
+[heess2017]: https://arxiv.org/abs/1707.02286
+[merel2017]: https://arxiv.org/abs/1707.02201
+[merel2019a]: https://arxiv.org/abs/1811.09656
+[merel2019b]: https://arxiv.org/abs/1811.11711
+[merel2020]: https://openreview.net/pdf?id=SyxrxR4KPS
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/__init__.py
new file mode 100644
index 0000000..4224c02
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py
new file mode 100644
index 0000000..3f81949
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Arenas for Locomotion tasks."""
+
+from dm_control.locomotion.arenas.bowl import Bowl
+from dm_control.locomotion.arenas.corridors import EmptyCorridor
+from dm_control.locomotion.arenas.corridors import GapsCorridor
+from dm_control.locomotion.arenas.corridors import WallsCorridor
+from dm_control.locomotion.arenas.floors import Floor
+from dm_control.locomotion.arenas.labmaze_textures import FloorTextures
+from dm_control.locomotion.arenas.labmaze_textures import SkyBox
+from dm_control.locomotion.arenas.labmaze_textures import WallTextures
+from dm_control.locomotion.arenas.mazes import MazeWithTargets
+from dm_control.locomotion.arenas.mazes import RandomMazeWithTargets
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py
new file mode 100644
index 0000000..74b35c0
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py
@@ -0,0 +1,63 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Locomotion texture assets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import sys
+
+ROOT_DIR = '../locomotion/arenas/assets'
+
+
+def get_texturedir(style):
+ return os.path.join(ROOT_DIR, style)
+
+SKY_STYLES = ('outdoor_natural')
+
+SkyBox = collections.namedtuple(
+ 'SkyBox', ('file', 'gridsize', 'gridlayout'))
+
+
+def get_sky_texture_info(style):
+ if style not in SKY_STYLES:
+ raise ValueError('`style` should be one of {}: got {!r}'.format(
+ SKY_STYLES, style))
+ return SkyBox(file='OutdoorSkybox2048.png',
+ gridsize='3 4',
+ gridlayout='.U..LFRB.D..')
+
+
+GROUND_STYLES = ('outdoor_natural')
+
+GroundTexture = collections.namedtuple(
+ 'GroundTexture', ('file', 'type'))
+
+
+def get_ground_texture_info(style):
+ if style not in GROUND_STYLES:
+ raise ValueError('`style` should be one of {}: got {!r}'.format(
+ GROUND_STYLES, style))
+ return GroundTexture(
+ file='OutdoorGrassFloorD.png',
+ type='2d')
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png
new file mode 100644
index 0000000..2a93f5b
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png
new file mode 100644
index 0000000..d6f9a58
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py
new file mode 100644
index 0000000..d591f51
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py
@@ -0,0 +1,139 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Bowl arena with bumps."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+from dm_control.mujoco.wrapper import mjbindings
+
+import numpy as np
+from scipy import ndimage
+
+mjlib = mjbindings.mjlib
+
+_TOP_CAMERA_DISTANCE = 100
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+
+# Constants related to terrain generation.
+_TERRAIN_SMOOTHNESS = .5 # 0.0: maximally bumpy; 1.0: completely smooth.
+_TERRAIN_BUMP_SCALE = .2 # Spatial scale of terrain bumps (in meters).
+
+
+class Bowl(composer.Arena):
+ """A bowl arena with sinusoidal bumps."""
+
+ def _build(self, size=(10, 10), aesthetic='default', name='bowl'):
+ super(Bowl, self)._build(name=name)
+
+ self._hfield = self._mjcf_root.asset.add(
+ 'hfield',
+ name='terrain',
+ nrow=201,
+ ncol=201,
+ size=(6, 6, 0.5, 0.1))
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._texture,
+ texuniform='true')
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ self._terrain_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ name='terrain',
+ type='hfield',
+ pos=(0, 0, -0.01),
+ hfield='terrain',
+ material=self._material)
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ size=list(size) + [0.5],
+ material=self._material)
+ else:
+ self._terrain_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ name='terrain',
+ type='hfield',
+ rgba=(0.2, 0.3, 0.4, 1),
+ pos=(0, 0, -0.01),
+ hfield='terrain')
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ rgba=(0.2, 0.3, 0.4, 1),
+ size=list(size) + [0.5])
+
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1])
+
+ self._regenerate = True
+
+ def regenerate(self, random_state):
+ # regeneration of the bowl requires physics, so postponed to initialization.
+ self._regenerate = True
+
+ def initialize_episode(self, physics, random_state):
+ if self._regenerate:
+ self._regenerate = False
+
+ # Get heightfield resolution, assert that it is square.
+ res = physics.bind(self._hfield).nrow
+ assert res == physics.bind(self._hfield).ncol
+
+ # Sinusoidal bowl shape.
+ row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j]
+ radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .1, 1)
+ bowl_shape = .5 - np.cos(2*np.pi*radius)/2
+
+ # Random smooth bumps.
+ terrain_size = 2 * physics.bind(self._hfield).size[0]
+ bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE)
+ bumps = random_state.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res))
+ smooth_bumps = ndimage.zoom(bumps, res / float(bump_res))
+
+ # Terrain is elementwise product.
+ terrain = bowl_shape * smooth_bumps
+ start_idx = physics.bind(self._hfield).adr
+ physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel()
+
+ # If we have a rendering context, we need to re-upload the modified
+ # heightfield data.
+ if physics.contexts:
+ with physics.contexts.gl.make_current() as ctx:
+ ctx.call(mjlib.mjr_uploadHField,
+ physics.model.ptr,
+ physics.contexts.mujoco.ptr,
+ physics.bind(self._hfield).element_id)
+
+ @property
+ def ground_geoms(self):
+ return (self._terrain_geom, self._ground_geom)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py
new file mode 100644
index 0000000..5699346
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py
@@ -0,0 +1,35 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.arenas.bowl."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import bowl
+
+
+class BowlTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+
+ arena = bowl.Bowl()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py
new file mode 100644
index 0000000..4129023
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py
@@ -0,0 +1,433 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Corridor-based arenas."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+import six
+
+_SIDE_WALLS_GEOM_GROUP = 3
+_CORRIDOR_X_PADDING = 2.0
+_WALL_THICKNESS = 0.16
+_SIDE_WALL_HEIGHT = 4.0
+_DEFAULT_ALPHA = 0.5
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Corridor(composer.Arena):
+ """Abstract base class for corridor-type arenas."""
+
+ @abc.abstractmethod
+ def regenerate(self, random_state):
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def corridor_length(self):
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def corridor_width(self):
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def ground_geoms(self):
+ raise NotImplementedError
+
+ def is_at_target_position(self, position, tolerance=0.0):
+ """Checks if a `position` is within `tolerance' of an end of the corridor.
+
+ This can also be used to evaluate more complicated T-shaped or L-shaped
+ corridors.
+
+ Args:
+ position: An iterable of 2 elements corresponding to the x and y location
+ of the position to evaluate.
+ tolerance: A `float` tolerance to use while evaluating the position.
+
+ Returns:
+ A `bool` indicating whether the `position` is within the `tolerance` of an
+ end of the corridor.
+ """
+ x, _ = position
+ return x > self.corridor_length - tolerance
+
+
+class EmptyCorridor(Corridor):
+ """An empty corridor with planes around the perimeter."""
+
+ def _build(self,
+ corridor_width=4,
+ corridor_length=40,
+ visible_side_planes=True,
+ name='empty_corridor'):
+ """Builds the corridor.
+
+ Args:
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ name: The name of this arena.
+ """
+ super(EmptyCorridor, self)._build(name=name)
+
+ self._corridor_width = corridor_width
+ self._corridor_length = corridor_length
+
+ self._walls_body = self._mjcf_root.worldbody.add('body', name='walls')
+
+ self._mjcf_root.visual.map.znear = 0.0005
+ self._mjcf_root.asset.add(
+ 'texture', type='skybox', builtin='gradient',
+ rgb1=[0.4, 0.6, 0.8], rgb2=[0, 0, 0], width=100, height=600)
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[0.4, 0.4, 0.4], diffuse=[0.8, 0.8, 0.8],
+ specular=[0.1, 0.1, 0.1])
+
+ alpha = _DEFAULT_ALPHA if visible_side_planes else 0.0
+ self._ground_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', rgba=[0.5, 0.5, 0.5, 1], size=[1, 1, 1])
+ self._left_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[1, 0, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._right_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[-1, 0, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._near_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[0, 1, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+ self._far_plane = self._mjcf_root.worldbody.add(
+ 'geom', type='plane', xyaxes=[0, -1, 0, 0, 0, 1], size=[1, 1, 1],
+ rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP)
+
+ self._current_corridor_length = None
+ self._current_corridor_width = None
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor is resized accordingly.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ self._walls_body.geom.clear()
+ corridor_width = variation.evaluate(self._corridor_width,
+ random_state=random_state)
+ corridor_length = variation.evaluate(self._corridor_length,
+ random_state=random_state)
+ self._current_corridor_length = corridor_length
+ self._current_corridor_width = corridor_width
+
+ self._ground_plane.pos = [corridor_length / 2, 0, 0]
+ self._ground_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, corridor_width / 2, 1]
+
+ self._left_plane.pos = [
+ corridor_length / 2, corridor_width / 2, _SIDE_WALL_HEIGHT / 2]
+ self._left_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._right_plane.pos = [
+ corridor_length / 2, -corridor_width / 2, _SIDE_WALL_HEIGHT / 2]
+ self._right_plane.size = [
+ corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._near_plane.pos = [
+ -_CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2]
+ self._near_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1]
+
+ self._far_plane.pos = [
+ corridor_length + _CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2]
+ self._far_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1]
+
+ @property
+ def corridor_length(self):
+ return self._current_corridor_length
+
+ @property
+ def corridor_width(self):
+ return self._current_corridor_width
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,)
+
+
+class GapsCorridor(EmptyCorridor):
+ """A corridor that consists of multiple platforms separated by gaps."""
+
+ def _build(self,
+ platform_length=1.,
+ gap_length=2.5,
+ corridor_width=4,
+ corridor_length=40,
+ ground_rgba=(0.5, 0.5, 0.5, 1),
+ visible_side_planes=False,
+ aesthetic='default',
+ name='gaps_corridor'):
+ """Builds the corridor.
+
+ Args:
+ platform_length: A number or a `composer.variation.Variation` object that
+ specifies the size of the platforms along the corridor.
+ gap_length: A number or a `composer.variation.Variation` object that
+ specifies the size of the gaps along the corridor.
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ ground_rgba: A sequence of 4 numbers or a `composer.variation.Variation`
+ object specifying the color of the ground.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ aesthetic: option to adjust the material properties and skybox
+ name: The name of this arena.
+ """
+ super(GapsCorridor, self)._build(
+ corridor_width=corridor_width,
+ corridor_length=corridor_length,
+ visible_side_planes=visible_side_planes,
+ name=name)
+
+ self._platform_length = platform_length
+ self._gap_length = gap_length
+ self._ground_rgba = ground_rgba
+ self._aesthetic = aesthetic
+
+ if self._aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._ground_texture,
+ texuniform='true')
+ # remove existing skybox
+ for texture in self._mjcf_root.asset.find_all('texture'):
+ if texture.type == 'skybox':
+ texture.remove()
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+
+ self._ground_body = self._mjcf_root.worldbody.add('body', name='ground')
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor resized accordingly, and
+ new sets of platforms are created according to values drawn from the
+ `platform_length`, `gap_length`, and `ground_rgba` distributions specified
+ in `_build`.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ # Resize the entire corridor first.
+ super(GapsCorridor, self).regenerate(random_state)
+
+ # Move the ground plane down and make it invisible.
+ self._ground_plane.pos = [self._current_corridor_length / 2, 0, -10]
+ self._ground_plane.rgba = [0, 0, 0, 0]
+
+ # Clear the existing platform pieces.
+ self._ground_body.geom.clear()
+
+ # Make the first platform larger.
+ platform_length = 3. * _CORRIDOR_X_PADDING
+ platform_pos = [
+ platform_length / 2,
+ 0,
+ -_WALL_THICKNESS,
+ ]
+ platform_size = [
+ platform_length / 2,
+ self._current_corridor_width / 2,
+ _WALL_THICKNESS,
+ ]
+ if self._aesthetic != 'default':
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ name='start_floor',
+ pos=platform_pos,
+ size=platform_size,
+ material=self._ground_material)
+ else:
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ rgba=variation.evaluate(self._ground_rgba, random_state),
+ name='start_floor',
+ pos=platform_pos,
+ size=platform_size)
+
+ current_x = platform_length
+ platform_id = 0
+ while current_x < self._current_corridor_length:
+ platform_length = variation.evaluate(
+ self._platform_length, random_state=random_state)
+ platform_pos = [
+ current_x + platform_length / 2.,
+ 0,
+ -_WALL_THICKNESS,
+ ]
+ platform_size = [
+ platform_length / 2,
+ self._current_corridor_width / 2,
+ _WALL_THICKNESS,
+ ]
+ if self._aesthetic != 'default':
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ name='floor_{}'.format(platform_id),
+ pos=platform_pos,
+ size=platform_size,
+ material=self._ground_material)
+ else:
+ self._ground_body.add(
+ 'geom',
+ type='box',
+ rgba=variation.evaluate(self._ground_rgba, random_state),
+ name='floor_{}'.format(platform_id),
+ pos=platform_pos,
+ size=platform_size)
+
+ platform_id += 1
+
+ # Move x to start of the next platform.
+ current_x += platform_length + variation.evaluate(
+ self._gap_length, random_state=random_state)
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,) + tuple(self._ground_body.find_all('geom'))
+
+
+class WallsCorridor(EmptyCorridor):
+ """A corridor obstructed by multiple walls aligned against the two sides."""
+
+ def _build(self,
+ wall_gap=2.5,
+ wall_width=2.5,
+ wall_height=2.0,
+ swap_wall_side=True,
+ wall_rgba=(1, 1, 1, 1),
+ corridor_width=4,
+ corridor_length=40,
+ visible_side_planes=False,
+ name='walls_corridor'):
+ """Builds the corridor.
+
+ Args:
+ wall_gap: A number or a `composer.variation.Variation` object that
+ specifies the gap between each consecutive pair obstructing walls.
+ wall_width: A number or a `composer.variation.Variation` object that
+ specifies the width that the obstructing walls extend into the corridor.
+ wall_height: A number or a `composer.variation.Variation` object that
+ specifies the height of the obstructing walls.
+ swap_wall_side: A boolean or a `composer.variation.Variation` object that
+ specifies whether the next obstructing wall should be aligned against
+ the opposite side of the corridor compared to the previous one.
+ wall_rgba: A sequence of 4 numbers or a `composer.variation.Variation`
+ object specifying the color of the walls.
+ corridor_width: A number or a `composer.variation.Variation` object that
+ specifies the width of the corridor.
+ corridor_length: A number or a `composer.variation.Variation` object that
+ specifies the length of the corridor.
+ visible_side_planes: Whether to the side planes that bound the corridor's
+ perimeter should be rendered.
+ name: The name of this arena.
+ """
+ super(WallsCorridor, self)._build(
+ corridor_width=corridor_width,
+ corridor_length=corridor_length,
+ visible_side_planes=visible_side_planes,
+ name=name)
+
+ self._wall_height = wall_height
+ self._wall_rgba = wall_rgba
+ self._wall_gap = wall_gap
+ self._wall_width = wall_width
+ self._swap_wall_side = swap_wall_side
+
+ def regenerate(self, random_state):
+ """Regenerates this corridor.
+
+ New values are drawn from the `corridor_width` and `corridor_height`
+ distributions specified in `_build`. The corridor resized accordingly, and
+ new sets of obstructing walls are created according to values drawn from the
+ `wall_gap`, `wall_width`, `wall_height`, and `wall_rgba` distributions
+ specified in `_build`.
+
+ Args:
+ random_state: A `numpy.random.RandomState` object that is passed to the
+ `Variation` objects.
+ """
+ super(WallsCorridor, self).regenerate(random_state)
+ wall_x = variation.evaluate(
+ self._wall_gap, random_state=random_state) - _CORRIDOR_X_PADDING
+ wall_side = 0
+ wall_id = 0
+ while wall_x < self._current_corridor_length:
+ wall_width = variation.evaluate(
+ self._wall_width, random_state=random_state)
+ wall_height = variation.evaluate(
+ self._wall_height, random_state=random_state)
+ wall_rgba = variation.evaluate(self._wall_rgba, random_state=random_state)
+ if variation.evaluate(self._swap_wall_side, random_state=random_state):
+ wall_side = 1 - wall_side
+
+ wall_pos = [
+ wall_x,
+ (2 * wall_side - 1) * (self._current_corridor_width - wall_width) / 2,
+ wall_height / 2
+ ]
+ wall_size = [_WALL_THICKNESS / 2, wall_width / 2, wall_height / 2]
+ self._walls_body.add(
+ 'geom',
+ type='box',
+ name='wall_{}'.format(wall_id),
+ pos=wall_pos,
+ size=wall_size,
+ rgba=wall_rgba)
+
+ wall_id += 1
+ wall_x += variation.evaluate(self._wall_gap, random_state=random_state)
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_plane,)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py
new file mode 100644
index 0000000..328a455
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.arenas.corridors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.variation import deterministic
+from dm_control.locomotion.arenas import corridors
+from six.moves import zip
+
+
+class CorridorsTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ corridors.EmptyCorridor,
+ corridors.GapsCorridor,
+ corridors.WallsCorridor,
+ ])
+ def test_can_compile_mjcf(self, arena_type):
+ arena = arena_type()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ @parameterized.parameters([
+ corridors.EmptyCorridor,
+ corridors.GapsCorridor,
+ corridors.WallsCorridor,
+ ])
+ def test_can_regenerate_corridor_size(self, arena_type):
+ width_sequence = [5.2, 3.8, 7.4]
+ length_sequence = [21.1, 19.4, 16.3]
+
+ arena = arena_type(
+ corridor_width=deterministic.Sequence(width_sequence),
+ corridor_length=deterministic.Sequence(length_sequence))
+
+ # Add a probe geom that will generate contacts with the side walls.
+ probe_body = arena.mjcf_model.worldbody.add('body', name='probe')
+ probe_joint = probe_body.add('freejoint')
+ probe_geom = probe_body.add('geom', name='probe', type='box')
+
+ for expected_width, expected_length in zip(width_sequence, length_sequence):
+ # No random_state is required since we are using deterministic variations.
+ arena.regenerate(random_state=None)
+
+ def resize_probe_geom_and_assert_num_contacts(
+ delta_size, expected_num_contacts,
+ expected_width=expected_width, expected_length=expected_length):
+ probe_geom.size = [
+ (expected_length / 2 + corridors._CORRIDOR_X_PADDING) + delta_size,
+ expected_width / 2 + delta_size, 0.1]
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ probe_geomid = physics.bind(probe_geom).element_id
+ physics.bind(probe_joint).qpos[:3] = [expected_length / 2, 0, 100]
+ physics.forward()
+ probe_contacts = [c for c in physics.data.contact
+ if c.geom1 == probe_geomid or c.geom2 == probe_geomid]
+ self.assertLen(probe_contacts, expected_num_contacts)
+
+ epsilon = 1e-7
+
+ # If the probe geom is epsilon-smaller than the expected corridor size,
+ # then we expect to detect no contact.
+ resize_probe_geom_and_assert_num_contacts(-epsilon, 0)
+
+ # If the probe geom is epsilon-larger than the expected corridor size,
+ # then we expect to generate 4 contacts with each side wall, so 16 total.
+ resize_probe_geom_and_assert_num_contacts(epsilon, 16)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py
new file mode 100644
index 0000000..82de82b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py
@@ -0,0 +1,143 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Calculates a covering of text mazes with overlapping rectangular walls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import numpy as np
+from six.moves import range
+
+GridCoordinates = collections.namedtuple('GridCoordinates', ('y', 'x'))
+MazeWall = collections.namedtuple('MazeWall', ('start', 'end'))
+
+
+class _MazeWallCoveringContext(object):
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ This class uses a greedy algorithm to try and minimize the number of geoms
+ generated to create a given maze. The solution is not guaranteed to be
+ optimal, but in most cases should result in a significantly smaller number of
+ geoms than if each cell were treated as an individual box.
+ """
+
+ def __init__(self, text_maze, wall_char='*', make_odd_sized_walls=False):
+ """Initializes this _MazeWallCoveringContext.
+
+ Args:
+ text_maze: A `labmaze.TextGrid` instance.
+ wall_char: (optional) The character that signifies a wall.
+ make_odd_sized_walls: (optional) A boolean, if `True` all wall sections
+ generated span odd numbers of grid cells. This option exists primarily
+ to appease MuJoCo's texture repeating algorithm.
+ """
+ self._text_maze = text_maze
+ self._wall_char = wall_char
+ self._make_odd_sized_walls = make_odd_sized_walls
+ self._covered = np.full(text_maze.shape, False, dtype=np.bool)
+ self._maze_size = GridCoordinates(*text_maze.shape)
+ self._next_start = GridCoordinates(0, 0)
+ self._calculated = False
+ self._walls = ()
+
+ def calculate(self):
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ Returns:
+ A tuple of `MazeWall` objects, each describing the corners of a wall.
+ """
+ if not self._calculated:
+ self._calculated = True
+ self._find_next_start()
+ walls = []
+ while self._next_start.y < self._maze_size.y:
+ walls.append(self._find_next_wall())
+ self._find_next_start()
+ self._walls = tuple(walls)
+ return self._walls
+
+ def _find_next_start(self):
+ """Moves `self._next_start` to the top-left corner of the next wall."""
+ for y in range(self._next_start.y, self._maze_size.y):
+ start_x = self._next_start.x if y == self._next_start.y else 0
+ for x in range(start_x, self._maze_size.x):
+ if self._text_maze[y, x] == self._wall_char and not self._covered[y, x]:
+ self._next_start = GridCoordinates(y, x)
+ return
+ self._next_start = self._maze_size
+
+ def _scan_row(self, row, start_col, end_col):
+ """Scans a row of text maze to find the longest strip of wall."""
+ for col in range(start_col, end_col):
+ if (self._text_maze[row, col] != self._wall_char
+ or self._covered[row, col]):
+ return col
+ return end_col
+
+ def _find_next_wall(self):
+ """Finds the largest piece of rectangular wall at the current location.
+
+ This function assumes that `self._next_start` is already at the top-left
+ corner of the next piece of wall.
+
+ Returns:
+ A `MazeWall` named tuple representing the next piece of wall created.
+ """
+ start = self._next_start
+ x = self._maze_size.x
+ end_x_for_rows = []
+ total_cells = []
+
+ for y in range(start.y, self._maze_size.y):
+ x = self._scan_row(y, start.x, x)
+ if x > start.x:
+ if self._make_odd_sized_walls and (x - start.x) % 2 == 0:
+ x -= 1
+ end_x_for_rows.append(x)
+ total_cells.append((x - start.x) * (y - start.y + 1))
+ y += 1
+ else:
+ break
+
+ if not self._make_odd_sized_walls:
+ end_y_offset = total_cells.index(max(total_cells))
+ else:
+ end_y_offset = 2 * total_cells[::2].index(max(total_cells[::2]))
+ end = GridCoordinates(start.y + end_y_offset + 1,
+ end_x_for_rows[end_y_offset])
+ self._covered[start.y:end.y, start.x:end.x] = True
+ self._next_start = GridCoordinates(start.y, end.x)
+ return MazeWall(start, end)
+
+
+def make_walls(text_maze, wall_char='*', make_odd_sized_walls=False):
+ """Calculates a covering of text mazes with overlapping rectangular walls.
+
+ Args:
+ text_maze: A `labmaze.TextMaze` instance.
+ wall_char: (optional) The character that signifies a wall.
+ make_odd_sized_walls: (optional) A boolean, if `True` all wall sections
+ generated span odd numbers of grid cells. This option exists primarily
+ to appease MuJoCo's texture repeating algorithm.
+
+ Returns:
+ A tuple of `MazeWall` objects, each describing the corners of a wall.
+ """
+ wall_covering_context = _MazeWallCoveringContext(
+ text_maze, wall_char=wall_char, make_odd_sized_walls=make_odd_sized_walls)
+ return wall_covering_context.calculate()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py
new file mode 100644
index 0000000..b008319
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py
@@ -0,0 +1,80 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for arenas.mazes.covering."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control.locomotion.arenas import covering
+import labmaze
+import numpy as np
+import six
+from six.moves import range
+
+if six.PY3:
+ _STRING_DTYPE = '|U1'
+else:
+ _STRING_DTYPE = '|S1'
+
+
+class CoveringTest(absltest.TestCase):
+
+ def testRandomMazes(self):
+ maze = labmaze.RandomMaze(height=17, width=17,
+ max_rooms=5, room_min_size=3, room_max_size=5,
+ spawns_per_room=0, objects_per_room=0,
+ random_seed=54321)
+ for _ in range(1000):
+ maze.regenerate()
+ walls = covering.make_walls(maze.entity_layer)
+ reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE)
+ for wall in walls:
+ reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*'
+ np.testing.assert_array_equal(reconstructed, maze.entity_layer)
+
+ def testOddCovering(self):
+ maze = labmaze.RandomMaze(height=17, width=17,
+ max_rooms=5, room_min_size=3, room_max_size=5,
+ spawns_per_room=0, objects_per_room=0,
+ random_seed=54321)
+ for _ in range(1000):
+ maze.regenerate()
+ walls = covering.make_walls(maze.entity_layer, make_odd_sized_walls=True)
+ reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE)
+ for wall in walls:
+ reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*'
+ np.testing.assert_array_equal(reconstructed, maze.entity_layer)
+ for wall in walls:
+ self.assertEqual((wall.end.y - wall.start.y) % 2, 1)
+ self.assertEqual((wall.end.x - wall.start.x) % 2, 1)
+
+ def testNoOverlappingWalls(self):
+ maze_string = """..**
+ .***
+ .***
+ """.replace(' ', '')
+ walls = covering.make_walls(labmaze.TextGrid(maze_string))
+ surface = 0
+ for wall in walls:
+ size_x = wall.end.x - wall.start.x
+ size_y = wall.end.y - wall.start.y
+ surface += size_x * size_y
+ self.assertEqual(surface, 8)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py
new file mode 100644
index 0000000..47afc91
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py
@@ -0,0 +1,105 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Simple floor arenas."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+import numpy as np
+
+_TOP_CAMERA_DISTANCE = 100
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+
+
+class Floor(composer.Arena):
+ """A simple floor arena with a checkered pattern."""
+
+ def _build(self, size=(8, 8), reflectance=.2, aesthetic='default',
+ name='floor'):
+ super(Floor, self)._build(name=name)
+ self._size = size
+
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1])
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_texture', file=ground_info.file,
+ type=ground_info.type)
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material', name='aesthetic_material', texture=self._ground_texture,
+ texuniform='true')
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ else:
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture',
+ rgb1=[.2, .3, .4],
+ rgb2=[.1, .2, .3],
+ type='2d',
+ builtin='checker',
+ name='groundplane',
+ width=300,
+ height=300,
+ mark='edge',
+ markrgb=[0.8, 0.8, 0.8])
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material',
+ name='groundplane',
+ texrepeat=[3, 3],
+ texuniform=True,
+ reflectance=reflectance,
+ texture=self._ground_texture)
+
+ # Build groundplane.
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ name='groundplane',
+ material=self._ground_material,
+ size=list(size) + [0.5])
+
+ # Choose the FOV so that the floor always fits nicely within the frame
+ # irrespective of actual floor size.
+ fovy_radians = 2 * np.arctan2(_TOP_CAMERA_Y_PADDING_FACTOR * size[1],
+ _TOP_CAMERA_DISTANCE)
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera',
+ name='top_camera',
+ pos=[0, 0, _TOP_CAMERA_DISTANCE],
+ zaxis=[0, 0, 1],
+ fovy=np.rad2deg(fovy_radians))
+
+ @property
+ def ground_geoms(self):
+ return (self._ground_geom,)
+
+ def regenerate(self, random_state):
+ pass
+
+ @property
+ def size(self):
+ return self._size
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py
new file mode 100644
index 0000000..8979036
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py
@@ -0,0 +1,53 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.arenas.floors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import floors
+import numpy as np
+
+
+class FloorsTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+ arena = floors.Floor()
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+ def test_size(self):
+ floor_size = (12.9, 27.1)
+ arena = floors.Floor(size=floor_size)
+ self.assertEqual(tuple(arena.ground_geoms[0].size[:2]), floor_size)
+
+ def test_top_camera(self):
+ floor_width, floor_height = 12.9, 27.1
+ arena = floors.Floor(size=[floor_width, floor_height])
+
+ self.assertGreater(floors._TOP_CAMERA_Y_PADDING_FACTOR, 1)
+ np.testing.assert_array_equal(arena._top_camera.zaxis, (0, 0, 1))
+
+ expected_camera_y = floor_height * floors._TOP_CAMERA_Y_PADDING_FACTOR
+ np.testing.assert_allclose(
+ np.tan(np.deg2rad(arena._top_camera.fovy / 2)),
+ expected_camera_y / arena._top_camera.pos[2])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py
new file mode 100644
index 0000000..593146a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py
@@ -0,0 +1,87 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""LabMaze textures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control import mjcf
+from labmaze import assets as labmaze_assets
+import six
+
+
+class SkyBox(composer.Entity):
+ """Represents a texture asset for the sky box."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_sky_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', type='skybox', name='texture',
+ fileleft=labmaze_textures.left, fileright=labmaze_textures.right,
+ fileup=labmaze_textures.up, filedown=labmaze_textures.down,
+ filefront=labmaze_textures.front, fileback=labmaze_textures.back)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def texture(self):
+ return self._texture
+
+
+class WallTextures(composer.Entity):
+ """Represents wall texture assets."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_wall_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._textures = []
+ for texture_name, texture_path in six.iteritems(labmaze_textures):
+ self._textures.append(self._mjcf_root.asset.add(
+ 'texture', type='2d', name=texture_name,
+ file=texture_path.format(texture_name)))
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def textures(self):
+ return self._textures
+
+
+class FloorTextures(composer.Entity):
+ """Represents floor texture assets."""
+
+ def _build(self, style):
+ labmaze_textures = labmaze_assets.get_floor_texture_paths(style)
+ self._mjcf_root = mjcf.RootElement(model='labmaze_' + style)
+ self._textures = []
+ for texture_name, texture_path in six.iteritems(labmaze_textures):
+ self._textures.append(self._mjcf_root.asset.add(
+ 'texture', type='2d', name=texture_name,
+ file=texture_path.format(texture_name)))
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def textures(self):
+ return self._textures
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py
new file mode 100644
index 0000000..b30afb9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py
@@ -0,0 +1,466 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Maze-based arenas."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import string
+
+from absl import logging
+from dm_control import composer
+from dm_control.composer.observation import observable
+from dm_control.locomotion.arenas import assets as locomotion_arenas_assets
+from dm_control.locomotion.arenas import covering
+import labmaze
+import numpy as np
+import six
+from six.moves import range
+from six.moves import zip
+
+
+# Put all "actual" wall geoms in a separate group since they are not rendered.
+_WALL_GEOM_GROUP = 3
+
+_TOP_CAMERA_DISTANCE = 100
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+
+_DEFAULT_WALL_CHAR = '*'
+_DEFAULT_FLOOR_CHAR = '.'
+
+
+class MazeWithTargets(composer.Arena):
+ """A 2D maze with target positions specified by a LabMaze-style text maze."""
+
+ def _build(self, maze, xy_scale=2.0, z_height=2.0,
+ skybox_texture=None, wall_textures=None, floor_textures=None,
+ aesthetic='default', name='maze'):
+ """Initializes this maze arena.
+
+ Args:
+ maze: A `labmaze.BaseMaze` instance.
+ xy_scale: The size of each maze cell in metres.
+ z_height: The z-height of the maze in metres.
+ skybox_texture: (optional) A `composer.Entity` that provides a texture
+ asset for the skybox.
+ wall_textures: (optional) Either a `composer.Entity` that provides texture
+ assets for the maze walls, or a dict mapping printable characters to
+ such Entities. In the former case, the maze walls are assumed to be
+ represented by '*' in the maze's entity layer. In the latter case,
+ the dict's keys specify the different characters that can be present
+ in the maze's entity layer, and the dict's values are the corresponding
+ texture providers.
+ floor_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze floor. Unlike with walls, we do not currently
+ support per-variation floor texture. Instead, we sample textures from
+ the same texture provider for each variation in the variations layer.
+ aesthetic: option to adjust the material properties and skybox
+ name: (optional) A string, the name of this arena.
+ """
+ super(MazeWithTargets, self)._build(name)
+ self._maze = maze
+ self._xy_scale = xy_scale
+ self._z_height = z_height
+
+ self._x_offset = (self._maze.width - 1) / 2
+ self._y_offset = (self._maze.height - 1) / 2
+
+ self._mjcf_root.default.geom.rgba = [1, 1, 1, 1]
+
+ if aesthetic != 'default':
+ sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic)
+ texturedir = locomotion_arenas_assets.get_texturedir(aesthetic)
+ self._mjcf_root.compiler.texturedir = texturedir
+ self._skybox = self._mjcf_root.asset.add(
+ 'texture', name='aesthetic_skybox', file=sky_info.file,
+ type='skybox', gridsize=sky_info.gridsize,
+ gridlayout=sky_info.gridlayout)
+ elif skybox_texture:
+ self._skybox_texture = skybox_texture.texture
+ self.attach(skybox_texture)
+ else:
+ self._skybox_texture = self._mjcf_root.asset.add(
+ 'texture', type='skybox', name='skybox', builtin='gradient',
+ rgb1=[.4, .6, .8], rgb2=[0, 0, 0], width=100, height=100)
+
+ self._texturing_geom_names = []
+ self._texturing_material_names = []
+ if wall_textures:
+ if isinstance(wall_textures, dict):
+ for texture_provider in set(wall_textures.values()):
+ self.attach(texture_provider)
+ self._wall_textures = {
+ wall_char: texture_provider.textures
+ for wall_char, texture_provider in six.iteritems(wall_textures)
+ }
+ else:
+ self.attach(wall_textures)
+ self._wall_textures = {_DEFAULT_WALL_CHAR: wall_textures.textures}
+ else:
+ self._wall_textures = {_DEFAULT_WALL_CHAR: [self._mjcf_root.asset.add(
+ 'texture', type='2d', name='wall', builtin='flat',
+ rgb1=[.8, .8, .8], width=100, height=100)]}
+
+ if aesthetic != 'default':
+ ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic)
+ self._floor_textures = [
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='aesthetic_texture_main',
+ file=ground_info.file,
+ type=ground_info.type),
+ self._mjcf_root.asset.add(
+ 'texture',
+ name='aesthetic_texture',
+ file=ground_info.file,
+ type=ground_info.type)
+ ]
+ elif floor_textures:
+ self._floor_textures = floor_textures.textures
+ self.attach(floor_textures)
+ else:
+ self._floor_textures = [self._mjcf_root.asset.add(
+ 'texture', type='2d', name='floor', builtin='flat',
+ rgb1=[.2, .2, .2], width=100, height=100)]
+
+ ground_x = ((self._maze.width - 1) + 1) * (xy_scale / 2)
+ ground_y = ((self._maze.height - 1) + 1) * (xy_scale / 2)
+ self._mjcf_root.worldbody.add(
+ 'geom', name='ground', type='plane',
+ pos=[0, 0, 0], size=[ground_x, ground_y, 1], rgba=[0, 0, 0, 0])
+
+ self._maze_body = self._mjcf_root.worldbody.add('body', name='maze_body')
+
+ self._mjcf_root.visual.map.znear = 0.0005
+
+ # Choose the FOV so that the maze always fits nicely within the frame
+ # irrespective of actual maze size.
+ maze_size = max(self._maze.width, self._maze.height)
+ top_camera_fovy = (360 / np.pi) * np.arctan2(
+ _TOP_CAMERA_Y_PADDING_FACTOR * maze_size * self._xy_scale / 2,
+ _TOP_CAMERA_DISTANCE)
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera', name='top_camera',
+ pos=[0, 0, _TOP_CAMERA_DISTANCE], zaxis=[0, 0, 1], fovy=top_camera_fovy)
+
+ self._target_positions = ()
+ self._spawn_positions = ()
+
+ self._text_maze_regenerated_hook = None
+ self._tile_geom_names = {}
+
+ def _build_observables(self):
+ return MazeObservables(self)
+
+ @property
+ def top_camera(self):
+ return self._top_camera
+
+ @property
+ def xy_scale(self):
+ return self._xy_scale
+
+ @property
+ def z_height(self):
+ return self._z_height
+
+ @property
+ def maze(self):
+ return self._maze
+
+ @property
+ def text_maze_regenerated_hook(self):
+ """A callback that is executed after the LabMaze object is regenerated."""
+ return self._text_maze_modifier
+
+ @text_maze_regenerated_hook.setter
+ def text_maze_regenerated_hook(self, hook):
+ self._text_maze_regenerated_hook = hook
+
+ @property
+ def target_positions(self):
+ """A tuple of Cartesian target positions generated for the current maze."""
+ return self._target_positions
+
+ @property
+ def spawn_positions(self):
+ """The Cartesian position at which the agent should be spawned."""
+ return self._spawn_positions
+
+ @property
+ def target_grid_positions(self):
+ """A tuple of grid coordinates of targets generated for the current maze."""
+ return self._target_grid_positions
+
+ @property
+ def spawn_grid_positions(self):
+ """The grid-coordinate position at which the agent should be spawned."""
+ return self._spawn_grid_positions
+
+ def regenerate(self):
+ """Generates a new maze layout."""
+ self._maze.regenerate()
+ logging.debug('GENERATED MAZE:\n%s', self._maze.entity_layer)
+ self._find_spawn_and_target_positions()
+
+ if self._text_maze_regenerated_hook:
+ self._text_maze_regenerated_hook()
+
+ # Remove old texturing planes.
+ for geom_name in self._texturing_geom_names:
+ del self._mjcf_root.worldbody.geom[geom_name]
+ self._texturing_geom_names = []
+
+ # Remove old texturing materials.
+ for material_name in self._texturing_material_names:
+ del self._mjcf_root.asset.material[material_name]
+ self._texturing_material_names = []
+
+ # Remove old actual-wall geoms.
+ self._maze_body.geom.clear()
+
+ self._current_wall_texture = {
+ wall_char: np.random.choice(wall_textures)
+ for wall_char, wall_textures in six.iteritems(self._wall_textures)
+ }
+
+ for wall_char in self._wall_textures:
+ self._make_wall_geoms(wall_char)
+ self._make_floor_variations()
+
+ def _make_wall_geoms(self, wall_char):
+ walls = covering.make_walls(
+ self._maze.entity_layer, wall_char=wall_char, make_odd_sized_walls=True)
+ for i, wall in enumerate(walls):
+ wall_mid = covering.GridCoordinates(
+ (wall.start.y + wall.end.y - 1) / 2,
+ (wall.start.x + wall.end.x - 1) / 2)
+ wall_pos = np.array([(wall_mid.x - self._x_offset) * self._xy_scale,
+ -(wall_mid.y - self._y_offset) * self._xy_scale,
+ self._z_height / 2])
+ wall_size = np.array([(wall.end.x - wall_mid.x - 0.5) * self._xy_scale,
+ (wall.end.y - wall_mid.y - 0.5) * self._xy_scale,
+ self._z_height / 2])
+ self._maze_body.add('geom', name='wall{}_{}'.format(wall_char, i),
+ type='box', pos=wall_pos, size=wall_size,
+ group=_WALL_GEOM_GROUP)
+ self._make_wall_texturing_planes(wall_char, i, wall_pos, wall_size)
+
+ def _make_wall_texturing_planes(self, wall_char, wall_id,
+ wall_pos, wall_size):
+ xyaxes = {
+ 'x': {-1: [0, -1, 0, 0, 0, 1], 1: [0, 1, 0, 0, 0, 1]},
+ 'y': {-1: [1, 0, 0, 0, 0, 1], 1: [-1, 0, 0, 0, 0, 1]},
+ 'z': {-1: [-1, 0, 0, 0, 1, 0], 1: [1, 0, 0, 0, 1, 0]}
+ }
+ for direction_index, direction in enumerate(('x', 'y', 'z')):
+ index = list(i for i in range(3) if i != direction_index)
+ delta_vector = np.array([int(i == direction_index) for i in range(3)])
+ material_name = 'wall{}_{}_{}'.format(wall_char, wall_id, direction)
+ self._texturing_material_names.append(material_name)
+ mat = self._mjcf_root.asset.add(
+ 'material', name=material_name,
+ texture=self._current_wall_texture[wall_char],
+ texrepeat=(2 * wall_size[index] / self._xy_scale))
+ for sign, sign_name in zip((-1, 1), ('neg', 'pos')):
+ if direction == 'z' and sign == -1:
+ continue
+ geom_name = (
+ 'wall{}_{}_texturing_{}_{}'.format(
+ wall_char, wall_id, sign_name, direction))
+ self._texturing_geom_names.append(geom_name)
+ self._mjcf_root.worldbody.add(
+ 'geom', type='plane', name=geom_name,
+ pos=(wall_pos + sign * delta_vector * wall_size),
+ size=np.concatenate([wall_size[index], [self._xy_scale]]),
+ xyaxes=xyaxes[direction][sign], material=mat,
+ contype=0, conaffinity=0)
+
+ def _make_floor_variations(self, build_tile_geoms_fn=None):
+ """Builds the floor tiles.
+
+ Args:
+ build_tile_geoms_fn: An optional callable returning floor tile geoms.
+ If not passed, the floor will be built using a default covering method.
+ Takes a kwarg `wall_char` that can be used control how active floor
+ tiles are selected.
+ """
+ main_floor_texture = np.random.choice(self._floor_textures)
+ for variation in _DEFAULT_FLOOR_CHAR + string.ascii_uppercase:
+ if variation not in self._maze.variations_layer:
+ break
+
+ if build_tile_geoms_fn is None:
+ # Break the floor variation down to odd-sized tiles.
+ tiles = covering.make_walls(self._maze.variations_layer,
+ wall_char=variation,
+ make_odd_sized_walls=True)
+ else:
+ tiles = build_tile_geoms_fn(wall_char=variation)
+
+ # Sample a texture that's not the same as the main floor texture.
+ variation_texture = main_floor_texture
+ if variation != _DEFAULT_FLOOR_CHAR:
+ if len(self._floor_textures) == 1:
+ return
+ else:
+ while variation_texture is main_floor_texture:
+ variation_texture = np.random.choice(self._floor_textures)
+
+ for i, tile in enumerate(tiles):
+ tile_mid = covering.GridCoordinates(
+ (tile.start.y + tile.end.y - 1) / 2,
+ (tile.start.x + tile.end.x - 1) / 2)
+ tile_pos = np.array([(tile_mid.x - self._x_offset) * self._xy_scale,
+ -(tile_mid.y - self._y_offset) * self._xy_scale,
+ 0.0])
+ tile_size = np.array([(tile.end.x - tile_mid.x - 0.5) * self._xy_scale,
+ (tile.end.y - tile_mid.y - 0.5) * self._xy_scale,
+ self._xy_scale])
+ if variation == _DEFAULT_FLOOR_CHAR:
+ tile_name = 'floor_{}'.format(i)
+ else:
+ tile_name = 'floor_{}_{}'.format(variation, i)
+ self._tile_geom_names[tile.start] = tile_name
+ self._texturing_material_names.append(tile_name)
+ self._texturing_geom_names.append(tile_name)
+ material = self._mjcf_root.asset.add(
+ 'material', name=tile_name, texture=variation_texture,
+ texrepeat=(2 * tile_size[[0, 1]] / self._xy_scale))
+ self._mjcf_root.worldbody.add(
+ 'geom', name=tile_name, type='plane', material=material,
+ pos=tile_pos, size=tile_size, contype=0, conaffinity=0)
+
+ @property
+ def ground_geoms(self):
+ return tuple([
+ geom for geom in self.mjcf_model.find_all('geom')
+ if 'ground' in geom.name
+ ])
+
+ def find_token_grid_positions(self, tokens):
+ out = {token: [] for token in tokens}
+ for y in range(self._maze.entity_layer.shape[0]):
+ for x in range(self._maze.entity_layer.shape[1]):
+ for token in tokens:
+ if self._maze.entity_layer[y, x] == token:
+ out[token].append((y, x))
+ return out
+
+ def grid_to_world_positions(self, grid_positions):
+ out = []
+ for y, x in grid_positions:
+ out.append(np.array([(x - self._x_offset) * self._xy_scale,
+ -(y - self._y_offset) * self._xy_scale,
+ 0.0]))
+ return out
+
+ def world_to_grid_positions(self, world_positions):
+ out = []
+ # the order of x, y is reverse between grid positions format and
+ # world positions format.
+ for x, y, _ in world_positions:
+ out.append(np.array([self._y_offset - y / self._xy_scale,
+ self._x_offset + x / self._xy_scale]))
+ return out
+
+ def _find_spawn_and_target_positions(self):
+ grid_positions = self.find_token_grid_positions([
+ labmaze.defaults.OBJECT_TOKEN, labmaze.defaults.SPAWN_TOKEN])
+ self._target_grid_positions = tuple(
+ grid_positions[labmaze.defaults.OBJECT_TOKEN])
+ self._spawn_grid_positions = tuple(
+ grid_positions[labmaze.defaults.SPAWN_TOKEN])
+ self._target_positions = tuple(
+ self.grid_to_world_positions(self._target_grid_positions))
+ self._spawn_positions = tuple(
+ self.grid_to_world_positions(self._spawn_grid_positions))
+
+
+class MazeObservables(composer.Observables):
+
+ @composer.observable
+ def top_camera(self):
+ return observable.MJCFCamera(self._entity.top_camera)
+
+
+class RandomMazeWithTargets(MazeWithTargets):
+ """A randomly generated 2D maze with target positions."""
+
+ def _build(self,
+ x_cells,
+ y_cells,
+ xy_scale=2.0,
+ z_height=2.0,
+ max_rooms=labmaze.defaults.MAX_ROOMS,
+ room_min_size=labmaze.defaults.ROOM_MIN_SIZE,
+ room_max_size=labmaze.defaults.ROOM_MAX_SIZE,
+ spawns_per_room=labmaze.defaults.SPAWN_COUNT,
+ targets_per_room=labmaze.defaults.OBJECT_COUNT,
+ max_variations=labmaze.defaults.MAX_VARIATIONS,
+ simplify=labmaze.defaults.SIMPLIFY,
+ skybox_texture=None,
+ wall_textures=None,
+ floor_textures=None,
+ aesthetic='default',
+ name='random_maze'):
+ """Initializes this random maze arena.
+
+ Args:
+ x_cells: The number of cells along the x-direction of the maze. Must be
+ an odd integer.
+ y_cells: The number of cells along the y-direction of the maze. Must be
+ an odd integer.
+ xy_scale: The size of each maze cell in metres.
+ z_height: The z-height of the maze in metres.
+ max_rooms: (optional) The maximum number of rooms in each generated maze.
+ room_min_size: (optional) The minimum size of each room generated.
+ room_max_size: (optional) The maximum size of each room generated.
+ spawns_per_room: (optional) Number of spawn points
+ to generate in each room.
+ targets_per_room: (optional) Number of targets to generate in each room.
+ max_variations: (optional) Maximum number of variations to generate
+ in the variations layer.
+ simplify: (optional) flag to simplify the maze.
+ skybox_texture: (optional) A `composer.Entity` that provides a texture
+ asset for the skybox.
+ wall_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze walls.
+ floor_textures: (optional) A `composer.Entity` that provides texture
+ assets for the maze floor.
+ aesthetic: option to adjust the material properties and skybox
+ name: (optional) A string, the name of this arena.
+ """
+ random_seed = np.random.randint(2147483648) # 2**31
+ super(RandomMazeWithTargets, self)._build(
+ maze=labmaze.RandomMaze(
+ height=y_cells,
+ width=x_cells,
+ max_rooms=max_rooms,
+ room_min_size=room_min_size,
+ room_max_size=room_max_size,
+ max_variations=max_variations,
+ spawns_per_room=spawns_per_room,
+ objects_per_room=targets_per_room,
+ simplify=simplify,
+ random_seed=random_seed),
+ xy_scale=xy_scale,
+ z_height=z_height,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ aesthetic=aesthetic,
+ name=name)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py
new file mode 100644
index 0000000..80a61c7
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py
@@ -0,0 +1,52 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.arenas.mazes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+
+
+class MazesTest(absltest.TestCase):
+
+ def test_can_compile_mjcf(self):
+
+ # Set the wall and floor textures to match DMLab and set the skybox.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures)
+ mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py
new file mode 100644
index 0000000..4224c02
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py
new file mode 100644
index 0000000..d508fcd
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py
@@ -0,0 +1,225 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Produces reference environments for CMU humanoid locomotion tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.tasks import go_to_target
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.walkers import cmu_humanoid
+from labmaze import fixed_maze
+
+
+def cmu_humanoid_run_walls(random_state=None):
+ """Requires a CMU humanoid to run down a corridor obstructed by walls."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = corr_arenas.WallsCorridor(
+ wall_gap=4.,
+ wall_width=distributions.Uniform(1, 7),
+ wall_height=3.0,
+ corridor_width=10,
+ corridor_length=100)
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(0.5, 0, 0),
+ target_velocity=3.0,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_run_gaps(random_state=None):
+ """Requires a CMU humanoid to run down a corridor with gaps."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
+ # platforms are uniformly randomized.
+ arena = corr_arenas.GapsCorridor(
+ platform_length=distributions.Uniform(.3, 2.5),
+ gap_length=distributions.Uniform(.5, 1.25),
+ corridor_width=10,
+ corridor_length=100)
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(0.5, 0, 0),
+ target_velocity=3.0,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_go_to_target(random_state=None):
+ """Requires a CMU humanoid to go to a target."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled()
+
+ # Build a standard floor arena.
+ arena = floors.Floor()
+
+ # Build a task that rewards the agent for going to a target.
+ task = go_to_target.GoToTarget(
+ walker=walker,
+ arena=arena,
+ physics_timestep=0.005,
+ control_timestep=0.03)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_maze_forage(random_state=None):
+ """Requires a CMU humanoid to find all items in a maze."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ # Build a task that rewards the agent for obtaining targets.
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ target_reward_scale=50.,
+ physics_timestep=0.005,
+ control_timestep=0.03,
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def cmu_humanoid_heterogeneous_forage(random_state=None):
+ """Requires a CMU humanoid to find all items of a particular type in a maze."""
+ level = ('*******\n'
+ '* *\n'
+ '* P *\n'
+ '* *\n'
+ '* G *\n'
+ '* *\n'
+ '*******\n')
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ maze = fixed_maze.FixedMazeWithRandomGoals(
+ entity_layer=level,
+ variations_layer=None,
+ num_spawns=1,
+ num_objects=6,
+ )
+ arena = mazes.MazeWithTargets(
+ maze=maze,
+ xy_scale=3.0,
+ z_height=2.0,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+ task = random_goal_maze.ManyHeterogeneousGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builders=[
+ functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0.4, 0),
+ rgb2=(0, 0.7, 0)),
+ functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0.4, 0, 0),
+ rgb2=(0.7, 0, 0)),
+ ],
+ randomize_spawn_rotation=False,
+ target_type_rewards=[30., -10.],
+ target_type_proportions=[1, 1],
+ shuffle_target_builders=True,
+ aliveness_reward=0.01,
+ control_timestep=.03,
+ )
+
+ return composer.Environment(
+ time_limit=25,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py
new file mode 100644
index 0000000..1b6283d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py
@@ -0,0 +1,175 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Produces reference environments for rodent tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.locomotion.arenas import bowl
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.tasks import escape
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.tasks import reach
+from dm_control.locomotion.walkers import rodent
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+def rodent_escape_bowl(random_state=None):
+ """Requires a rodent to climb out of a bowl-shaped terrain."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a bowl-shaped arena.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for being far from the origin.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=20,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_run_gaps(random_state=None):
+ """Requires a rodent to run down a corridor with gaps."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a corridor-shaped arena with gaps, where the sizes of the gaps and
+ # platforms are uniformly randomized.
+ arena = corr_arenas.GapsCorridor(
+ platform_length=distributions.Uniform(.4, .8),
+ gap_length=distributions.Uniform(.05, .2),
+ corridor_width=2,
+ corridor_length=40,
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ target_velocity=1.0,
+ contact_termination=False,
+ terminate_at_height=-0.3,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_maze_forage(random_state=None):
+ """Requires a rodent to find all items in a maze."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build a maze with rooms and targets.
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=.5,
+ z_height=.3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ wall_textures=wall_textures,
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for obtaining targets.
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.05,
+ height_above_ground=.125,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ target_reward_scale=50.,
+ contact_termination=False,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
+
+
+def rodent_two_touch(random_state=None):
+ """Requires a rodent to tap an orb, wait an interval, and tap it again."""
+
+ # Build a position-controlled rodent walker.
+ walker = rodent.Rat(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build an open floor arena
+ arena = floors.Floor(
+ size=(10., 10.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the walker for touching/reaching orbs with a
+ # specific time interval between touches
+ task = reach.TwoTouch(
+ walker=walker,
+ arena=arena,
+ target_builders=[
+ functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025),
+ ],
+ randomize_spawn_rotation=True,
+ target_type_rewards=[25.],
+ shuffle_target_builders=False,
+ target_area=(1.5, 1.5),
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP,
+ )
+
+ return composer.Environment(time_limit=30,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py
new file mode 100644
index 0000000..2c210f1
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py
@@ -0,0 +1,89 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for `dm_control.locomotion.examples`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion.examples import basic_cmu_2019
+from dm_control.locomotion.examples import basic_rodent_2020
+
+import numpy as np
+from six.moves import range
+
+
+_NUM_EPISODES = 5
+_NUM_STEPS_PER_EPISODE = 10
+
+
+class ExampleEnvironmentsTest(parameterized.TestCase):
+ """Tests run on all the tasks registered."""
+
+ def _validate_observation(self, observation, observation_spec):
+ self.assertEqual(list(observation.keys()), list(observation_spec.keys()))
+ for name, array_spec in observation_spec.items():
+ array_spec.validate(observation[name])
+
+ def _validate_reward_range(self, reward):
+ self.assertIsInstance(reward, float)
+ self.assertBetween(reward, 0, 1)
+
+ def _validate_discount(self, discount):
+ self.assertIsInstance(discount, float)
+ self.assertBetween(discount, 0, 1)
+
+ @parameterized.named_parameters(
+ ('cmu_humanoid_run_walls', basic_cmu_2019.cmu_humanoid_run_walls),
+ ('cmu_humanoid_run_gaps', basic_cmu_2019.cmu_humanoid_run_gaps),
+ ('cmu_humanoid_go_to_target', basic_cmu_2019.cmu_humanoid_go_to_target),
+ ('cmu_humanoid_maze_forage', basic_cmu_2019.cmu_humanoid_maze_forage),
+ ('cmu_humanoid_heterogeneous_forage',
+ basic_cmu_2019.cmu_humanoid_heterogeneous_forage),
+ ('rodent_escape_bowl', basic_rodent_2020.rodent_escape_bowl),
+ ('rodent_run_gaps', basic_rodent_2020.rodent_run_gaps),
+ ('rodent_maze_forage', basic_rodent_2020.rodent_maze_forage),
+ ('rodent_two_touch', basic_rodent_2020.rodent_two_touch),
+ )
+ def test_env_runs(self, env_constructor):
+ """Tests that the environment runs and is coherent with its specs."""
+ random_state = np.random.RandomState(99)
+
+ env = env_constructor(random_state=random_state)
+ observation_spec = env.observation_spec()
+ action_spec = env.action_spec()
+ self.assertTrue(np.all(np.isfinite(action_spec.minimum)))
+ self.assertTrue(np.all(np.isfinite(action_spec.maximum)))
+
+ # Run a partial episode, check observations, rewards, discount.
+ for _ in range(_NUM_EPISODES):
+ time_step = env.reset()
+ for _ in range(_NUM_STEPS_PER_EPISODE):
+ self._validate_observation(time_step.observation, observation_spec)
+ if time_step.first():
+ self.assertIsNone(time_step.reward)
+ self.assertIsNone(time_step.discount)
+ else:
+ self._validate_reward_range(time_step.reward)
+ self._validate_discount(time_step.discount)
+ action = random_state.uniform(action_spec.minimum, action_spec.maximum)
+ env.step(action)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py
new file mode 100644
index 0000000..a966283
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py
@@ -0,0 +1,28 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Simple script to launch viewer with an example environment."""
+
+from absl import app
+
+from dm_control import viewer
+from dm_control.locomotion.examples import basic_cmu_2019
+
+
+def main(unused_argv):
+ viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_gaps)
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/gaps.png b/DMC/src/env/dm_control/dm_control/locomotion/gaps.png
new file mode 100644
index 0000000..9a6c16d
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/gaps.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py
new file mode 100644
index 0000000..84e687e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Props for Locomotion tasks."""
+
+from dm_control.locomotion.props.target_sphere import TargetSphere
+from dm_control.locomotion.props.target_sphere import TargetSphereTwoTouch
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py
new file mode 100644
index 0000000..0c55b0c
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py
@@ -0,0 +1,227 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A non-colliding sphere that is activated through touch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control import mjcf
+
+
+class TargetSphere(composer.Entity):
+ """A non-colliding sphere that is activated through touch.
+
+ Once the target has been reached, it remains in the "activated" state
+ for the remainder of the current episode.
+
+ The target is automatically reset to "not activated" state at episode
+ initialization time.
+ """
+
+ def _build(self,
+ radius=0.6,
+ height_above_ground=1,
+ rgb1=(0, 0.4, 0),
+ rgb2=(0, 0.7, 0),
+ specific_collision_geom_ids=None,
+ name='target'):
+ """Builds this target sphere.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ height_above_ground: The height (in meters) of this target above ground.
+ rgb1: A sequence of three floating point values between 0.0 and 1.0
+ (inclusive) representing the color of the first element in the stripe
+ pattern of the target.
+ rgb2: A sequence of three floating point values between 0.0 and 1.0
+ (inclusive) representing the color of the second element in the stripe
+ pattern of the target.
+ specific_collision_geom_ids: Only activate if collides with these geoms.
+ name: The name of this entity.
+ """
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._texture = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere', type='cube',
+ builtin='checker', rgb1=rgb1, rgb2=rgb2,
+ width='100', height='100')
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='target_sphere', texture=self._texture)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type='sphere', name='geom', gap=2*radius,
+ pos=[0, 0, height_above_ground], size=[radius], material=self._material)
+ self._geom_id = -1
+ self._activated = False
+ self._specific_collision_geom_ids = specific_collision_geom_ids
+
+ @property
+ def geom(self):
+ return self._geom
+
+ @property
+ def material(self):
+ return self._material
+
+ @property
+ def activated(self):
+ """Whether this target has been reached during this episode."""
+ return self._activated
+
+ def reset(self, physics):
+ self._activated = False
+ physics.bind(self._material).rgba[-1] = 1
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._activated = False
+
+ def _update_activation(self, physics):
+ if not self._activated:
+ for contact in physics.data.contact:
+ if self._specific_collision_geom_ids:
+ has_specific_collision = (
+ contact.geom1 in self._specific_collision_geom_ids or
+ contact.geom2 in self._specific_collision_geom_ids)
+ else:
+ has_specific_collision = True
+ if (has_specific_collision and
+ self._geom_id in (contact.geom1, contact.geom2)):
+ self._activated = True
+ physics.bind(self._material).rgba[-1] = 0
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._update_activation(physics)
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_activation(physics)
+
+
+class TargetSphereTwoTouch(composer.Entity):
+ """A non-colliding sphere that is activated through touch.
+
+ The target indicates if it has been touched at least once and touched at least
+ twice this episode with a two-bit activated state tuple. It remains activated
+ for the remainder of the current episode.
+
+ The target is automatically reset at episode initialization.
+ """
+
+ def _build(self,
+ radius=0.6,
+ height_above_ground=1,
+ rgb_initial=((0, 0.4, 0), (0, 0.7, 0)),
+ rgb_interval=((1., 1., .4), (0.7, 0.7, 0.)),
+ rgb_final=((.4, 0.7, 1.), (0, 0.4, .7)),
+ touch_debounce=.2,
+ specific_collision_geom_ids=None,
+ name='target'):
+ """Builds this target sphere.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ height_above_ground: The height (in meters) of this target above ground.
+ rgb_initial: A tuple of two colors for the stripe pattern of the target.
+ rgb_interval: A tuple of two colors for the stripe pattern of the target.
+ rgb_final: A tuple of two colors for the stripe pattern of the target.
+ touch_debounce: duration to not count second touch.
+ specific_collision_geom_ids: Only activate if collides with these geoms.
+ name: The name of this entity.
+ """
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._texture_initial = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_init', type='cube',
+ builtin='checker', rgb1=rgb_initial[0], rgb2=rgb_initial[1],
+ width='100', height='100')
+ self._texture_interval = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_inter', type='cube',
+ builtin='checker', rgb1=rgb_interval[0], rgb2=rgb_interval[1],
+ width='100', height='100')
+ self._texture_final = self._mjcf_root.asset.add(
+ 'texture', name='target_sphere_final', type='cube',
+ builtin='checker', rgb1=rgb_final[0], rgb2=rgb_final[1],
+ width='100', height='100')
+ self._material = self._mjcf_root.asset.add(
+ 'material', name='target_sphere_init', texture=self._texture_initial)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type='sphere', name='geom', gap=2*radius,
+ pos=[0, 0, height_above_ground], size=[radius],
+ material=self._material)
+ self._geom_id = -1
+ self._touched_once = False
+ self._touched_twice = False
+ self._touch_debounce = touch_debounce
+ self._specific_collision_geom_ids = specific_collision_geom_ids
+
+ @property
+ def geom(self):
+ return self._geom
+
+ @property
+ def material(self):
+ return self._material
+
+ @property
+ def activated(self):
+ """Whether this target has been reached during this episode."""
+ return (self._touched_once, self._touched_twice)
+
+ def reset(self, physics):
+ self._touched_once = False
+ self._touched_twice = False
+ self._geom.material = self._material
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_initial).element_id
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._touched_once = False
+ self._touched_twice = False
+
+ def _update_activation(self, physics):
+ if not (self._touched_once and self._touched_twice):
+ for contact in physics.data.contact:
+ if self._specific_collision_geom_ids:
+ has_specific_collision = (
+ contact.geom1 in self._specific_collision_geom_ids or
+ contact.geom2 in self._specific_collision_geom_ids)
+ else:
+ has_specific_collision = True
+ if (has_specific_collision and
+ self._geom_id in (contact.geom1, contact.geom2)):
+ if not self._touched_once:
+ self._touched_once = True
+ self._touch_time = physics.time()
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_interval).element_id
+ if self._touched_once and (
+ physics.time() > (self._touch_time + self._touch_debounce)):
+ self._touched_twice = True
+ physics.bind(self._material).texid = physics.bind(
+ self._texture_final).element_id
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._update_activation(physics)
+
+ def after_substep(self, physics, unused_random_state):
+ self._update_activation(physics)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py
new file mode 100644
index 0000000..a83a39d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py
@@ -0,0 +1,70 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for props.target_sphere."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl.testing import absltest
+from dm_control import composer
+
+from dm_control.entities.props import primitive
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.props import target_sphere
+
+
+class TargetSphereTest(absltest.TestCase):
+
+ def testActivation(self):
+ target_radius = 0.6
+ prop_radius = 0.1
+ target_height = 1
+
+ arena = floors.Floor()
+ target = target_sphere.TargetSphere(radius=target_radius,
+ height_above_ground=target_height)
+ prop = primitive.Primitive(geom_type='sphere', size=[prop_radius])
+ arena.attach(target)
+ arena.add_free_entity(prop)
+
+ task = composer.NullTask(arena)
+ task.initialize_episode = (
+ lambda physics, random_state: prop.set_pose(physics, [0, 0, 2]))
+
+ env = composer.Environment(task)
+ env.reset()
+
+ max_activated_height = target_height + target_radius + prop_radius
+
+ while env.physics.bind(prop.geom).xpos[2] > max_activated_height:
+ self.assertFalse(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 1)
+ env.step([])
+
+ while env.physics.bind(prop.geom).xpos[2] > 0.2:
+ self.assertTrue(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 0)
+ env.step([])
+
+ # Target should be reset when the environment is reset.
+ env.reset()
+ self.assertFalse(target.activated)
+ self.assertEqual(env.physics.bind(target.material).rgba[-1], 1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md b/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md
new file mode 100644
index 0000000..49b8504
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md
@@ -0,0 +1,64 @@
+# DeepMind MuJoCo Multi-Agent Soccer Environment.
+
+This submodule contains the components and environment described in ICLR 2019
+paper [Emergent Coordination through Competition][website].
+
+# ![soccer](soccer.png)
+
+## Installation and requirements
+
+See [dm_control](../../../README.md#installation-and-requirements) for instructions.
+
+## Quickstart
+
+```python
+import numpy as np
+from dm_control.locomotion import soccer as dm_soccer
+
+# Load the 2-vs-2 soccer environment with episodes of 10 seconds:
+env = dm_soccer.load(team_size=2, time_limit=10.)
+
+# Retrieves action_specs for all 4 players.
+action_specs = env.action_spec()
+
+# Step through the environment for one episode with random actions.
+time_step = env.reset()
+while not time_step.last():
+ actions = []
+ for action_spec in action_specs:
+ action = np.random.uniform(
+ action_spec.minimum, action_spec.maximum, size=action_spec.shape)
+ actions.append(action)
+ time_step = env.step(actions)
+
+ for i in range(len(action_specs)):
+ print(
+ "Player {}: reward = {}, discount = {}, observations = {}.".format(
+ i, time_step.reward[i], time_step.discount,
+ time_step.observation[i]))
+```
+
+## Rewards
+
+The environment provides a reward of +1 to each player when their team
+scores a goal, -1 when their team concedes a goal, or 0 if neither team scored
+on the current timestep.
+
+In addition to the sparse reward returned the environment, the player
+observations also contain various environment statistics that may be used to
+derive custom per-player shaping rewards (as was done in
+http://arxiv.org/abs/1902.07151, where the environment reward was ignored).
+
+## Episode terminations
+
+Episodes will terminate immediately with a discount factor of 0 when either side
+scores a goal. There is also a per-episode `time_limit` (45 seconds by default).
+If neither team scores within this time then the episode will terminate with a
+discount factor of 1.
+
+## Environment Viewer
+
+To visualize an example 2-vs-2 soccer environment in the `dm_control`
+interactive viewer, execute `dm_control/locomotion/soccer/explore.py`.
+
+[website]: https://sites.google.com/corp/view/emergent-coordination/home
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py
new file mode 100644
index 0000000..20dc585
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py
@@ -0,0 +1,94 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Multi-agent MuJoCo soccer environment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.locomotion.soccer.boxhead import BoxHead
+from dm_control.locomotion.soccer.initializers import Initializer
+from dm_control.locomotion.soccer.initializers import UniformInitializer
+from dm_control.locomotion.soccer.observables import CoreObservablesAdder
+from dm_control.locomotion.soccer.observables import InterceptionObservablesAdder
+from dm_control.locomotion.soccer.observables import MultiObservablesAdder
+from dm_control.locomotion.soccer.observables import ObservablesAdder
+from dm_control.locomotion.soccer.pitch import Pitch
+from dm_control.locomotion.soccer.pitch import RandomizedPitch
+from dm_control.locomotion.soccer.soccer_ball import SoccerBall
+from dm_control.locomotion.soccer.task import Task
+from dm_control.locomotion.soccer.team import Player
+from dm_control.locomotion.soccer.team import Team
+from six.moves import range
+
+_RGBA_BLUE = [.1, .1, .8, 1.]
+_RGBA_RED = [.8, .1, .1, 1.]
+
+
+def _make_walker(name, walker_id, marker_rgba):
+ """Construct a BoxHead walker."""
+ return BoxHead(
+ name=name,
+ walker_id=walker_id,
+ marker_rgba=marker_rgba,
+ )
+
+
+def _make_players(team_size):
+ """Construct home and away teams each of `team_size` players."""
+ home_players = []
+ away_players = []
+ for i in range(team_size):
+ home_players.append(
+ Player(Team.HOME, _make_walker("home%d" % i, i, _RGBA_BLUE)))
+ away_players.append(
+ Player(Team.AWAY, _make_walker("away%d" % i, i, _RGBA_RED)))
+ return home_players + away_players
+
+
+def load(team_size,
+ time_limit=45.,
+ random_state=None,
+ disable_walker_contacts=False):
+ """Construct `team_size`-vs-`team_size` soccer environment.
+
+ Args:
+ team_size: Integer, the number of players per team. Must be between 1 and
+ 11.
+ time_limit: Float, the maximum duration of each episode in seconds.
+ random_state: (optional) an int seed or `np.random.RandomState` instance.
+ disable_walker_contacts: (optional) if `True`, disable physical contacts
+ between walkers.
+
+ Returns:
+ A `composer.Environment` instance.
+
+ Raises:
+ ValueError: If `team_size` is not between 1 and 11.
+ """
+ if team_size < 0 or team_size > 11:
+ raise ValueError(
+ "Team size must be between 1 and 11 (received %d)." % team_size)
+
+ return composer.Environment(
+ task=Task(
+ players=_make_players(team_size),
+ arena=RandomizedPitch(
+ min_size=(32, 24), max_size=(48, 36), keep_aspect_ratio=True),
+ disable_walker_contacts=disable_walker_contacts),
+ time_limit=time_limit,
+ random_state=random_state)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml
new file mode 100644
index 0000000..17cb28d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml
@@ -0,0 +1,56 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png
new file mode 100644
index 0000000..b5fde93
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png
new file mode 100644
index 0000000..f52bae4
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png
new file mode 100644
index 0000000..bec2be6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png
new file mode 100644
index 0000000..654dba2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png
new file mode 100644
index 0000000..f320726
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png
new file mode 100644
index 0000000..64d8cd6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png
new file mode 100644
index 0000000..ed97f8b
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png
new file mode 100644
index 0000000..4808c29
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png
new file mode 100644
index 0000000..41d09b9
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png
new file mode 100644
index 0000000..649723c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png
new file mode 100644
index 0000000..d2cc0d0
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png
new file mode 100644
index 0000000..c01b517
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png
new file mode 100644
index 0000000..49ace5b
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png
new file mode 100644
index 0000000..c18fbf0
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png
new file mode 100644
index 0000000..d3b14f7
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png
new file mode 100644
index 0000000..69965a8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png
new file mode 100644
index 0000000..8c729d7
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py
new file mode 100644
index 0000000..8963b7a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py
@@ -0,0 +1,292 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Walkers based on an actuated jumping ball."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import legacy_base
+import numpy as np
+from PIL import Image
+import six
+
+from dm_control.utils import io as resources
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'boxhead')
+_MAX_WALKER_ID = 10
+_INVALID_WALKER_ID = 'walker_id must be in [0-{}], got: {{}}.'.format(
+ _MAX_WALKER_ID)
+
+
+def _compensate_gravity(physics, body_elements):
+ """Applies Cartesian forces to bodies in order to exactly counteract gravity.
+
+ Note that this will also affect the output of pressure, force, or torque
+ sensors within the kinematic chain leading from the worldbody to the bodies
+ that are being gravity-compensated.
+
+ Args:
+ physics: An `mjcf.Physics` instance to modify.
+ body_elements: An iterable of `mjcf.Element`s specifying the bodies to which
+ gravity compensation will be applied.
+ """
+ gravity = np.hstack([physics.model.opt.gravity, [0, 0, 0]])
+ bodies = physics.bind(body_elements)
+ bodies.xfrc_applied = -gravity * bodies.mass[..., None]
+
+
+def _alpha_blend(foreground, background):
+ """Does alpha compositing of two RGBA images.
+
+ Both inputs must be (..., 4) numpy arrays whose shapes are compatible for
+ broadcasting. They are assumed to contain float RGBA values in [0, 1].
+
+ Args:
+ foreground: foreground RGBA image.
+ background: background RGBA image.
+
+ Returns:
+ A numpy array of shape (..., 4) containing the blended image.
+ """
+ fg, bg = np.broadcast_arrays(foreground, background)
+ fg_rgb = fg[..., :3]
+ fg_a = fg[..., 3:]
+ bg_rgb = bg[..., :3]
+ bg_a = bg[..., 3:]
+ out = np.empty_like(bg)
+ out_a = out[..., 3:]
+ out_rgb = out[..., :3]
+ # https://en.wikipedia.org/wiki/Alpha_compositing#Alpha_blending
+ out_a[:] = fg_a + bg_a * (1. - fg_a)
+ out_rgb[:] = fg_rgb * fg_a + bg_rgb * bg_a * (1. - fg_a)
+ # Avoid division by zero if foreground and background are both transparent.
+ out_rgb[:] = np.where(out_a, out_rgb / out_a, out_rgb)
+ return out
+
+
+def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba):
+ """Decode PNG from asset file and add solid background."""
+
+ # Retrieve PNG image contents as a bytestring, convert to a numpy array.
+ contents = resources.GetResource(os.path.join(_ASSETS_PATH, asset_fname))
+ digit_rgba = np.array(Image.open(six.BytesIO(contents)), dtype=np.double)
+
+ # Add solid background with `background_rgba`.
+ blended = 255. * _alpha_blend(digit_rgba / 255., np.asarray(background_rgba))
+
+ # Encode composite image array to a PNG bytestring.
+ img = Image.fromarray(blended.astype(np.uint8), mode='RGBA')
+ buf = six.BytesIO()
+ img.save(buf, format='PNG')
+ png_encoding = buf.getvalue()
+ buf.close()
+
+ return png_encoding
+
+
+class BoxHeadObservables(legacy_base.WalkerObservables):
+
+ def __init__(self, entity, camera_resolution):
+ self._camera_resolution = camera_resolution
+ super(BoxHeadObservables, self).__init__(entity)
+
+ @composer.observable
+ def egocentric_camera(self):
+ width, height = self._camera_resolution
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=width, height=height)
+
+
+class BoxHead(legacy_base.Walker):
+ """A rollable and jumpable ball with a head."""
+
+ def _build(self,
+ name='walker',
+ marker_rgba=None,
+ camera_control=False,
+ camera_resolution=(28, 28),
+ roll_gear=-60,
+ steer_gear=55,
+ walker_id=None,
+ initializer=None):
+ """Build a BoxHead.
+
+ Args:
+ name: name of the walker.
+ marker_rgba: RGBA value set to walker.marker_geoms to distinguish between
+ walkers (in multi-agent setting).
+ camera_control: If `True`, the walker exposes two additional actuated
+ degrees of freedom to control the egocentric camera height and tilt.
+ camera_resolution: egocentric camera rendering resolution.
+ roll_gear: gear determining forward acceleration.
+ steer_gear: gear determining steering (spinning) torque.
+ walker_id: (Optional) An integer in [0-10], this number will be shown on
+ the walker's head. Defaults to `None` which does not show any number.
+ initializer: (Optional) A `WalkerInitializer` object.
+
+ Raises:
+ ValueError: if received invalid walker_id.
+ """
+ super(BoxHead, self)._build(initializer=initializer)
+ xml_path = os.path.join(_ASSETS_PATH, 'boxhead.xml')
+ self._mjcf_root = mjcf.from_xml_string(resources.GetResource(xml_path, 'r'))
+ if name:
+ self._mjcf_root.model = name
+
+ if walker_id is not None and not 0 <= walker_id <= _MAX_WALKER_ID:
+ raise ValueError(_INVALID_WALKER_ID.format(walker_id))
+
+ self._walker_id = walker_id
+ if walker_id is not None:
+ png_bytes = _asset_png_with_background_rgba_bytes(
+ 'digits/%02d.png' % walker_id, marker_rgba)
+ head_texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='head_texture',
+ type='2d',
+ file=mjcf.Asset(png_bytes, '.png'))
+ head_material = self._mjcf_root.asset.add(
+ 'material', name='head_material', texture=head_texture)
+ self._mjcf_root.find('geom', 'head').material = head_material
+ self._mjcf_root.find('geom', 'head').rgba = None
+
+ self._mjcf_root.find('geom', 'top_down_cam_box').material = head_material
+ self._mjcf_root.find('geom', 'top_down_cam_box').rgba = None
+
+ self._body_texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='ball_body',
+ type='cube',
+ builtin='checker',
+ rgb1=marker_rgba[:-1] if marker_rgba else '.4 .4 .4',
+ rgb2='.8 .8 .8',
+ width='100',
+ height='100')
+ self._body_material = self._mjcf_root.asset.add(
+ 'material', name='ball_body', texture=self._body_texture)
+ self._mjcf_root.find('geom', 'shell').material = self._body_material
+
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ self._root_joints = None
+ self._camera_control = camera_control
+ self._camera_resolution = camera_resolution
+ if not camera_control:
+ for name in ('camera_pitch', 'camera_yaw'):
+ self._mjcf_root.find('actuator', name).remove()
+ self._mjcf_root.find('joint', name).remove()
+ self._roll_gear = roll_gear
+ self._steer_gear = steer_gear
+ self._mjcf_root.find('actuator', 'roll').gear[0] = self._roll_gear
+ self._mjcf_root.find('actuator', 'steer').gear[0] = self._steer_gear
+
+ # Initialize previous action.
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def _build_observables(self):
+ return BoxHeadObservables(self, camera_resolution=self._camera_resolution)
+
+ @property
+ def marker_geoms(self):
+ geoms = [
+ self._mjcf_root.find('geom', 'arm_l'),
+ self._mjcf_root.find('geom', 'arm_r'),
+ self._mjcf_root.find('geom', 'eye_l'),
+ self._mjcf_root.find('geom', 'eye_r'),
+ ]
+ if self._walker_id is None:
+ geoms.append(self._mjcf_root.find('geom', 'head'))
+ return geoms
+
+ def create_root_joints(self, attachment_frame):
+ root_class = self._mjcf_root.find('default', 'root')
+ root_x = attachment_frame.add(
+ 'joint', name='root_x', type='slide', axis=[1, 0, 0], dclass=root_class)
+ root_y = attachment_frame.add(
+ 'joint', name='root_y', type='slide', axis=[0, 1, 0], dclass=root_class)
+ root_z = attachment_frame.add(
+ 'joint', name='root_z', type='slide', axis=[0, 0, 1], dclass=root_class)
+ self._root_joints = [root_x, root_y, root_z]
+
+ def set_pose(self, physics, position=None, quaternion=None):
+ if position is not None:
+ if self._root_joints is not None:
+ physics.bind(self._root_joints).qpos = position
+ else:
+ super(BoxHead, self).set_pose(physics, position, quaternion=None)
+ physics.bind(self._mjcf_root.find_all('joint')).qpos = 0.
+ if quaternion is not None:
+ # This walker can only rotate along the z-axis, so we extract only that
+ # component from the quaternion.
+ z_angle = np.arctan2(
+ 2 * (quaternion[0] * quaternion[3] + quaternion[1] * quaternion[2]),
+ 1 - 2 * (quaternion[2] ** 2 + quaternion[3] ** 2))
+ physics.bind(self._mjcf_root.find('joint', 'steer')).qpos = z_angle
+
+ def initialize_episode(self, physics, unused_random_state):
+ if self._camera_control:
+ _compensate_gravity(physics,
+ self._mjcf_root.find('body', 'egocentric_camera'))
+ self._prev_action = np.zeros(shape=self.action_spec.shape,
+ dtype=self.action_spec.dtype)
+
+ def apply_action(self, physics, action, random_state):
+ super(BoxHead, self).apply_action(physics, action, random_state)
+
+ # Updates previous action.
+ self._prev_action[:] = action
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ return self._mjcf_root.find_all('actuator')
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'head_body')
+
+ @composer.cached_property
+ def end_effectors(self):
+ return (self._mjcf_root.find('body', 'head_body'),)
+
+ @composer.cached_property
+ def observable_joints(self):
+ return (self._mjcf_root.find('joint', 'kick'),)
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ return (self._mjcf_root.find('geom', 'shell'),)
+
+ @property
+ def prev_action(self):
+ return self._prev_action
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py
new file mode 100644
index 0000000..684b4e7
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py
@@ -0,0 +1,49 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.boxhead."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion.soccer import boxhead
+
+
+class BoxheadTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ dict(camera_control=True, walker_id=None),
+ dict(camera_control=False, walker_id=None),
+ dict(camera_control=True, walker_id=0),
+ dict(camera_control=False, walker_id=10))
+ def test_instantiation(self, camera_control, walker_id):
+ boxhead.BoxHead(marker_rgba=[.8, .1, .1, 1.],
+ camera_control=camera_control,
+ walker_id=walker_id)
+
+ @parameterized.parameters(-1, 11)
+ def test_invalid_walker_id(self, walker_id):
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, boxhead._INVALID_WALKER_ID.format(walker_id)):
+ boxhead.BoxHead(walker_id=walker_id)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py
new file mode 100644
index 0000000..29d020d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py
@@ -0,0 +1,35 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Interactive viewer for MuJoCo soccer enviornmnet."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from absl import app
+from dm_control import viewer
+from dm_control.locomotion import soccer
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError('Too many command-line arguments.')
+ viewer.launch(environment_loader=functools.partial(soccer.load, team_size=2))
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py
new file mode 100644
index 0000000..22785ce
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py
@@ -0,0 +1,73 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Soccer task episode initializers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import numpy as np
+import six
+
+
+_INIT_BALL_Z = 0.5
+_SPAWN_RATIO = 0.6
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Initializer(object):
+
+ @abc.abstractmethod
+ def __call__(self, task, physics, random_state):
+ """Initialize episode for a task."""
+
+
+class UniformInitializer(Initializer):
+ """Uniformly initialize walkers and soccer ball over spawn_range."""
+
+ def __init__(self, spawn_ratio=_SPAWN_RATIO, init_ball_z=_INIT_BALL_Z):
+ self._spawn_ratio = spawn_ratio
+ self._init_ball_z = init_ball_z
+
+ def _initialize_ball(self, ball, spawn_range, physics, random_state):
+ x, y = random_state.uniform(-spawn_range, spawn_range)
+ ball.set_pose(physics, [x, y, self._init_ball_z])
+ # Note: this method is not always called immediately after `physics.reset()`
+ # so we need to explicitly zero out the velocity.
+ ball.set_velocity(physics, velocity=0., angular_velocity=0.)
+
+ def _initialize_walker(self, walker, spawn_range, physics, random_state):
+ """Uniformly initialize walker in spawn_range."""
+ walker.reinitialize_pose(physics, random_state)
+ x, y = random_state.uniform(-spawn_range, spawn_range)
+ (_, _, z), quat = walker.get_pose(physics)
+ walker.set_pose(physics, [x, y, z], quat)
+ rotation = random_state.uniform(-np.pi, np.pi)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walker.shift_pose(physics, quaternion=quat)
+ # TODO(b/132759890): `walker.set_velocity` has no effect for walkers without
+ # freejoints, such as `BoxHead`.
+ # Note: this method is not always called immediately after `physics.reset()`
+ # so we need to explicitly zero out the velocity.
+ walker.set_velocity(physics, velocity=0., angular_velocity=0.)
+
+ def __call__(self, task, physics, random_state):
+ spawn_range = np.asarray(task.arena.size) * self._spawn_ratio
+ self._initialize_ball(task.ball, spawn_range, physics, random_state)
+ for player in task.players:
+ self._initialize_walker(player.walker, spawn_range, physics, random_state)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py
new file mode 100644
index 0000000..8ec2ec9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py
@@ -0,0 +1,107 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.load."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.locomotion import soccer
+import numpy as np
+from six.moves import range
+
+
+class LoadTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("2vs2_nocontacts", 2, True), ("2vs2_contacts", 2, False),
+ ("1vs1_nocontacts", 1, True), ("1vs1_contacts", 1, False))
+ def test_load_env(self, team_size, disable_walker_contacts):
+ env = soccer.load(team_size=team_size, time_limit=2.,
+ disable_walker_contacts=disable_walker_contacts)
+ action_specs = env.action_spec()
+
+ random_state = np.random.RandomState(0)
+ time_step = env.reset()
+ while not time_step.last():
+ actions = []
+ for action_spec in action_specs:
+ action = random_state.uniform(
+ action_spec.minimum, action_spec.maximum, size=action_spec.shape)
+ actions.append(action)
+ time_step = env.step(actions)
+
+ for i in range(len(action_specs)):
+ logging.info(
+ "Player %d: reward = %s, discount = %s, observations = %s.", i,
+ time_step.reward[i], time_step.discount, time_step.observation[i])
+
+ def assertSameObservation(self, expected_observation, actual_observation):
+ self.assertLen(actual_observation, len(expected_observation))
+ for player_id in range(len(expected_observation)):
+ expected_player_observations = expected_observation[player_id]
+ actual_player_observations = actual_observation[player_id]
+ expected_keys = expected_player_observations.keys()
+ actual_keys = actual_player_observations.keys()
+ msg = ("Observation keys differ for player {}.\nExpected: {}.\nActual: {}"
+ .format(player_id, expected_keys, actual_keys))
+ self.assertEqual(expected_keys, actual_keys, msg)
+ for key in expected_player_observations:
+ expected_array = expected_player_observations[key]
+ actual_array = actual_player_observations[key]
+ msg = ("Observation {!r} differs for player {}.\nExpected:\n{}\n"
+ "Actual:\n{}"
+ .format(key, player_id, expected_array, actual_array))
+ np.testing.assert_array_equal(expected_array, actual_array,
+ err_msg=msg)
+
+ @parameterized.parameters(True, False)
+ def test_same_first_observation_if_same_seed(self, disable_walker_contacts):
+ seed = 42
+ timestep_1 = soccer.load(
+ team_size=2,
+ random_state=seed,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ timestep_2 = soccer.load(
+ team_size=2,
+ random_state=seed,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ self.assertSameObservation(timestep_1.observation, timestep_2.observation)
+
+ @parameterized.parameters(True, False)
+ def test_different_first_observation_if_different_seed(
+ self, disable_walker_contacts):
+ timestep_1 = soccer.load(
+ team_size=2,
+ random_state=1,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ timestep_2 = soccer.load(
+ team_size=2,
+ random_state=2,
+ disable_walker_contacts=disable_walker_contacts).reset()
+ try:
+ self.assertSameObservation(timestep_1.observation, timestep_2.observation)
+ except AssertionError:
+ pass
+ else:
+ self.fail("Observations are unexpectedly identical.")
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py
new file mode 100644
index 0000000..da3d406
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py
@@ -0,0 +1,432 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Soccer observables modules."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from dm_control.composer.observation import observable as base_observable
+from dm_control.locomotion.soccer import team as team_lib
+import numpy as np
+import six
+from six.moves import zip
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ObservablesAdder(object):
+ """A callable that adds a set of per-player observables for a task."""
+
+ @abc.abstractmethod
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+
+class MultiObservablesAdder(ObservablesAdder):
+ """Applies multiple `ObservablesAdder`s to a soccer task and player."""
+
+ def __init__(self, observables):
+ """Initializes a `MultiObservablesAdder` instance.
+
+ Args:
+ observables: A list of `ObservablesAdder` instances.
+ """
+ self._observables = observables
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+ for observable in self._observables:
+ observable(task, player)
+
+
+class CoreObservablesAdder(ObservablesAdder):
+ """Core set of per player observables."""
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+ # Enable proprioceptive observables.
+ self._add_player_proprio_observables(player)
+
+ # Add egocentric observations of soccer ball.
+ self._add_player_observables_on_ball(player, task.ball)
+
+ # Add egocentric observations of others.
+ teammate_id = 0
+ opponent_id = 0
+ for other in task.players:
+ if other is player:
+ continue
+ # Infer team prefix for `other` conditioned on `player.team`.
+ if player.team != other.team:
+ prefix = 'opponent_{}'.format(opponent_id)
+ opponent_id += 1
+ else:
+ prefix = 'teammate_{}'.format(teammate_id)
+ teammate_id += 1
+
+ self._add_player_observables_on_other(player, other, prefix)
+
+ self._add_player_arena_observables(player, task.arena)
+
+ # Add per player game statistics.
+ self._add_player_stats_observables(task, player)
+
+ def _add_player_observables_on_other(self, player, other, prefix):
+ """Add observables of another player in this player's egocentric frame.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observables to.
+ other: A `Walker` instance corresponding to a different player.
+ prefix: A string specifying a prefix to apply to the names of observables
+ belonging to `player`.
+ """
+ if player is other:
+ raise ValueError('Cannot add egocentric observables of player on itself.')
+ # Origin callable in xpos, xvel for `player`.
+ xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos
+ xvel_xyz_callable = lambda p: p.bind(player.walker.root_body).cvel[3:]
+ # Egocentric observation of other's position, orientation and
+ # linear velocities.
+ def _cvel_observation(physics, other=other):
+ # Velocitmeter reads in local frame but we need world frame observable
+ # for egocentric transformation.
+ return physics.bind(other.walker.root_body).cvel[3:]
+
+ def _egocentric_end_effectors_xpos(physics, other=other):
+ origin_xpos = xpos_xyz_callable(physics)
+ egocentric_end_effectors_xpos = []
+ for end_effector_body in other.walker.end_effectors:
+ xpos = physics.bind(end_effector_body).xpos
+ delta = xpos - origin_xpos
+ ego_xpos = player.walker.transform_vec_to_egocentric_frame(
+ physics, delta)
+ egocentric_end_effectors_xpos.append(ego_xpos)
+ return np.concatenate(egocentric_end_effectors_xpos)
+
+ player.walker.observables.add_egocentric_vector(
+ '{}_ego_linear_velocity'.format(prefix),
+ base_observable.Generic(_cvel_observation),
+ origin_callable=xvel_xyz_callable)
+ player.walker.observables.add_egocentric_vector(
+ '{}_ego_position'.format(prefix),
+ other.walker.observables.position,
+ origin_callable=xpos_xyz_callable)
+ player.walker.observables.add_egocentric_xmat(
+ '{}_ego_orientation'.format(prefix),
+ other.walker.observables.orientation)
+
+ # Adds end effectors of the other agents in the player's egocentric frame.
+ player.walker.observables.add_observable(
+ '{}_ego_end_effectors_pos'.format(prefix),
+ base_observable.Generic(_egocentric_end_effectors_xpos))
+
+ # Adds end effectors of the other agents in the other's egocentric frame.
+ # A is seeing B's hand extended to B's right.
+ player.walker.observables.add_observable(
+ '{}_end_effectors_pos'.format(prefix),
+ other.walker.observables.end_effectors_pos)
+
+ def _add_player_observables_on_ball(self, player, ball):
+ """Add observables of the soccer ball in this player's egocentric frame.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observations for.
+ ball: A `SoccerBall` instance.
+ """
+ # Origin callables for egocentric transformations.
+ xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos
+ xvel_xyz_callable = lambda p: p.bind(player.walker.root_body).cvel[3:]
+
+ # Add egocentric ball observations.
+ player.walker.observables.add_egocentric_vector(
+ 'ball_ego_angular_velocity', ball.observables.angular_velocity)
+ player.walker.observables.add_egocentric_vector(
+ 'ball_ego_position',
+ ball.observables.position,
+ origin_callable=xpos_xyz_callable)
+ player.walker.observables.add_egocentric_vector(
+ 'ball_ego_linear_velocity',
+ ball.observables.linear_velocity,
+ origin_callable=xvel_xyz_callable)
+
+ def _add_player_proprio_observables(self, player):
+ """Add proprioceptive observables to the given player.
+
+ Args:
+ player: A `Walker` instance, the player we are adding observations for.
+ """
+ for observable in (player.walker.observables.proprioception +
+ player.walker.observables.kinematic_sensors):
+ observable.enabled = True
+
+ # Also enable previous action observable as part of proprioception.
+ player.walker.observables.prev_action.enabled = True
+
+ def _add_player_arena_observables(self, player, arena):
+ """Add observables of the arena.
+
+ Args:
+ player: A `Walker` instance to which observables will be added.
+ arena: A `Pitch` instance.
+ """
+ # Enable egocentric view of position detectors (goal, field).
+ # Corners named according to walker *facing towards opponent goal*.
+ clockwise_names = [
+ 'team_goal_back_right',
+ 'team_goal_mid',
+ 'team_goal_front_left',
+ 'field_front_left',
+ 'opponent_goal_back_left',
+ 'opponent_goal_mid',
+ 'opponent_goal_front_right',
+ 'field_back_right',
+ ]
+ clockwise_features = [
+ lambda _: arena.home_goal.lower[:2],
+ lambda _: arena.home_goal.mid,
+ lambda _: arena.home_goal.upper[:2],
+ lambda _: arena.field.upper,
+ lambda _: arena.away_goal.upper[:2],
+ lambda _: arena.away_goal.mid,
+ lambda _: arena.away_goal.lower[:2],
+ lambda _: arena.field.lower,
+ ]
+ xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos
+ xpos_xy_callable = lambda p: p.bind(player.walker.root_body).xpos[:2]
+ # A list of egocentric reference origin for each one of clockwise_features.
+ clockwise_origins = [
+ xpos_xy_callable,
+ xpos_xyz_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ xpos_xyz_callable,
+ xpos_xy_callable,
+ xpos_xy_callable,
+ ]
+ if player.team != team_lib.Team.HOME:
+ half = len(clockwise_features) // 2
+ clockwise_features = clockwise_features[half:] + clockwise_features[:half]
+ clockwise_origins = clockwise_origins[half:] + clockwise_origins[:half]
+
+ for name, feature, origin in zip(clockwise_names, clockwise_features,
+ clockwise_origins):
+ player.walker.observables.add_egocentric_vector(
+ name, base_observable.Generic(feature), origin_callable=origin)
+
+ def _add_player_stats_observables(self, task, player):
+ """Add observables corresponding to game statistics.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+ def _stats_vel_to_ball(physics):
+ dir_ = (
+ physics.bind(task.ball.geom).xpos -
+ physics.bind(player.walker.root_body).xpos)
+ vel_to_ball = np.dot(dir_[:2] / (np.linalg.norm(dir_[:2]) + 1e-7),
+ physics.bind(player.walker.root_body).cvel[3:5])
+ return np.sum(vel_to_ball)
+
+ player.walker.observables.add_observable(
+ 'stats_vel_to_ball', base_observable.Generic(_stats_vel_to_ball))
+
+ def _stats_closest_vel_to_ball(physics):
+ """Velocity to the ball if this walker is the team's closest."""
+ closest = None
+ min_team_dist_to_ball = np.inf
+ for player_ in task.players:
+ if player_.team == player.team:
+ dist_to_ball = np.linalg.norm(
+ physics.bind(task.ball.geom).xpos -
+ physics.bind(player_.walker.root_body).xpos)
+ if dist_to_ball < min_team_dist_to_ball:
+ min_team_dist_to_ball = dist_to_ball
+ closest = player_
+ if closest is player:
+ return _stats_vel_to_ball(physics)
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_closest_vel_to_ball',
+ base_observable.Generic(_stats_closest_vel_to_ball))
+
+ def _stats_veloc_forward(physics):
+ """Player's forward velocity."""
+ return player.walker.observables.veloc_forward(physics)
+
+ player.walker.observables.add_observable(
+ 'stats_veloc_forward', base_observable.Generic(_stats_veloc_forward))
+
+ def _stats_vel_ball_to_goal(physics):
+ """Ball velocity towards opponents' goal."""
+ if player.team == team_lib.Team.HOME:
+ goal = task.arena.away_goal
+ else:
+ goal = task.arena.home_goal
+
+ goal_center = (goal.upper + goal.lower) / 2.
+ direction = goal_center - physics.bind(task.ball.geom).xpos
+ ball_vel_observable = task.ball.observables.linear_velocity
+ ball_vel = ball_vel_observable.observation_callable(physics)()
+
+ norm_dir = np.linalg.norm(direction)
+ normalized_dir = direction / norm_dir if norm_dir else direction
+ return np.sum(np.dot(normalized_dir, ball_vel))
+
+ player.walker.observables.add_observable(
+ 'stats_vel_ball_to_goal',
+ base_observable.Generic(_stats_vel_ball_to_goal))
+
+ def _stats_avg_teammate_dist(physics):
+ """Compute average distance from `walker` to its teammates."""
+ teammate_dists = []
+ for other in task.players:
+ if player is other:
+ continue
+ if other.team != player.team:
+ continue
+ dist = np.linalg.norm(
+ physics.bind(player.walker.root_body).xpos -
+ physics.bind(other.walker.root_body).xpos)
+ teammate_dists.append(dist)
+ return np.mean(teammate_dists) if teammate_dists else 0.
+
+ player.walker.observables.add_observable(
+ 'stats_home_avg_teammate_dist',
+ base_observable.Generic(_stats_avg_teammate_dist))
+
+ def _stats_teammate_spread_out(physics):
+ """Compute average distance from `walker` to its teammates."""
+ return _stats_avg_teammate_dist(physics) > 5.
+
+ player.walker.observables.add_observable(
+ 'stats_teammate_spread_out',
+ base_observable.Generic(_stats_teammate_spread_out))
+
+ def _stats_home_score(unused_physics):
+ if (task.arena.detected_goal() and
+ task.arena.detected_goal() == player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_home_score', base_observable.Generic(_stats_home_score))
+
+ has_opponent = any([p.team != player.team for p in task.players])
+
+ def _stats_away_score(unused_physics):
+ if (has_opponent and task.arena.detected_goal() and
+ task.arena.detected_goal() != player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_away_score', base_observable.Generic(_stats_away_score))
+
+
+# TODO(b/124848293): add unit-test interception observables.
+class InterceptionObservablesAdder(ObservablesAdder):
+ """Adds obervables representing interception events.
+
+ These observables represent events where this player received the ball from
+ another player, or when an opponent intercepted the ball from this player's
+ team. For each type of event there are three different thresholds applied to
+ the distance travelled by the ball since it last made contact with a player
+ (5, 10, or 15 meters).
+
+ For example, on a given timestep `stats_i_received_ball_10m` will be 1 if
+ * This player just made contact with the ball
+ * The last player to have made contact with the ball was a different player
+ * The ball travelled for at least 10 m since it last hit a player
+ and 0 otherwise.
+
+ Conversely, `stats_opponent_intercepted_ball_10m` will be 1 if:
+ * An opponent just made contact with the ball
+ * The last player to have made contact with the ball was on this player's team
+ * The ball travelled for at least 10 m since it last hit a player
+ """
+
+ def __call__(self, task, player):
+ """Adds observables to a player for the given task.
+
+ Args:
+ task: A `soccer.Task` instance.
+ player: A `Walker` instance to which observables will be added.
+ """
+
+ def _stats_i_received_ball(unused_physics):
+ if (task.ball.hit and task.ball.repossessed and
+ task.ball.last_hit is player):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_i_received_ball',
+ base_observable.Generic(_stats_i_received_ball))
+
+ def _stats_opponent_intercepted_ball(unused_physics):
+ """Indicator on if an opponent intercepted the ball."""
+ if (task.ball.hit and task.ball.intercepted and
+ task.ball.last_hit.team != player.team):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_opponent_intercepted_ball',
+ base_observable.Generic(_stats_opponent_intercepted_ball))
+
+ for dist in [5, 10, 15]:
+
+ def _stats_i_received_ball_dist(physics, dist=dist):
+ if (_stats_i_received_ball(physics) and
+ task.ball.dist_between_last_hits is not None and
+ task.ball.dist_between_last_hits > dist):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_i_received_ball_%dm' % dist,
+ base_observable.Generic(_stats_i_received_ball_dist))
+
+ def _stats_opponent_intercepted_ball_dist(physics, dist=dist):
+ if (_stats_opponent_intercepted_ball(physics) and
+ task.ball.dist_between_last_hits is not None and
+ task.ball.dist_between_last_hits > dist):
+ return 1.
+ return 0.
+
+ player.walker.observables.add_observable(
+ 'stats_opponent_intercepted_ball_%dm' % dist,
+ base_observable.Generic(_stats_opponent_intercepted_ball_dist))
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py
new file mode 100644
index 0000000..2ee5295
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py
@@ -0,0 +1,346 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A soccer pitch with home/away goals and one field with position detection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.locomotion.soccer import team
+import numpy as np
+
+
+_TOP_CAMERA_Y_PADDING_FACTOR = 1.1
+_TOP_CAMERA_DISTANCE = 100.
+_WALL_HEIGHT = 10.
+_WALL_THICKNESS = .5
+_SIDE_WIDTH = 32. / 6.
+_GROUND_GEOM_GRID_RATIO = 1. / 100 # Grid size for lighting.
+_FIELD_BOX_CONTACT_BIT = 1 << 7 # Use a higher bit to prevent potential clash.
+
+_DEFAULT_PITCH_SIZE = (12, 9)
+_DEFAULT_GOAL_LENGTH_RATIO = 0.33 # Goal length / pitch width.
+
+
+def _top_down_cam_fovy(size, top_camera_distance):
+ return (360 / np.pi) * np.arctan2(_TOP_CAMERA_Y_PADDING_FACTOR * max(size),
+ top_camera_distance)
+
+
+def _wall_pos_xyaxes(size):
+ """Infers position and size of bounding walls given pitch size.
+
+ Walls are placed around `ground_geom` that represents the pitch. Note that
+ the ball cannot travel beyond `field` but walkers can walk outside of the
+ `field` but not the surrounding walls.
+
+ Args:
+ size: a tuple of (length, width) of the pitch.
+
+ Returns:
+ a list of 4 tuples, each representing the position and xyaxes of a wall
+ plane. In order, walls are placed along x-negative, x-positive, y-negative,
+ y-positive relative the center of the pitch.
+ """
+ return [
+ ((0., -size[1], 0.), (-1, 0, 0, 0, 0, 1)),
+ ((0., size[1], 0.), (1, 0, 0, 0, 0, 1)),
+ ((-size[0], 0., 0.), (0, 1, 0, 0, 0, 1)),
+ ((size[0], 0., 0.), (0, -1, 0, 0, 0, 1)),
+ ]
+
+
+def _roof_size(size):
+ return (size[0], size[1], _WALL_THICKNESS)
+
+
+class Pitch(composer.Arena):
+ """A pitch with a plane, two goals and a field with position detection."""
+
+ def _build(self,
+ size=_DEFAULT_PITCH_SIZE,
+ goal_size=None,
+ top_camera_distance=_TOP_CAMERA_DISTANCE,
+ field_box=False,
+ name='pitch'):
+ """Construct a pitch with walls and position detectors.
+
+ Args:
+ size: a tuple of (length, width) of the pitch.
+ goal_size: optional (depth, width, height) indicating the goal size.
+ If not specified, the goal size is inferred from pitch size with a fixed
+ default ratio.
+ top_camera_distance: the distance of the top-down camera to the pitch.
+ field_box: adds a "field box" that collides with the ball but not the
+ walkers.
+ name: the name of this arena.
+ """
+ super(Pitch, self)._build(name=name)
+ self._size = size
+ self._goal_size = goal_size
+ self._top_camera_distance = top_camera_distance
+
+ self._top_camera = self._mjcf_root.worldbody.add(
+ 'camera',
+ name='top_down',
+ pos=[0, 0, top_camera_distance],
+ zaxis=[0, 0, 1],
+ fovy=_top_down_cam_fovy(self._size, top_camera_distance))
+
+ self._mjcf_root.visual.headlight.set_attributes(
+ ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1])
+
+ # Ensure close up geoms are rendered by egocentric cameras.
+ self._mjcf_root.visual.map.znear = 0.0005
+
+ # Build groundplane.
+ if len(self._size) != 2:
+ raise ValueError('`size` should be a sequence of length 2: got {!r}'
+ .format(self._size))
+ self._ground_texture = self._mjcf_root.asset.add(
+ 'texture',
+ type='2d',
+ builtin='checker',
+ name='groundplane',
+ rgb1=[0.3, 0.8, 0.3],
+ rgb2=[0.1, 0.6, 0.1],
+ width=300,
+ height=300,
+ mark='edge',
+ markrgb=[0.8, 0.8, 0.8])
+ self._ground_material = self._mjcf_root.asset.add(
+ 'material', name='groundplane', texture=self._ground_texture)
+ self._ground_geom = self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ material=self._ground_material,
+ size=list(self._size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO])
+
+ # Build walls.
+ self._walls = []
+ for wall_pos, wall_xyaxes in _wall_pos_xyaxes(self._size):
+ self._walls.append(
+ self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ rgba=[.1, .1, .1, .8],
+ pos=wall_pos,
+ size=[1e-7, 1e-7, 1e-7],
+ xyaxes=wall_xyaxes))
+
+ # Build goal position detectors.
+ # If field_box is enabled, offset goal by 1.0 such that ball reaches the
+ # goal position detector before bouncing off the field_box.
+ self._fb_offset = 0.5 if field_box else 0.0
+ goal_size = self._get_goal_size()
+ self._home_goal = props.PositionDetector(
+ pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0,
+ goal_size[2]),
+ size=goal_size,
+ rgba=(0, 0, 1, 0.5),
+ visible=True,
+ name='home_goal')
+ self.attach(self._home_goal)
+
+ self._away_goal = props.PositionDetector(
+ pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]),
+ size=goal_size,
+ rgba=(1, 0, 0, 0.5),
+ visible=True,
+ name='away_goal')
+ self.attach(self._away_goal)
+
+ # Build inverted field position detectors.
+ self._field = props.PositionDetector(
+ pos=(0, 0),
+ size=(self._size[0] - 2 * goal_size[0],
+ self._size[1] - 2 * goal_size[0]),
+ rgba=(1, 0, 0, 0.1),
+ inverted=True,
+ visible=True,
+ name='field')
+ self.attach(self._field)
+
+ # Build field box.
+ self._field_box = []
+ if field_box:
+ for wall_pos, wall_xyaxes in _wall_pos_xyaxes(
+ (self._field.upper - self._field.lower) / 2.0):
+ self._field_box.append(
+ self._mjcf_root.worldbody.add(
+ 'geom',
+ type='plane',
+ rgba=[.3, .3, .3, .3],
+ pos=wall_pos,
+ size=[1e-7, 1e-7, 1e-7],
+ xyaxes=wall_xyaxes))
+
+ def _get_goal_size(self):
+ goal_size = self._goal_size
+ if goal_size is None:
+ goal_size = (
+ _SIDE_WIDTH / 2,
+ self._size[1] * _DEFAULT_GOAL_LENGTH_RATIO,
+ _SIDE_WIDTH / 2,
+ )
+ return goal_size
+
+ def register_ball(self, ball):
+ self._home_goal.register_entities(ball)
+ self._away_goal.register_entities(ball)
+
+ if self._field_box:
+ # Geoms a and b collides if:
+ # (a.contype & b.conaffinity) || (b.contype & a.conaffinity) != 0.
+ # See: http://www.mujoco.org/book/computation.html#Collision
+ ball.geom.contype = (ball.geom.contype or 1) | _FIELD_BOX_CONTACT_BIT
+ for wall in self._field_box:
+ wall.conaffinity = _FIELD_BOX_CONTACT_BIT
+ wall.contype = _FIELD_BOX_CONTACT_BIT
+ else:
+ self._field.register_entities(ball)
+
+ def detected_goal(self):
+ """Returning the team that scored a goal."""
+ if self._home_goal.detected_entities:
+ return team.Team.AWAY
+ if self._away_goal.detected_entities:
+ return team.Team.HOME
+ return None
+
+ def detected_off_court(self):
+ return self._field.detected_entities
+
+ @property
+ def size(self):
+ return self._size
+
+ @property
+ def home_goal(self):
+ return self._home_goal
+
+ @property
+ def away_goal(self):
+ return self._away_goal
+
+ @property
+ def field(self):
+ return self._field
+
+ @property
+ def ground_geom(self):
+ return self._ground_geom
+
+
+class RandomizedPitch(Pitch):
+ """RandomizedPitch that randomizes its size between (min_size, max_size)."""
+
+ def __init__(self,
+ min_size,
+ max_size,
+ randomizer=None,
+ keep_aspect_ratio=False,
+ goal_size=None,
+ field_box=False,
+ top_camera_distance=_TOP_CAMERA_DISTANCE,
+ name='randomized_pitch'):
+ """Construct a randomized pitch.
+
+ Args:
+ min_size: a tuple of minimum (length, width) of the pitch.
+ max_size: a tuple of maximum (length, width) of the pitch.
+ randomizer: a callable that returns ratio between [0., 1.] that scales
+ between min_size, max_size.
+ keep_aspect_ratio: if `True`, keep the aspect ratio constant during
+ randomization.
+ goal_size: optional (depth, width, height) indicating the goal size.
+ If not specified, the goal size is inferred from pitch size with a fixed
+ default ratio.
+ field_box: optional indicating if we should construct field box containing
+ the ball (but not the walkers).
+ top_camera_distance: the distance of the top-down camera to the pitch.
+ name: the name of this arena.
+ """
+ super(RandomizedPitch, self).__init__(
+ size=max_size,
+ goal_size=goal_size,
+ top_camera_distance=top_camera_distance,
+ field_box=field_box,
+ name=name)
+
+ self._min_size = min_size
+ self._max_size = max_size
+
+ self._randomizer = randomizer or distributions.Uniform()
+ self._keep_aspect_ratio = keep_aspect_ratio
+
+ # Sample a new size and regenerate the soccer pitch.
+ logging.info('%s between (%s, %s) with %s', self.__class__.__name__,
+ min_size, max_size, self._randomizer)
+
+ def _resize_goals(self, goal_size):
+ self._home_goal.resize(
+ pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0, goal_size[2]),
+ size=goal_size)
+ self._away_goal.resize(
+ pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]),
+ size=goal_size)
+
+ def initialize_episode_mjcf(self, random_state):
+ super(RandomizedPitch, self).initialize_episode_mjcf(random_state)
+ min_len, min_wid = self._min_size
+ max_len, max_wid = self._max_size
+
+ if self._keep_aspect_ratio:
+ len_ratio = self._randomizer(random_state=random_state)
+ wid_ratio = len_ratio
+ else:
+ len_ratio = self._randomizer(random_state=random_state)
+ wid_ratio = self._randomizer(random_state=random_state)
+
+ self._size = (min_len + len_ratio * (max_len - min_len),
+ min_wid + wid_ratio * (max_wid - min_wid))
+
+ # Reset top_down camera field of view.
+ self._top_camera.fovy = _top_down_cam_fovy(self._size,
+ self._top_camera_distance)
+
+ # Resize ground geom size.
+ self._ground_geom.size = list(
+ self._size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO]
+
+ # Resize and reposition walls and roof geoms.
+ for i, (wall_pos, _) in enumerate(_wall_pos_xyaxes(self._size)):
+ self._walls[i].pos = wall_pos
+
+ goal_size = self._get_goal_size()
+ self._resize_goals(goal_size)
+
+ # Resize inverted field position detectors.
+ self._field.resize(
+ pos=(0, 0),
+ size=(self._size[0] - 2 * goal_size[0],
+ self._size[1] - 2 * goal_size[0]))
+
+ # Resize and reposition field box geoms.
+ if self._field_box:
+ for i, (pos, _) in enumerate(
+ _wall_pos_xyaxes((self._field.upper - self._field.lower) / 2.0)):
+ self._field_box[i].pos = pos
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py
new file mode 100644
index 0000000..d4e9141
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py
@@ -0,0 +1,86 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.pitch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.entities import props
+from dm_control.locomotion.soccer import pitch as pitch_lib
+from dm_control.locomotion.soccer import team as team_lib
+import numpy as np
+
+
+class PitchTest(parameterized.TestCase):
+
+ def _pitch_with_ball(self, pitch_size, ball_pos):
+ pitch = pitch_lib.Pitch(size=pitch_size)
+ self.assertEqual(pitch.size, pitch_size)
+
+ sphere = props.Primitive(geom_type='sphere', size=(0.1,), pos=ball_pos)
+ pitch.register_ball(sphere)
+ pitch.attach(sphere)
+
+ env = composer.Environment(
+ composer.NullTask(pitch), random_state=np.random.RandomState(42))
+ env.reset()
+ return pitch
+
+ def test_pitch_none_detected(self):
+ pitch = self._pitch_with_ball((12, 9), (0, 0, 0))
+ self.assertEmpty(pitch.detected_off_court())
+ self.assertIsNone(pitch.detected_goal())
+
+ def test_pitch_detected_off_court(self):
+ pitch = self._pitch_with_ball((12, 9), (20, 0, 0))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertIsNone(pitch.detected_goal())
+
+ def test_pitch_detected_away_goal(self):
+ pitch = self._pitch_with_ball((12, 9), (-9.5, 0, 1))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertEqual(team_lib.Team.AWAY, pitch.detected_goal())
+
+ def test_pitch_detected_home_goal(self):
+ pitch = self._pitch_with_ball((12, 9), (9.5, 0, 1))
+ self.assertLen(pitch.detected_off_court(), 1)
+ self.assertEqual(team_lib.Team.HOME, pitch.detected_goal())
+
+ @parameterized.parameters((True, distributions.Uniform()),
+ (False, distributions.Uniform()))
+ def test_randomize_pitch(self, keep_aspect_ratio, randomizer):
+ pitch = pitch_lib.RandomizedPitch(
+ min_size=(4, 3),
+ max_size=(8, 6),
+ randomizer=randomizer,
+ keep_aspect_ratio=keep_aspect_ratio)
+ pitch.initialize_episode_mjcf(np.random.RandomState(42))
+
+ self.assertBetween(pitch.size[0], 4, 8)
+ self.assertBetween(pitch.size[1], 3, 6)
+
+ if keep_aspect_ratio:
+ self.assertAlmostEqual((pitch.size[0] - 4) / (8. - 4.),
+ (pitch.size[1] - 3) / (6. - 3.))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png
new file mode 100644
index 0000000..3511794
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py
new file mode 100644
index 0000000..5ac8a6e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py
@@ -0,0 +1,236 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A soccer ball that keeps track of ball-player contacts."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from dm_control import mjcf
+from dm_control.entities import props
+import numpy as np
+
+from dm_control.utils import io as resources
+
+_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'soccer_ball')
+
+
+def _get_texture(name):
+ contents = resources.GetResource(
+ os.path.join(_ASSETS_PATH, '{}.png'.format(name)))
+ return mjcf.Asset(contents, '.png')
+
+
+class SoccerBall(props.Primitive):
+ """A soccer ball that keeps track of entities that come into contact."""
+
+ def _build(self, radius=0.35, mass=0.045, name='soccer_ball'):
+ """Builds this soccer ball.
+
+ Args:
+ radius: The radius (in meters) of this target sphere.
+ mass: Mass (in kilograms) of the ball.
+ name: The name of this entity.
+ """
+ super(SoccerBall, self)._build(
+ geom_type='sphere', size=(radius,), name=name)
+ texture = self._mjcf_root.asset.add(
+ 'texture',
+ name='soccer_ball',
+ type='cube',
+ fileup=_get_texture('up'),
+ filedown=_get_texture('down'),
+ filefront=_get_texture('front'),
+ fileback=_get_texture('back'),
+ fileleft=_get_texture('left'),
+ fileright=_get_texture('right'))
+ material = self._mjcf_root.asset.add(
+ 'material', name='soccer_ball', texture=texture)
+ self._geom.set_attributes(
+ pos=[0, 0, radius],
+ size=[radius],
+ condim=6,
+ friction=[.7, .075, .075],
+ mass=mass,
+ material=material)
+
+ # Add some tracking cameras for visualization and logging.
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam_near',
+ pos=[0, -2, 2],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam',
+ pos=[0, -7, 7],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+ self._mjcf_root.worldbody.add(
+ 'camera',
+ name='ball_cam_far',
+ pos=[0, -10, 10],
+ zaxis=[0, -1, 1],
+ fovy=70,
+ mode='trackcom')
+
+ # Keep track of entities to team mapping.
+ self._players = []
+
+ # Initialize tracker attributes.
+ self.initialize_entity_trackers()
+
+ def register_player(self, player):
+ self._players.append(player)
+
+ def initialize_entity_trackers(self):
+ self._last_hit = None
+ self._hit = False
+ self._repossessed = False
+ self._intercepted = False
+
+ # Tracks distance traveled by the ball in between consecutive hits.
+ self._pos_at_last_step = None
+ self._dist_since_last_hit = None
+ self._dist_between_last_hits = None
+
+ def initialize_episode(self, physics, unused_random_state):
+ self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom')
+ self._geom_id_to_player = {}
+ for player in self._players:
+ geoms = player.walker.mjcf_model.find_all('geom')
+ for geom in geoms:
+ geom_id = physics.model.name2id(geom.full_identifier, 'geom')
+ self._geom_id_to_player[geom_id] = player
+
+ self.initialize_entity_trackers()
+
+ def after_substep(self, physics, unused_random_state):
+ """Resolve contacts and update ball-player contact trackers."""
+ if self._hit:
+ # Ball has already registered a valid contact within step (during one of
+ # previous after_substep calls).
+ return
+
+ # Iterate through all contacts to find the first contact between the ball
+ # and one of the registered entities.
+ for contact in physics.data.contact:
+ # Keep contacts that involve the ball and one of the registered entities.
+ has_self = False
+ for geom_id in (contact.geom1, contact.geom2):
+ if geom_id == self._geom_id:
+ has_self = True
+ else:
+ player = self._geom_id_to_player.get(geom_id)
+
+ if has_self and player:
+ # Detected a contact between the ball and an registered player.
+ if self._last_hit is not None:
+ self._intercepted = player.team != self._last_hit.team
+ else:
+ self._intercepted = True
+
+ # Register repossessed before updating last_hit player.
+ self._repossessed = player is not self._last_hit
+ self._last_hit = player
+ # Register hit event.
+ self._hit = True
+ break
+
+ def before_step(self, physics, random_state):
+ super(SoccerBall, self).before_step(physics, random_state)
+ # Reset per simulation step indicator.
+ self._hit = False
+ self._repossessed = False
+ self._intercepted = False
+
+ def after_step(self, physics, random_state):
+ super(SoccerBall, self).after_step(physics, random_state)
+ pos = physics.bind(self._geom).xpos
+ if self._hit:
+ # SoccerBall is hit on this step. Update dist_between_last_hits
+ # to dist_since_last_hit before resetting dist_since_last_hit.
+ self._dist_between_last_hits = self._dist_since_last_hit
+ self._dist_since_last_hit = 0.
+ self._pos_at_last_step = pos.copy()
+
+ if self._dist_since_last_hit is not None:
+ # Accumulate distance traveled since last hit event.
+ self._dist_since_last_hit += np.linalg.norm(pos - self._pos_at_last_step)
+
+ self._pos_at_last_step = pos.copy()
+
+ @property
+ def last_hit(self):
+ """The player that last came in contact with the ball or `None`."""
+ return self._last_hit
+
+ @property
+ def hit(self):
+ """Indicates if the ball is hit during the last simulation step.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True: if the ball is hit by a registered player during simulation step.
+ False: if not.
+ """
+ return self._hit
+
+ @property
+ def repossessed(self):
+ """Indicates if the ball has been repossessed by a different player.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True if the ball is hit by a registered player during simulation step
+ and that player is different from `last_hit`.
+ False: if the ball is not hit, or the ball is hit by `last_hit` player.
+ """
+ return self._repossessed
+
+ @property
+ def intercepted(self):
+ """Indicates if the ball has been intercepted by a different team.
+
+ For a timeline shown below:
+ ..., agent.step, simulation, agent.step, ...
+
+ Returns:
+ True: if the ball is hit for the first time, or repossessed by an player
+ from a different team.
+ False: if the ball is not hit, not repossessed, or repossessed by a
+ teammate to `last_hit`.
+ """
+ return self._intercepted
+
+ @property
+ def dist_between_last_hits(self):
+ """Distance between last consecutive hits.
+
+ Returns:
+ Distance between last two consecutive hit events or `None` if there has
+ not been two consecutive hits on the ball.
+ """
+ return self._dist_between_last_hits
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py
new file mode 100644
index 0000000..372b703
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py
@@ -0,0 +1,73 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.soccer.soccer_ball."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.entities import props
+from dm_control.locomotion.soccer import soccer_ball
+from dm_control.locomotion.soccer import team
+import numpy as np
+
+
+class SoccerBallTest(absltest.TestCase):
+
+ def test_detect_hit(self):
+ arena = composer.Arena()
+ ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball')
+ player = team.Player(
+ team=team.Team.HOME,
+ walker=props.Primitive(geom_type='sphere', size=(0.1,), name='home'))
+ arena.add_free_entity(player.walker)
+ ball.register_player(player)
+ arena.add_free_entity(ball)
+
+ random_state = np.random.RandomState(42)
+ physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model)
+ physics.step()
+
+ ball.initialize_episode(physics, random_state)
+ ball.before_step(physics, random_state)
+ self.assertEqual(ball.hit, False)
+ self.assertEqual(ball.repossessed, False)
+ self.assertEqual(ball.intercepted, False)
+ self.assertIsNone(ball.last_hit)
+ self.assertIsNone(ball.dist_between_last_hits)
+
+ ball.after_substep(physics, random_state)
+ ball.after_step(physics, random_state)
+
+ self.assertEqual(ball.hit, True)
+ self.assertEqual(ball.repossessed, True)
+ self.assertEqual(ball.intercepted, True)
+ self.assertEqual(ball.last_hit, player)
+ # Only one hit registered.
+ self.assertIsNone(ball.dist_between_last_hits)
+
+ def test_has_tracking_cameras(self):
+ ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball')
+ expected_camera_names = ['ball_cam_near', 'ball_cam', 'ball_cam_far']
+ camera_names = [cam.name for cam in ball.mjcf_model.find_all('camera')]
+ self.assertCountEqual(expected_camera_names, camera_names)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py
new file mode 100644
index 0000000..e3f5044
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py
@@ -0,0 +1,191 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+""""A task where players play a soccer game."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.locomotion.soccer import initializers
+from dm_control.locomotion.soccer import observables as observables_lib
+from dm_control.locomotion.soccer import soccer_ball
+from dm_env import specs
+import numpy as np
+from six.moves import zip
+
+_THROW_IN_BALL_Z = 0.5
+
+
+def _disable_geom_contacts(entities):
+ for entity in entities:
+ mjcf_model = entity.mjcf_model
+ for geom in mjcf_model.find_all("geom"):
+ geom.set_attributes(contype=0)
+
+
+class Task(composer.Task):
+ """A task where two teams of walkers play soccer."""
+
+ def __init__(self,
+ players,
+ arena,
+ ball=None,
+ initializer=None,
+ observables=None,
+ disable_walker_contacts=False,
+ nconmax_per_player=200,
+ njmax_per_player=200,
+ control_timestep=0.025):
+ """Construct an instance of soccer.Task.
+
+ This task implements the high-level game logic of multi-agent MuJoCo soccer.
+
+ Args:
+ players: a sequence of `soccer.Player` instances, representing
+ participants to the game from both teams.
+ arena: an instance of `soccer.Pitch`, implementing the physical geoms and
+ the sensors associated with the pitch.
+ ball: optional instance of `soccer.SoccerBall`, implementing the physical
+ geoms and sensors associated with the soccer ball. If None, defaults to
+ using `soccer_ball.SoccerBall()`.
+ initializer: optional instance of `soccer.Initializer` that initializes
+ the task at the start of each episode. If None, defaults to
+ `initializers.UniformInitializer()`.
+ observables: optional instance of `soccer.ObservablesAdder` that adds
+ observables for each player. If None, defaults to
+ `observables.CoreObservablesAdder()`.
+ disable_walker_contacts: if `True`, disable physical contacts between
+ players.
+ nconmax_per_player: allocated maximum number of contacts per player. It
+ may be necessary to increase this value if you encounter errors due to
+ `mjWARN_CONTACTFULL`.
+ njmax_per_player: allocated maximum number of scalar constraints per
+ player. It may be necessary to increase this value if you encounter
+ errors due to `mjWARN_CNSTRFULL`.
+ control_timestep: control timestep of the agent.
+ """
+ self.arena = arena
+ self.players = players
+
+ self._initializer = initializer or initializers.UniformInitializer()
+ self._observables = observables or observables_lib.CoreObservablesAdder()
+
+ if disable_walker_contacts:
+ _disable_geom_contacts([p.walker for p in self.players])
+
+ # Create ball and attach ball to arena.
+ self.ball = ball or soccer_ball.SoccerBall()
+ self.arena.add_free_entity(self.ball)
+ self.arena.register_ball(self.ball)
+
+ # Register soccer ball contact tracking for players.
+ for player in self.players:
+ player.walker.create_root_joints(self.arena.attach(player.walker))
+ self.ball.register_player(player)
+ # Add per-walkers observables.
+ self._observables(self, player)
+
+ self.set_timesteps(
+ physics_timestep=0.005, control_timestep=control_timestep)
+ self.root_entity.mjcf_model.size.nconmax = nconmax_per_player * len(players)
+ self.root_entity.mjcf_model.size.njmax = njmax_per_player * len(players)
+
+ @property
+ def observables(self):
+ observables = []
+ for player in self.players:
+ observables.append(
+ player.walker.observables.as_dict(fully_qualified=False))
+ return observables
+
+ def _throw_in(self, physics, random_state, ball):
+ x, y, _ = physics.bind(ball.geom).xpos
+ shrink_x, shrink_y = random_state.uniform([0.7, 0.7], [0.9, 0.9])
+ ball.set_pose(physics, [x * shrink_x, y * shrink_y, _THROW_IN_BALL_Z])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+ ball.initialize_entity_trackers()
+
+ def initialize_episode_mjcf(self, random_state):
+ self.arena.initialize_episode_mjcf(random_state)
+
+ def initialize_episode(self, physics, random_state):
+ self.arena.initialize_episode(physics, random_state)
+ self._initializer(self, physics, random_state)
+
+ @property
+ def root_entity(self):
+ return self.arena
+
+ def get_reward(self, physics):
+ """Returns a list of per-player rewards.
+
+ Each player will receive a reward of:
+ +1 if their team scored a goal
+ -1 if their team conceded a goal
+ 0 if no goals were scored on this timestep.
+
+ Note: the observations also contain various environment statistics that may
+ be used to derive per-player rewards (as done in
+ http://arxiv.org/abs/1902.07151).
+
+ Args:
+ physics: An instance of `Physics`.
+
+ Returns:
+ A list of 0-dimensional numpy arrays, one per player.
+ """
+ scoring_team = self.arena.detected_goal()
+ if not scoring_team:
+ return [np.zeros((), dtype=np.float32) for _ in self.players]
+
+ rewards = []
+ for p in self.players:
+ if p.team == scoring_team:
+ rewards.append(np.ones((), dtype=np.float32))
+ else:
+ rewards.append(-np.ones((), dtype=np.float32))
+ return rewards
+
+ def get_reward_spec(self):
+ return [
+ specs.Array(name="reward", shape=(), dtype=np.float32)
+ for _ in self.players
+ ]
+
+ def get_discount(self, physics):
+ if self.arena.detected_goal():
+ return np.zeros((), np.float32)
+ return np.ones((), np.float32)
+
+ def get_discount_spec(self):
+ return specs.Array(name="discount", shape=(), dtype=np.float32)
+
+ def should_terminate_episode(self, physics):
+ """Returns True if a goal was scored by either team."""
+ return self.arena.detected_goal() is not None
+
+ def before_step(self, physics, actions, random_state):
+ for player, action in zip(self.players, actions):
+ player.walker.apply_action(physics, action, random_state)
+
+ if self.arena.detected_off_court():
+ self._throw_in(physics, random_state, self.ball)
+
+ def action_spec(self, physics):
+ """Return multi-agent action_spec."""
+ return [player.walker.action_spec for player in self.players]
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py
new file mode 100644
index 0000000..c4dfacc
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py
@@ -0,0 +1,543 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.tasks.soccer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.locomotion import soccer
+from dm_control.locomotion.soccer import initializers
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+from six.moves import range
+from six.moves import zip
+
+RGBA_BLUE = [.1, .1, .8, 1.]
+RGBA_RED = [.8, .1, .1, 1.]
+
+
+def _walker(name, walker_id, marker_rgba):
+ return soccer.BoxHead(
+ name=name,
+ walker_id=walker_id,
+ marker_rgba=marker_rgba,
+ )
+
+
+def _team_players(team_size, team, team_name, team_color):
+ team_of_players = []
+ for i in range(team_size):
+ team_of_players.append(
+ soccer.Player(team, _walker("%s%d" % (team_name, i), i, team_color)))
+ return team_of_players
+
+
+def _home_team(team_size):
+ return _team_players(team_size, soccer.Team.HOME, "home", RGBA_BLUE)
+
+
+def _away_team(team_size):
+ return _team_players(team_size, soccer.Team.AWAY, "away", RGBA_RED)
+
+
+def _env(players, disable_walker_contacts=True, observables=None,
+ random_state=42, **task_kwargs):
+ return composer.Environment(
+ task=soccer.Task(
+ players=players,
+ arena=soccer.Pitch((20, 15)),
+ observables=observables,
+ disable_walker_contacts=disable_walker_contacts,
+ **task_kwargs
+ ),
+ random_state=random_state,
+ time_limit=1)
+
+
+def _observables_adder(observables_adder):
+ if observables_adder == "core":
+ return soccer.CoreObservablesAdder()
+ if observables_adder == "core_interception":
+ return soccer.MultiObservablesAdder(
+ [soccer.CoreObservablesAdder(),
+ soccer.InterceptionObservablesAdder()])
+ raise ValueError("Unrecognized observable_adder %s" % observables_adder)
+
+
+class TaskTest(parameterized.TestCase):
+
+ def _assert_all_count_equal(self, list_of_lists):
+ """Check all lists in the list are count equal."""
+ if not list_of_lists:
+ return
+
+ first = sorted(list_of_lists[0])
+ for other in list_of_lists[1:]:
+ self.assertCountEqual(first, other)
+
+ @parameterized.named_parameters(
+ ("1vs1_core", 1, "core", 33, True),
+ ("2vs2_core", 2, "core", 43, True),
+ ("1vs1_interception", 1, "core_interception", 41, True),
+ ("2vs2_interception", 2, "core_interception", 51, True),
+ ("1vs1_core_contact", 1, "core", 33, False),
+ ("2vs2_core_contact", 2, "core", 43, False),
+ ("1vs1_interception_contact", 1, "core_interception", 41, False),
+ ("2vs2_interception_contact", 2, "core_interception", 51, False),
+ )
+ def test_step_environment(self, team_size, observables_adder, num_obs,
+ disable_walker_contacts):
+ env = _env(
+ _home_team(team_size) + _away_team(team_size),
+ observables=_observables_adder(observables_adder),
+ disable_walker_contacts=disable_walker_contacts)
+ self.assertLen(env.action_spec(), 2 * team_size)
+ self.assertLen(env.observation_spec(), 2 * team_size)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ for observation, spec in zip(timestep.observation, env.observation_spec()):
+ self.assertLen(spec, num_obs)
+ self.assertCountEqual(list(observation.keys()), list(spec.keys()))
+ for key in observation.keys():
+ self.assertEqual(observation[key].shape, spec[key].shape)
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ # TODO(b/124848293): consolidate environment stepping loop for task tests.
+ @parameterized.named_parameters(
+ ("1vs2", 1, 2, 38),
+ ("2vs1", 2, 1, 38),
+ ("3vs0", 3, 0, 38),
+ ("0vs2", 0, 2, 33),
+ ("2vs2", 2, 2, 43),
+ ("0vs0", 0, 0, None),
+ )
+ def test_num_players(self, home_size, away_size, num_observations):
+ env = _env(_home_team(home_size) + _away_team(away_size))
+ self.assertLen(env.action_spec(), home_size + away_size)
+ self.assertLen(env.observation_spec(), home_size + away_size)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ # Members of the same team should have identical specs.
+ self._assert_all_count_equal(
+ [spec.keys() for spec in env.observation_spec()[:home_size]])
+ self._assert_all_count_equal(
+ [spec.keys() for spec in env.observation_spec()[-away_size:]])
+
+ for observation, spec in zip(timestep.observation, env.observation_spec()):
+ self.assertCountEqual(list(observation.keys()), list(spec.keys()))
+ for key in observation.keys():
+ self.assertEqual(observation[key].shape, spec[key].shape)
+
+ self.assertLen(spec, num_observations)
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ self.assertLen(timestep.observation, home_size + away_size)
+
+ self.assertLen(timestep.reward, home_size + away_size)
+ for player_spec, player_reward in zip(env.reward_spec(), timestep.reward):
+ player_spec.validate(player_reward)
+
+ discount_spec = env.discount_spec()
+ discount_spec.validate(timestep.discount)
+
+ def test_all_contacts(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _all_contact_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 5.])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 0., 0., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 3.], quat)
+ walkers[0].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 0., 0., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 1.], quat)
+ walkers[1].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ env.add_extra_hook("initialize_episode", _all_contact_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ def test_symmetric_observations(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _symmetric_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ ball.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = 5., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ x, y, rotation = -5., -3., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+ env.add_extra_hook("initialize_episode", _symmetric_configuration)
+
+ timestep = env.reset()
+ obs_a, obs_b = timestep.observation
+ self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys()))
+ for k in sorted(obs_a.keys()):
+ o_a, o_b = obs_a[k], obs_b[k]
+ np.testing.assert_allclose(o_a, o_b, err_msg=k + " not equal.", atol=1e-6)
+
+ def test_symmetric_dynamic_observations(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _symmetric_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting up. Walkers going tangent.
+ ball.set_velocity(physics, velocity=[0., 0., 1.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 5., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[y, -x, 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = -5., -3., np.pi / 3. + np.pi
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[y, -x, 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _symmetric_configuration)
+
+ timestep = env.reset()
+ obs_a, obs_b = timestep.observation
+ self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys()))
+ for k in sorted(obs_a.keys()):
+ o_a, o_b = obs_a[k], obs_b[k]
+ np.testing.assert_allclose(o_a, o_b, err_msg=k + " not equal.", atol=1e-6)
+
+ def test_prev_actions(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ actions = []
+ for i, player in enumerate(env.task.players):
+ spec = player.walker.action_spec
+ actions.append((i + 1) * np.ones(spec.shape, dtype=spec.dtype))
+
+ env.reset()
+ timestep = env.step(actions)
+
+ for walker_idx, obs in enumerate(timestep.observation):
+ np.testing.assert_allclose(
+ np.squeeze(obs["prev_action"], axis=0),
+ actions[walker_idx],
+ err_msg="Walker {}: incorrect previous action.".format(walker_idx))
+
+ @parameterized.named_parameters(
+ dict(testcase_name="1vs2_draw",
+ home_size=1, away_size=2, ball_vel_x=0, expected_home_score=0),
+ dict(testcase_name="1vs2_home_score",
+ home_size=1, away_size=2, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="2vs1_away_score",
+ home_size=2, away_size=1, ball_vel_x=-50, expected_home_score=-1),
+ dict(testcase_name="3vs0_home_score",
+ home_size=3, away_size=0, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="0vs2_home_score",
+ home_size=0, away_size=2, ball_vel_x=50, expected_home_score=1),
+ dict(testcase_name="2vs2_away_score",
+ home_size=2, away_size=2, ball_vel_x=-50, expected_home_score=-1),
+ )
+ def test_scoring_rewards(
+ self, home_size, away_size, ball_vel_x, expected_home_score):
+ env = _env(_home_team(home_size) + _away_team(away_size))
+
+ def _score_configuration(physics, random_state):
+ del random_state # Unused.
+ # Send the ball shooting towards either the home or away goal.
+ env.task.ball.set_pose(physics, [0., 0., 0.5])
+ env.task.ball.set_velocity(physics,
+ velocity=[ball_vel_x, 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _score_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ # Disable contacts and gravity so that the ball follows a straight path.
+ with env.physics.model.disable("contact", "gravity"):
+
+ timestep = env.reset()
+ with self.subTest("Reward and discount are None on the first timestep"):
+ self.assertTrue(timestep.first())
+ self.assertIsNone(timestep.reward)
+ self.assertIsNone(timestep.discount)
+
+ # Step until the episode ends.
+ timestep = env.step(actions)
+ while not timestep.last():
+ self.assertTrue(timestep.mid())
+ # For non-terminal timesteps, the reward should always be 0 and the
+ # discount should always be 1.
+ np.testing.assert_array_equal(np.hstack(timestep.reward), 0.)
+ self.assertEqual(timestep.discount, 1.)
+ timestep = env.step(actions)
+
+ # If a goal was scored then the epsiode should have ended with a discount of
+ # 0. If neither team scored and the episode ended due to hitting the time
+ # limit then the discount should be 1.
+ with self.subTest("Correct terminal discount"):
+ if expected_home_score != 0:
+ expected_discount = 0.
+ else:
+ expected_discount = 1.
+ self.assertEqual(timestep.discount, expected_discount)
+
+ with self.subTest("Correct terminal reward"):
+ reward = np.hstack(timestep.reward)
+ np.testing.assert_array_equal(reward[:home_size], expected_home_score)
+ np.testing.assert_array_equal(reward[home_size:], -expected_home_score)
+
+ def test_throw_in(self):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _throw_in_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 3., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting out of bounds.
+ ball.set_velocity(physics, velocity=[0., 50., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 0., -3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+ x, y, rotation = 0., -5., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _throw_in_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ terminal_ball_vel = np.linalg.norm(
+ timestep.observation[0]["ball_ego_linear_velocity"])
+ self.assertAlmostEqual(terminal_ball_vel, 0.)
+
+ @parameterized.named_parameters(("score", 50., 0.), ("timeout", 0., 1.))
+ def test_terminal_discount(self, init_ball_vel_x, expected_terminal_discount):
+ env = _env(_home_team(1) + _away_team(1))
+
+ def _initial_configuration(physics, unused_random_state):
+ walkers = [p.walker for p in env.task.players]
+ ball = env.task.ball
+
+ x, y, rotation = 0., 0., np.pi / 6.
+ ball.set_pose(physics, [x, y, 0.5])
+ # Ball shooting up. Walkers going tangent.
+ ball.set_velocity(physics, velocity=[init_ball_vel_x, 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ x, y, rotation = 0., -3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[0].set_pose(physics, [x, y, 0.], quat)
+ walkers[0].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+ x, y, rotation = 0., 3., np.pi / 3.
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ walkers[1].set_pose(physics, [x, y, 0.], quat)
+ walkers[1].set_velocity(physics, velocity=[0., 0., 0.],
+ angular_velocity=[0., 0., 0.])
+
+ env.add_extra_hook("initialize_episode", _initial_configuration)
+
+ actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()]
+
+ timestep = env.reset()
+
+ while not timestep.last():
+ timestep = env.step(actions)
+
+ self.assertEqual(timestep.discount, expected_terminal_discount)
+
+
+class UniformInitializerTest(parameterized.TestCase):
+
+ @parameterized.parameters([0.3, 0.7])
+ def test_walker_position(self, spawn_ratio):
+ initializer = initializers.UniformInitializer(spawn_ratio=spawn_ratio)
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ root_bodies = [p.walker.root_body for p in env.task.players]
+ xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio
+ env.reset()
+ xy = env.physics.bind(root_bodies).xpos[:, :2].copy()
+ with self.subTest("X and Y positions within bounds"):
+ if np.any(abs(xy) > xy_bounds):
+ self.fail("Walker(s) spawned out of bounds. Expected abs(xy) "
+ "<= {}, got:\n{}".format(xy_bounds, xy))
+ env.reset()
+ xy2 = env.physics.bind(root_bodies).xpos[:, :2].copy()
+ with self.subTest("X and Y positions change after reset"):
+ if np.any(xy == xy2):
+ self.fail("Walker(s) have the same X and/or Y coordinates before and "
+ "after reset. Before: {}, after: {}.".format(xy, xy2))
+
+ def test_walker_rotation(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+
+ def quats_to_eulers(quats):
+ eulers = np.empty((len(quats), 3), dtype=np.double)
+ dt = 1.
+ for i, quat in enumerate(quats):
+ mjbindings.mjlib.mju_quat2Vel(eulers[i], quat, dt)
+ return eulers
+
+ # TODO(b/132671988): Switch to using `get_pose` to get the quaternion once
+ # `BoxHead.get_pose` and `BoxHead.set_pose` are
+ # implemented in a consistent way.
+ def get_quat(walker):
+ return env.physics.bind(walker.root_body).xquat
+
+ env.reset()
+ quats = [get_quat(p.walker) for p in env.task.players]
+ eulers = quats_to_eulers(quats)
+ with self.subTest("Rotation is about the Z-axis only"):
+ np.testing.assert_array_equal(eulers[:, :2], 0.)
+
+ env.reset()
+ quats2 = [get_quat(p.walker) for p in env.task.players]
+ eulers2 = quats_to_eulers(quats2)
+ with self.subTest("Rotation about Z changes after reset"):
+ if np.any(eulers[:, 2] == eulers2[:, 2]):
+ self.fail("Walker(s) have the same rotation about Z before and "
+ "after reset. Before: {}, after: {}."
+ .format(eulers[:, 2], eulers2[:, 2]))
+
+ # TODO(b/132759890): Remove `expectedFailure` decorator once `set_velocity`
+ # works correctly for the `BoxHead` walker.
+ @unittest.expectedFailure
+ def test_walker_velocity(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ root_joints = []
+ non_root_joints = []
+ for player in env.task.players:
+ attachment_frame = mjcf.get_attachment_frame(player.walker.mjcf_model)
+ root_joints.extend(
+ attachment_frame.find_all("joint", immediate_children_only=True))
+ non_root_joints.extend(player.walker.mjcf_model.find_all("joint"))
+ # Assign a non-zero sentinel value to the velocities of all root and
+ # non-root joints.
+ sentinel_velocity = 3.14
+ env.physics.bind(root_joints + non_root_joints).qvel = sentinel_velocity
+ # The initializer should zero the velocities of the root joints, but not the
+ # non-root joints.
+ initializer(env.task, env.physics, env.random_state)
+ np.testing.assert_array_equal(env.physics.bind(non_root_joints).qvel,
+ sentinel_velocity)
+ np.testing.assert_array_equal(env.physics.bind(root_joints).qvel, 0.)
+
+ @parameterized.parameters([
+ dict(spawn_ratio=0.3, init_ball_z=0.4),
+ dict(spawn_ratio=0.5, init_ball_z=0.6),
+ ])
+ def test_ball_position(self, spawn_ratio, init_ball_z):
+ initializer = initializers.UniformInitializer(
+ spawn_ratio=spawn_ratio, init_ball_z=init_ball_z)
+ env = _env(_home_team(2) + _away_team(2), initializer=initializer)
+ xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio
+ env.reset()
+ position, _ = env.task.ball.get_pose(env.physics)
+ xyz = position.copy()
+ with self.subTest("X and Y positions within bounds"):
+ if np.any(abs(xyz[:2]) > xy_bounds):
+ self.fail("Ball spawned out of bounds. Expected abs(xy) "
+ "<= {}, got:\n{}".format(xy_bounds, xyz[:2]))
+ with self.subTest("Z position equal to `init_ball_z`"):
+ self.assertEqual(xyz[2], init_ball_z)
+ env.reset()
+ position, _ = env.task.ball.get_pose(env.physics)
+ xyz2 = position.copy()
+ with self.subTest("X and Y positions change after reset"):
+ if np.any(xyz[:2] == xyz2[:2]):
+ self.fail("Ball has the same XY position before and after reset. "
+ "Before: {}, after: {}.".format(xyz[:2], xyz2[:2]))
+
+ def test_ball_velocity(self):
+ initializer = initializers.UniformInitializer()
+ env = _env(_home_team(1) + _away_team(1), initializer=initializer)
+ ball_root_joint = mjcf.get_frame_freejoint(env.task.ball.mjcf_model)
+ # Set the velocities of the ball root joint to a non-zero sentinel value.
+ env.physics.bind(ball_root_joint).qvel = 3.14
+ initializer(env.task, env.physics, env.random_state)
+ # The initializer should set the ball velocity to zero.
+ ball_velocity = env.physics.bind(ball_root_joint).qvel
+ np.testing.assert_array_equal(ball_velocity, 0.)
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py
new file mode 100644
index 0000000..ce49e3b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py
@@ -0,0 +1,32 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Define teams and players participating in a match."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import enum
+
+
+class Team(enum.Enum):
+ HOME = 0
+ AWAY = 1
+
+
+Player = collections.namedtuple('Player', ['team', 'walker'])
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py
new file mode 100644
index 0000000..f83b97a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tasks in the Locomotion library."""
+
+
+from dm_control.locomotion.tasks.corridors import RunThroughCorridor
+from dm_control.locomotion.tasks.escape import Escape
+from dm_control.locomotion.tasks.go_to_target import GoToTarget
+from dm_control.locomotion.tasks.random_goal_maze import ManyGoalsMaze
+from dm_control.locomotion.tasks.random_goal_maze import ManyHeterogeneousGoalsMaze
+from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMaze
+from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMazeAugmentedWithTargets
+from dm_control.locomotion.tasks.reach import TwoTouch
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py
new file mode 100644
index 0000000..f865643
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py
@@ -0,0 +1,161 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Corridor-based locomotion tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.utils import rewards
+import numpy as np
+
+
+class RunThroughCorridor(composer.Task):
+ """A task that requires a walker to run through a corridor.
+
+ This task rewards an agent for controlling a walker to move at a specific
+ target velocity along the corridor, and for minimising the magnitude of the
+ control signals used to achieve this.
+ """
+
+ def __init__(self,
+ walker,
+ arena,
+ walker_spawn_position=(0, 0, 0),
+ walker_spawn_rotation=None,
+ target_velocity=3.0,
+ contact_termination=True,
+ terminate_at_height=-0.5,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas.corridors.Corridor`.
+ walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ target_velocity: a number specifying the target velocity (in meters per
+ second) for the walker.
+ contact_termination: whether to terminate if a non-foot geom touches the
+ ground.
+ terminate_at_height: a number specifying the height of end effectors below
+ which the episode terminates.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ self._walker_spawn_position = walker_spawn_position
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ enabled_observables.append(self._walker.observables.egocentric_camera)
+ for observable in enabled_observables:
+ observable.enabled = True
+
+ self._vel = target_velocity
+ self._contact_termination = contact_termination
+ self._terminate_at_height = terminate_at_height
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state)
+ self._arena.mjcf_model.visual.map.znear = 0.00025
+ self._arena.mjcf_model.visual.map.zfar = 4.
+
+ def initialize_episode(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+ if self._walker_spawn_rotation:
+ rotation = variation.evaluate(
+ self._walker_spawn_rotation, random_state=random_state)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ else:
+ quat = None
+ self._walker.shift_pose(
+ physics,
+ position=variation.evaluate(
+ self._walker_spawn_position, random_state=random_state),
+ quaternion=quat,
+ rotate_velocity=True)
+
+ self._failure_termination = False
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._arena.ground_geoms).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def before_step(self, physics, action, random_state):
+ self._walker.apply_action(physics, action, random_state)
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ if self._contact_termination:
+ for c in physics.data.contact:
+ if self._is_disallowed_contact(c):
+ self._failure_termination = True
+ break
+ if self._terminate_at_height is not None:
+ if any(physics.bind(self._walker.end_effectors).xpos[:, -1] <
+ self._terminate_at_height):
+ self._failure_termination = True
+
+ def get_reward(self, physics):
+ walker_xvel = physics.bind(self._walker.root_body).subtree_linvel[0]
+ xvel_term = rewards.tolerance(
+ walker_xvel, (self._vel, self._vel),
+ margin=self._vel,
+ sigmoid='linear',
+ value_at_margin=0.0)
+ return xvel_term
+
+ def should_terminate_episode(self, physics):
+ return self._failure_termination
+
+ def get_discount(self, physics):
+ if self._failure_termination:
+ return 0.
+ else:
+ return 1.
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py
new file mode 100644
index 0000000..c7ca7c9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py
@@ -0,0 +1,138 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.tasks.corridors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.variation import deterministic
+from dm_control.composer.variation import rotations
+from dm_control.locomotion.arenas import corridors as corridor_arenas
+from dm_control.locomotion.tasks import corridors as corridor_tasks
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+from six.moves import range
+
+
+class CorridorsTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ dict(position_offset=(0, 0, 0),
+ rotate_180_degrees=False,
+ use_variations=False),
+ dict(position_offset=(1, 2, 3),
+ rotate_180_degrees=True,
+ use_variations=True))
+ def test_walker_is_correctly_reinitialized(
+ self, position_offset, rotate_180_degrees, use_variations):
+ walker_spawn_position = position_offset
+
+ if not rotate_180_degrees:
+ walker_spawn_rotation = None
+ else:
+ walker_spawn_rotation = np.pi
+
+ if use_variations:
+ walker_spawn_position = deterministic.Constant(position_offset)
+ walker_spawn_rotation = deterministic.Constant(walker_spawn_rotation)
+
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = corridor_arenas.EmptyCorridor()
+ task = corridor_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=walker_spawn_position,
+ walker_spawn_rotation=walker_spawn_rotation)
+
+ # Randomize the initial pose and joint positions in order to check that they
+ # are set correctly by `initialize_episode`.
+ random_state = np.random.RandomState(12345)
+ task.initialize_episode_mjcf(random_state)
+ physics = mjcf.Physics.from_mjcf_model(task.root_entity.mjcf_model)
+
+ walker_joints = walker.mjcf_model.find_all('joint')
+ physics.bind(walker_joints).qpos = random_state.uniform(
+ size=len(walker_joints))
+ walker.set_pose(physics,
+ position=random_state.uniform(size=3),
+ quaternion=rotations.UniformQuaternion()(random_state))
+
+ task.initialize_episode(physics, random_state)
+ physics.forward()
+
+ with self.subTest('Correct joint positions'):
+ walker_qpos = physics.bind(walker_joints).qpos
+ if walker.upright_pose.qpos is not None:
+ np.testing.assert_array_equal(walker_qpos, walker.upright_pose.qpos)
+ else:
+ walker_qpos0 = physics.bind(walker_joints).qpos0
+ np.testing.assert_array_equal(walker_qpos, walker_qpos0)
+
+ walker_xpos, walker_xquat = walker.get_pose(physics)
+
+ with self.subTest('Correct position'):
+ expected_xpos = walker.upright_pose.xpos + np.array(position_offset)
+ np.testing.assert_array_equal(walker_xpos, expected_xpos)
+
+ with self.subTest('Correct orientation'):
+ upright_xquat = walker.upright_pose.xquat.copy()
+ upright_xquat /= np.linalg.norm(walker.upright_pose.xquat)
+ if rotate_180_degrees:
+ expected_xquat = (-upright_xquat[3], -upright_xquat[2],
+ upright_xquat[1], upright_xquat[0])
+ else:
+ expected_xquat = upright_xquat
+ np.testing.assert_allclose(walker_xquat, expected_xquat)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = corridor_arenas.EmptyCorridor()
+ task = corridor_tasks.RunThroughCorridor(walker, arena)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ self.assertEqual(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+ self.assertEqual(task.get_discount(env.physics), 0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py
new file mode 100644
index 0000000..1651a34
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py
@@ -0,0 +1,188 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Escape locomotion tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable as base_observable
+from dm_control.rl import control
+from dm_control.utils import rewards
+
+import numpy as np
+
+# Constants related to terrain generation.
+_HEIGHTFIELD_ID = 0
+
+
+class Escape(composer.Task):
+ """A task solved by escaping a starting area (e.g. bowl-shaped terrain)."""
+
+ def __init__(self,
+ walker,
+ arena,
+ walker_spawn_position=(0, 0, 0),
+ walker_spawn_rotation=None,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas`.
+ walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ self._walker_spawn_position = walker_spawn_position
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ enabled_observables.append(self._walker.observables.egocentric_camera)
+ for observable in enabled_observables:
+ observable.enabled = True
+
+ if 'CMUHumanoid' in str(type(self._walker)):
+ core_body = 'walker/root'
+ self._reward_body = 'walker/root'
+ elif 'Rat' in str(type(self._walker)):
+ core_body = 'walker/torso'
+ self._reward_body = 'walker/head'
+ else:
+ raise ValueError('Expects Rat or CMUHumanoid.')
+
+ def _origin(physics):
+ """Returns origin position in the torso frame."""
+ torso_frame = physics.named.data.xmat[core_body].reshape(3, 3)
+ torso_pos = physics.named.data.xpos[core_body]
+ return -torso_pos.dot(torso_frame)
+
+ self._walker.observables.add_observable(
+ 'origin', base_observable.Generic(_origin))
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def initialize_episode_mjcf(self, random_state):
+ if hasattr(self._arena, 'regenerate'):
+ self._arena.regenerate(random_state)
+ self._arena.mjcf_model.visual.map.znear = 0.00025
+ self._arena.mjcf_model.visual.map.zfar = 50.
+
+ def initialize_episode(self, physics, random_state):
+ super(Escape, self).initialize_episode(physics, random_state)
+
+ # Initial configuration.
+ orientation = random_state.randn(4)
+ orientation /= np.linalg.norm(orientation)
+ _find_non_contacting_height(physics, self._walker, orientation)
+
+ def get_reward(self, physics):
+ # Escape reward term.
+ terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0]
+ escape_reward = rewards.tolerance(
+ np.asarray(np.linalg.norm(
+ physics.named.data.site_xpos[self._reward_body])),
+ bounds=(terrain_size, float('inf')),
+ margin=terrain_size,
+ value_at_margin=0,
+ sigmoid='linear')
+ upright_reward = _upright_reward(physics, self._walker, deviation_angle=30)
+ return upright_reward * escape_reward
+
+ def get_discount(self, physics):
+ return 1.
+
+
+def _find_non_contacting_height(physics, walker, orientation,
+ x_pos=0.0, y_pos=0.0, maxiter=1000):
+ """Find a height with no contacts given a body orientation.
+
+ Args:
+ physics: An instance of `Physics`.
+ walker: the focal walker.
+ orientation: A quaternion.
+ x_pos: A float. Position along global x-axis.
+ y_pos: A float. Position along global y-axis.
+ maxiter: maximum number of iterations to try
+ """
+ z_pos = 0.0 # Start embedded in the floor.
+ num_contacts = 1
+ count = 1
+ # Move up in 1cm increments until no contacts.
+ while num_contacts > 0:
+ try:
+ with physics.reset_context():
+ freejoint = mjcf.get_frame_freejoint(walker.mjcf_model)
+ physics.bind(freejoint).qpos[:3] = x_pos, y_pos, z_pos
+ physics.bind(freejoint).qpos[3:] = orientation
+ except control.PhysicsError:
+ # We may encounter a PhysicsError here due to filling the contact
+ # buffer, in which case we simply increment the height and continue.
+ pass
+ num_contacts = physics.data.ncon
+ z_pos += 0.01
+ count += 1
+ if count > maxiter:
+ raise ValueError(
+ 'maxiter reached: possibly contacts in null pose of body.'
+ )
+
+
+def _upright_reward(physics, walker, deviation_angle=0):
+ """Returns a reward proportional to how upright the torso is.
+
+ Args:
+ physics: an instance of `Physics`.
+ walker: the focal walker.
+ deviation_angle: A float, in degrees. The reward is 0 when the torso is
+ exactly upside-down and 1 when the torso's z-axis is less than
+ `deviation_angle` away from the global z-axis.
+ """
+ deviation = np.cos(np.deg2rad(deviation_angle))
+ upright_torso = physics.bind(walker.root_body).xmat[-1]
+ if hasattr(walker, 'pelvis_body'):
+ upright_pelvis = physics.bind(walker.pelvis_body).xmat[-1]
+ upright_zz = np.stack([upright_torso, upright_pelvis])
+ else:
+ upright_zz = upright_torso
+ upright = rewards.tolerance(upright_zz,
+ bounds=(deviation, float('inf')),
+ sigmoid='linear',
+ margin=1 + deviation,
+ value_at_margin=0)
+ return np.min(upright)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py
new file mode 100644
index 0000000..1efecca
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py
@@ -0,0 +1,90 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.escape."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import bowl
+from dm_control.locomotion.tasks import escape
+from dm_control.locomotion.walkers import rodent
+
+import numpy as np
+from six.moves import range
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+class EscapeTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = rodent.Rat()
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+ def test_contact(self):
+ walker = rodent.Rat()
+
+ # Build a corridor-shaped arena that is obstructed by walls.
+ arena = bowl.Bowl(
+ size=(20., 20.),
+ aesthetic='outdoor_natural')
+
+ # Build a task that rewards the agent for running down the corridor at a
+ # specific velocity.
+ task = escape.Escape(
+ walker=walker,
+ arena=arena,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py
new file mode 100644
index 0000000..173608d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py
@@ -0,0 +1,220 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Task for a walker to move to a target."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control import composer
+from dm_control.composer import variation
+from dm_control.composer.observation import observable
+from dm_control.composer.variation import distributions
+import numpy as np
+
+DEFAULT_DISTANCE_TOLERANCE_TO_TARGET = 1.0
+
+
+class GoToTarget(composer.Task):
+ """A task that requires a walker to move towards a target."""
+
+ def __init__(self,
+ walker,
+ arena,
+ moving_target=False,
+ target_relative=False,
+ target_relative_dist=1.5,
+ steps_before_moving_target=10,
+ distance_tolerance=DEFAULT_DISTANCE_TOLERANCE_TO_TARGET,
+ target_spawn_position=None,
+ walker_spawn_position=None,
+ walker_spawn_rotation=None,
+ physics_timestep=0.005,
+ control_timestep=0.025):
+ """Initializes this task.
+
+ Args:
+ walker: an instance of `locomotion.walkers.base.Walker`.
+ arena: an instance of `locomotion.arenas.floors.Floor`.
+ moving_target: bool, Whether the target should move after receiving the
+ walker reaches it.
+ target_relative: bool, Whether the target be set relative to its current
+ position.
+ target_relative_dist: float, new target distance range if
+ using target_relative.
+ steps_before_moving_target: int, the number of steps before the target
+ moves, if moving_target==True.
+ distance_tolerance: Accepted to distance to the target position before
+ providing reward.
+ target_spawn_position: a sequence of 2 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the target is spawned at the beginning of an episode.
+ If None, the entire arena is used to generate random target positions.
+ walker_spawn_position: a sequence of 2 numbers, or a `composer.Variation`
+ instance that generates such sequences, specifying the position at
+ which the walker is spawned at the beginning of an episode.
+ If None, the entire arena is used to generate random spawn positions.
+ walker_spawn_rotation: a number, or a `composer.Variation` instance that
+ generates a number, specifying the yaw angle offset (in radians) that is
+ applied to the walker at the beginning of an episode.
+ physics_timestep: a number specifying the timestep (in seconds) of the
+ physics simulation.
+ control_timestep: a number specifying the timestep (in seconds) at which
+ the agent applies its control inputs (in seconds).
+ """
+
+ self._arena = arena
+ self._walker = walker
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+
+ arena_position = distributions.Uniform(
+ low=-np.array(arena.size) / 2, high=np.array(arena.size) / 2)
+ if target_spawn_position is not None:
+ self._target_spawn_position = target_spawn_position
+ else:
+ self._target_spawn_position = arena_position
+
+ if walker_spawn_position is not None:
+ self._walker_spawn_position = walker_spawn_position
+ else:
+ self._walker_spawn_position = arena_position
+
+ self._walker_spawn_rotation = walker_spawn_rotation
+
+ self._distance_tolerance = distance_tolerance
+ self._moving_target = moving_target
+ self._target_relative = target_relative
+ self._target_relative_dist = target_relative_dist
+ self._steps_before_moving_target = steps_before_moving_target
+ self._reward_step_counter = 0
+
+ self._target = self.root_entity.mjcf_model.worldbody.add(
+ 'site',
+ name='target',
+ type='sphere',
+ pos=(0., 0., 0.),
+ size=(0.1,),
+ rgba=(0.9, 0.6, 0.6, 1.0))
+
+ enabled_observables = []
+ enabled_observables += self._walker.observables.proprioception
+ enabled_observables += self._walker.observables.kinematic_sensors
+ enabled_observables += self._walker.observables.dynamic_sensors
+ enabled_observables.append(self._walker.observables.sensors_touch)
+ for obs in enabled_observables:
+ obs.enabled = True
+
+ walker.observables.add_egocentric_vector(
+ 'target',
+ observable.MJCFFeature('pos', self._target),
+ origin_callable=lambda physics: physics.bind(walker.root_body).xpos)
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def target_position(self, physics):
+ return np.array(physics.bind(self._target).pos)
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state=random_state)
+
+ target_x, target_y = variation.evaluate(
+ self._target_spawn_position, random_state=random_state)
+ self._target.pos = [target_x, target_y, 0.]
+
+ def initialize_episode(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+ if self._walker_spawn_rotation:
+ rotation = variation.evaluate(
+ self._walker_spawn_rotation, random_state=random_state)
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+ else:
+ quat = None
+ walker_x, walker_y = variation.evaluate(
+ self._walker_spawn_position, random_state=random_state)
+ self._walker.shift_pose(
+ physics,
+ position=[walker_x, walker_y, 0.],
+ quaternion=quat,
+ rotate_velocity=True)
+
+ self._failure_termination = False
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._arena.ground_geoms).element_id)
+ self._ground_geomids.add(physics.bind(self._target).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def should_terminate_episode(self, physics):
+ return self._failure_termination
+
+ def get_discount(self, physics):
+ if self._failure_termination:
+ return 0.
+ else:
+ return 1.
+
+ def get_reward(self, physics):
+ reward = 0.
+ distance = np.linalg.norm(
+ physics.bind(self._target).pos[:2] -
+ physics.bind(self._walker.root_body).xpos[:2])
+ if distance < self._distance_tolerance:
+ reward = 1.
+ if self._moving_target:
+ self._reward_step_counter += 1
+ return reward
+
+ def before_step(self, physics, action, random_state):
+ self._walker.apply_action(physics, action, random_state)
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ for contact in physics.data.contact:
+ if self._is_disallowed_contact(contact):
+ self._failure_termination = True
+ break
+ if (self._moving_target and
+ self._reward_step_counter >= self._steps_before_moving_target):
+
+ # Reset the target position.
+ if self._target_relative:
+ walker_pos = physics.bind(self._walker.root_body).xpos[:2]
+ target_x, target_y = random_state.uniform(
+ -np.array([self._target_relative_dist, self._target_relative_dist]),
+ np.array([self._target_relative_dist, self._target_relative_dist]))
+ target_x += walker_pos[0]
+ target_y += walker_pos[1]
+ else:
+ target_x, target_y = variation.evaluate(
+ self._target_spawn_position, random_state=random_state)
+ physics.bind(self._target).pos = [target_x, target_y, 0.]
+
+ # Reset the number of steps at the target for the moving target.
+ self._reward_step_counter = 0
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py
new file mode 100644
index 0000000..183f0df
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py
@@ -0,0 +1,160 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for locomotion.tasks.go_to_target."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.tasks import go_to_target
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+from six.moves import range
+
+
+class GoToTargetTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/target', timestep.observation)
+
+ def test_target_position_randomized_on_reset(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+ first_target_position = task.target_position(env.physics)
+ env.reset()
+ second_target_position = task.target_position(env.physics)
+ self.assertFalse(np.all(first_target_position == second_target_position),
+ 'Target positions are unexpectedly identical.')
+
+ def test_reward_fixed_target(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(
+ walker=walker, arena=arena, moving_target=False)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ target_position = task.target_position(env.physics)
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+ for _ in range(2):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+ walker_pos = env.physics.bind(walker.root_body).xpos
+ walker.set_pose(
+ env.physics,
+ position=[target_position[0], target_position[1], walker_pos[2]])
+ env.physics.forward()
+
+ # Receive reward while the agent remains at that location.
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 1)
+
+ # Target position should not change.
+ np.testing.assert_array_equal(target_position,
+ task.target_position(env.physics))
+
+ def test_reward_moving_target(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+
+ steps_before_moving_target = 2
+ task = go_to_target.GoToTarget(
+ walker=walker,
+ arena=arena,
+ moving_target=True,
+ steps_before_moving_target=steps_before_moving_target)
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ target_position = task.target_position(env.physics)
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+ for _ in range(2):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+
+ walker_pos = env.physics.bind(walker.root_body).xpos
+ walker.set_pose(
+ env.physics,
+ position=[target_position[0], target_position[1], walker_pos[2]])
+ env.physics.forward()
+
+ # Receive reward while the agent remains at that location.
+ for _ in range(steps_before_moving_target):
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 1)
+ np.testing.assert_array_equal(target_position,
+ task.target_position(env.physics))
+
+ # After taking > steps_before_moving_target, the target should move and
+ # reward should be 0.
+ timestep = env.step(zero_action)
+ self.assertEqual(timestep.reward, 0)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ arena = floors.Floor()
+ task = go_to_target.GoToTarget(walker=walker, arena=arena)
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py
new file mode 100644
index 0000000..6e8ef52
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py
@@ -0,0 +1,546 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A task consisting of finding goals/targets in a random maze."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable as observable_lib
+from dm_control.locomotion.props import target_sphere
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+from six.moves import range
+from six.moves import zip
+
+_NUM_RAYS = 10
+
+# Aliveness in [-1., 0.].
+DEFAULT_ALIVE_THRESHOLD = -0.5
+
+DEFAULT_PHYSICS_TIMESTEP = 0.001
+DEFAULT_CONTROL_TIMESTEP = 0.025
+
+
+class NullGoalMaze(composer.Task):
+ """A base task for maze with goals."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ enable_global_task_observables=False,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ """Initializes goal-directed maze task.
+
+ Args:
+ walker: The body to navigate the maze.
+ maze_arena: The physical maze arena object.
+ randomize_spawn_position: Flag to randomize position of spawning.
+ randomize_spawn_rotation: Flag to randomize orientation of spawning.
+ rotation_bias_factor: A non-negative number that concentrates initial
+ orientation away from walls. When set to zero, the initial orientation
+ is uniformly random. The larger the value of this number, the more
+ likely it is that the initial orientation would face the direction that
+ is farthest away from a wall.
+ aliveness_reward: Reward for being alive.
+ aliveness_threshold: Threshold if should terminate based on walker
+ aliveness feature.
+ contact_termination: whether to terminate if a non-foot geom touches the
+ ground.
+ enable_global_task_observables: Flag to provide task observables that
+ contain global information, including map layout.
+ physics_timestep: timestep of simulation.
+ control_timestep: timestep at which agent changes action.
+ """
+ self._walker = walker
+ self._maze_arena = maze_arena
+ self._walker.create_root_joints(self._maze_arena.attach(self._walker))
+
+ self._randomize_spawn_position = randomize_spawn_position
+ self._randomize_spawn_rotation = randomize_spawn_rotation
+ self._rotation_bias_factor = rotation_bias_factor
+
+ self._aliveness_reward = aliveness_reward
+ self._aliveness_threshold = aliveness_threshold
+ self._contact_termination = contact_termination
+ self._discount = 1.0
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ self._walker.observables.egocentric_camera.height = 64
+ self._walker.observables.egocentric_camera.width = 64
+
+ for observable in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors):
+ observable.enabled = True
+ self._walker.observables.egocentric_camera.enabled = True
+
+ if enable_global_task_observables:
+ # Reveal maze text map as observable.
+ maze_obs = observable_lib.Generic(
+ lambda _: self._maze_arena.maze.entity_layer)
+ maze_obs.enabled = True
+
+ # absolute walker position
+ def get_walker_pos(physics):
+ walker_pos = physics.bind(self._walker.root_body).xpos
+ return walker_pos
+ absolute_position = observable_lib.Generic(get_walker_pos)
+ absolute_position.enabled = True
+
+ # absolute walker orientation
+ def get_walker_ori(physics):
+ walker_ori = np.reshape(
+ physics.bind(self._walker.root_body).xmat, (3, 3))
+ return walker_ori
+ absolute_orientation = observable_lib.Generic(get_walker_ori)
+ absolute_orientation.enabled = True
+
+ # grid element of player in maze cell: i,j cell in maze layout
+ def get_walker_ij(physics):
+ walker_xypos = physics.bind(self._walker.root_body).xpos[:-1]
+ walker_rel_origin = (
+ (walker_xypos +
+ np.sign(walker_xypos) * self._maze_arena.xy_scale / 2) /
+ (self._maze_arena.xy_scale)).astype(int)
+ x_offset = (self._maze_arena.maze.width - 1) / 2
+ y_offset = (self._maze_arena.maze.height - 1) / 2
+ walker_ij = walker_rel_origin + np.array([x_offset, y_offset])
+ return walker_ij
+ absolute_position_discrete = observable_lib.Generic(get_walker_ij)
+ absolute_position_discrete.enabled = True
+
+ self._task_observables = collections.OrderedDict({
+ 'maze_layout': maze_obs,
+ 'absolute_position': absolute_position,
+ 'absolute_orientation': absolute_orientation,
+ 'location_in_maze': absolute_position_discrete, # from bottom left
+ })
+ else:
+ self._task_observables = collections.OrderedDict({})
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def name(self):
+ return 'goal_maze'
+
+ @property
+ def root_entity(self):
+ return self._maze_arena
+
+ def initialize_episode_mjcf(self, unused_random_state):
+ self._maze_arena.regenerate()
+
+ def _respawn(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+
+ if self._randomize_spawn_position:
+ self._spawn_position = self._maze_arena.spawn_positions[
+ random_state.randint(0, len(self._maze_arena.spawn_positions))]
+
+ if self._randomize_spawn_rotation:
+ # Move walker up out of the way before raycasting.
+ self._walker.shift_pose(physics, [0.0, 0.0, 100.0])
+
+ distances = []
+ geomid_out = np.array([-1], dtype=np.intc)
+ for i in range(_NUM_RAYS):
+ theta = 2 * np.pi * i / _NUM_RAYS
+ pos = np.array([self._spawn_position[0], self._spawn_position[1], 0.1],
+ dtype=np.float64)
+ vec = np.array([np.cos(theta), np.sin(theta), 0], dtype=np.float64)
+ dist = mjbindings.mjlib.mj_ray(
+ physics.model.ptr, physics.data.ptr, pos, vec,
+ None, 1, -1, geomid_out)
+ distances.append(dist)
+
+ def remap_with_bias(x):
+ """Remaps values [-1, 1] -> [-1, 1] with bias."""
+ return np.tanh((1 + self._rotation_bias_factor) * np.arctanh(x))
+
+ max_theta = 2 * np.pi * np.argmax(distances) / _NUM_RAYS
+ rotation = max_theta + np.pi * (
+ 1 + remap_with_bias(random_state.uniform(-1, 1)))
+
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+
+ # Move walker back down.
+ self._walker.shift_pose(physics, [0.0, 0.0, -100.0])
+ else:
+ quat = None
+
+ self._walker.shift_pose(
+ physics, [self._spawn_position[0], self._spawn_position[1], 0.0],
+ quat,
+ rotate_velocity=True)
+
+ def initialize_episode(self, physics, random_state):
+ super(NullGoalMaze, self).initialize_episode(physics, random_state)
+ self._respawn(physics, random_state)
+ self._discount = 1.0
+
+ walker_foot_geoms = set(self._walker.ground_contact_geoms)
+ walker_nonfoot_geoms = [
+ geom for geom in self._walker.mjcf_model.find_all('geom')
+ if geom not in walker_foot_geoms]
+ self._walker_nonfoot_geomids = set(
+ physics.bind(walker_nonfoot_geoms).element_id)
+ self._ground_geomids = set(
+ physics.bind(self._maze_arena.ground_geoms).element_id)
+
+ def _is_disallowed_contact(self, contact):
+ set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom1 in set2 and contact.geom2 in set1))
+
+ def after_step(self, physics, random_state):
+ self._failure_termination = False
+ if self._contact_termination:
+ for c in physics.data.contact:
+ if self._is_disallowed_contact(c):
+ self._failure_termination = True
+ break
+
+ def should_terminate_episode(self, physics):
+ if self._walker.aliveness(physics) < self._aliveness_threshold:
+ self._failure_termination = True
+ if self._failure_termination:
+ self._discount = 0.0
+ return True
+ else:
+ return False
+
+ def get_reward(self, physics):
+ del physics
+ return self._aliveness_reward
+
+ def get_discount(self, physics):
+ del physics
+ return self._discount
+
+
+class RepeatSingleGoalMaze(NullGoalMaze):
+ """Requires an agent to repeatedly find the same goal in a maze."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target=target_sphere.TargetSphere(),
+ target_reward_scale=1.0,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ max_repeats=0,
+ enable_global_task_observables=False,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super(RepeatSingleGoalMaze, self).__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ enable_global_task_observables=enable_global_task_observables,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ self._target = target
+ self._rewarded_this_step = False
+ self._maze_arena.attach(target)
+ self._target_reward_scale = target_reward_scale
+ self._max_repeats = max_repeats
+ self._targets_obtained = 0
+
+ if enable_global_task_observables:
+ xpos_origin_callable = lambda phys: phys.bind(walker.root_body).xpos
+
+ def _target_pos(physics, target=target):
+ return physics.bind(target.geom).xpos
+
+ walker.observables.add_egocentric_vector(
+ 'target_0',
+ observable_lib.Generic(_target_pos),
+ origin_callable=xpos_origin_callable)
+
+ def initialize_episode_mjcf(self, random_state):
+ super(RepeatSingleGoalMaze, self).initialize_episode_mjcf(random_state)
+ self._target_position = self._maze_arena.target_positions[
+ random_state.randint(0, len(self._maze_arena.target_positions))]
+ mjcf.get_attachment_frame(
+ self._target.mjcf_model).pos = self._target_position
+
+ def initialize_episode(self, physics, random_state):
+ super(RepeatSingleGoalMaze, self).initialize_episode(physics, random_state)
+ self._rewarded_this_step = False
+ self._targets_obtained = 0
+
+ def after_step(self, physics, random_state):
+ super(RepeatSingleGoalMaze, self).after_step(physics, random_state)
+ if self._target.activated:
+ self._rewarded_this_step = True
+ self._targets_obtained += 1
+ if self._targets_obtained <= self._max_repeats:
+ self._respawn(physics, random_state)
+ self._target.reset(physics)
+ else:
+ self._rewarded_this_step = False
+
+ def should_terminate_episode(self, physics):
+ if super(RepeatSingleGoalMaze, self).should_terminate_episode(physics):
+ return True
+ if self._targets_obtained > self._max_repeats:
+ return True
+
+ def get_reward(self, physics):
+ del physics
+ if self._rewarded_this_step:
+ target_reward = self._target_reward_scale
+ else:
+ target_reward = 0.0
+ return target_reward + self._aliveness_reward
+
+
+class ManyHeterogeneousGoalsMaze(NullGoalMaze):
+ """Requires an agent to find multiple goals with different rewards."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target_builders,
+ target_type_rewards,
+ target_type_proportions,
+ shuffle_target_builders=False,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super(ManyHeterogeneousGoalsMaze, self).__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ self._active_targets = []
+ self._target_builders = target_builders
+ self._target_type_rewards = tuple(target_type_rewards)
+ self._target_type_fractions = (
+ np.array(target_type_proportions, dtype=float) /
+ np.sum(target_type_proportions))
+ self._shuffle_target_builders = shuffle_target_builders
+
+ def _get_targets(self, total_target_count, random_state):
+ # Multiply total target count by the fraction for each type, rounded down.
+ target_numbers = np.array([int(frac * total_target_count)
+ for frac in self._target_type_fractions])
+
+ # Calculate deviations from the ideal ratio incurred by rounding.
+ errors = (self._target_type_fractions -
+ target_numbers / float(total_target_count))
+
+ # Sort the target types by deviations from ideal ratios.
+ target_types_sorted_by_errors = list(np.argsort(errors))
+
+ # Top up individual target classes until we reach the desired total,
+ # starting from the class that is furthest away from the ideal ratio.
+ current_total = np.sum(target_numbers)
+ while current_total < total_target_count:
+ target_numbers[target_types_sorted_by_errors.pop()] += 1
+ current_total += 1
+
+ if self._shuffle_target_builders:
+ random_state.shuffle(self._target_builders)
+
+ all_targets = []
+ for target_type, num in enumerate(target_numbers):
+ targets = []
+ target_builder = self._target_builders[target_type]
+ for i in range(num):
+ target = target_builder(name='target_{}_{}'.format(target_type, i))
+ targets.append(target)
+ all_targets.append(targets)
+ return all_targets
+
+ def initialize_episode_mjcf(self, random_state):
+ super(
+ ManyHeterogeneousGoalsMaze, self).initialize_episode_mjcf(random_state)
+ for target in itertools.chain(*self._active_targets):
+ target.detach()
+ target_positions = list(self._maze_arena.target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._get_targets(len(target_positions), random_state)
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ self._maze_arena.attach(target)
+ mjcf.get_attachment_frame(target.mjcf_model).pos = pos
+ target.initialize_episode_mjcf(random_state)
+ self._active_targets = all_targets
+ self._target_rewarded = [[False] * len(targets) for targets in all_targets]
+
+ def get_reward(self, physics):
+ del physics
+ reward = self._aliveness_reward
+ for target_type, targets in enumerate(self._active_targets):
+ for i, target in enumerate(targets):
+ if target.activated and not self._target_rewarded[target_type][i]:
+ reward += self._target_type_rewards[target_type]
+ self._target_rewarded[target_type][i] = True
+ return reward
+
+ def should_terminate_episode(self, physics):
+ if super(ManyHeterogeneousGoalsMaze,
+ self).should_terminate_episode(physics):
+ return True
+ else:
+ for target in itertools.chain(*self._active_targets):
+ if not target.activated:
+ return False
+ # All targets have been activated: successful termination.
+ return True
+
+
+class ManyGoalsMaze(ManyHeterogeneousGoalsMaze):
+ """Requires an agent to find all goals in a random maze."""
+
+ def __init__(self,
+ walker,
+ maze_arena,
+ target_builder,
+ target_reward_scale=1.0,
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super(ManyGoalsMaze, self).__init__(
+ walker=walker,
+ maze_arena=maze_arena,
+ target_builders=[target_builder],
+ target_type_rewards=[target_reward_scale],
+ target_type_proportions=[1],
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+
+
+class RepeatSingleGoalMazeAugmentedWithTargets(RepeatSingleGoalMaze):
+ """Augments the single goal maze with many lower reward targets."""
+
+ def __init__(self,
+ walker,
+ main_target,
+ maze_arena,
+ num_subtargets=20,
+ target_reward_scale=10.0,
+ subtarget_reward_scale=1.0,
+ subtarget_colors=((0, 0, 0.4), (0, 0, 0.7)),
+ randomize_spawn_position=True,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ aliveness_threshold=DEFAULT_ALIVE_THRESHOLD,
+ contact_termination=True,
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ super(RepeatSingleGoalMazeAugmentedWithTargets, self).__init__(
+ walker=walker,
+ target=main_target,
+ maze_arena=maze_arena,
+ target_reward_scale=target_reward_scale,
+ randomize_spawn_position=randomize_spawn_position,
+ randomize_spawn_rotation=randomize_spawn_rotation,
+ rotation_bias_factor=rotation_bias_factor,
+ aliveness_reward=aliveness_reward,
+ aliveness_threshold=aliveness_threshold,
+ contact_termination=contact_termination,
+ physics_timestep=physics_timestep,
+ control_timestep=control_timestep)
+ self._subtarget_reward_scale = subtarget_reward_scale
+ self._subtargets = []
+ for i in range(num_subtargets):
+ subtarget = target_sphere.TargetSphere(
+ radius=0.4, rgb1=subtarget_colors[0], rgb2=subtarget_colors[1],
+ name='subtarget_{}'.format(i)
+ )
+ self._subtargets.append(subtarget)
+ self._maze_arena.attach(subtarget)
+ self._subtarget_rewarded = None
+
+ def initialize_episode_mjcf(self, random_state):
+ super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).initialize_episode_mjcf(random_state)
+ subtarget_positions = self._maze_arena.target_positions
+ for pos, subtarget in zip(subtarget_positions, self._subtargets):
+ mjcf.get_attachment_frame(subtarget.mjcf_model).pos = pos
+ self._subtarget_rewarded = [False] * len(self._subtargets)
+
+ def get_reward(self, physics):
+ main_reward = super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).get_reward(physics)
+ subtarget_reward = 0
+ for i, subtarget in enumerate(self._subtargets):
+ if subtarget.activated and not self._subtarget_rewarded[i]:
+ subtarget_reward += 1
+ self._subtarget_rewarded[i] = True
+ subtarget_reward *= self._subtarget_reward_scale
+ return main_reward + subtarget_reward
+
+ def should_terminate_episode(self, physics):
+ if super(RepeatSingleGoalMazeAugmentedWithTargets,
+ self).should_terminate_episode(physics):
+ return True
+ else:
+ for subtarget in self._subtargets:
+ if not subtarget.activated:
+ return False
+ # All subtargets have been activated.
+ return True
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py
new file mode 100644
index 0000000..bef643c
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py
@@ -0,0 +1,135 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.random_goal_maze."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import labmaze_textures
+from dm_control.locomotion.arenas import mazes
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import random_goal_maze
+from dm_control.locomotion.walkers import cmu_humanoid
+
+import numpy as np
+from six.moves import range
+
+
+class RandomGoalMazeTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ control_timestep=.03,
+ physics_timestep=.005,
+ )
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+ def test_termination_and_discount(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ # Build a maze with rooms and targets.
+ skybox_texture = labmaze_textures.SkyBox(style='sky_03')
+ wall_textures = labmaze_textures.WallTextures(style='style_01')
+ floor_textures = labmaze_textures.FloorTextures(style='style_01')
+ arena = mazes.RandomMazeWithTargets(
+ x_cells=11,
+ y_cells=11,
+ xy_scale=3,
+ max_rooms=4,
+ room_min_size=4,
+ room_max_size=5,
+ spawns_per_room=1,
+ targets_per_room=3,
+ skybox_texture=skybox_texture,
+ wall_textures=wall_textures,
+ floor_textures=floor_textures,
+ )
+
+ task = random_goal_maze.ManyGoalsMaze(
+ walker=walker,
+ maze_arena=arena,
+ target_builder=functools.partial(
+ target_sphere.TargetSphere,
+ radius=0.4,
+ rgb1=(0, 0, 0.4),
+ rgb2=(0, 0, 0.7)),
+ control_timestep=.03,
+ physics_timestep=.005,
+ )
+
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ env.reset()
+
+ zero_action = np.zeros_like(env.physics.data.ctrl)
+
+ # Walker starts in upright position.
+ # Should not trigger failure termination in the first few steps.
+ for _ in range(5):
+ env.step(zero_action)
+ self.assertFalse(task.should_terminate_episode(env.physics))
+ np.testing.assert_array_equal(task.get_discount(env.physics), 1)
+
+ # Rotate the walker upside down and run the physics until it makes contact.
+ current_time = env.physics.data.time
+ walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0))
+ env.physics.forward()
+ while env.physics.data.ncon == 0:
+ env.physics.step()
+ env.physics.data.time = current_time
+
+ # Should now trigger a failure termination.
+ env.step(zero_action)
+ self.assertTrue(task.should_terminate_episode(env.physics))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py
new file mode 100644
index 0000000..e575691
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py
@@ -0,0 +1,294 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A (visuomotor) task consisting of reaching to targets for reward."""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import itertools
+
+from dm_control import composer
+from dm_control.composer.observation import observable as dm_observable
+import enum
+import numpy as np
+from six.moves import range
+from six.moves import zip
+
+DEFAULT_ALIVE_THRESHOLD = -1.0
+DEFAULT_PHYSICS_TIMESTEP = 0.005
+DEFAULT_CONTROL_TIMESTEP = 0.03
+
+
+class TwoTouchState(enum.IntEnum):
+ PRE_TOUCH = 0
+ TOUCHED_ONCE = 1
+ TOUCHED_TWICE = 2 # at appropriate time
+ TOUCHED_TOO_SOON = 3
+ NO_SECOND_TOUCH = 4
+
+
+class TwoTouch(composer.Task):
+ """Task with target to tap with short delay (for Rat)."""
+
+ def __init__(self,
+ walker,
+ arena,
+ target_builders,
+ target_type_rewards,
+ shuffle_target_builders=False,
+ randomize_spawn_position=False,
+ randomize_spawn_rotation=True,
+ rotation_bias_factor=0,
+ aliveness_reward=0.0,
+ touch_interval=0.8,
+ interval_tolerance=0.1, # consider making a curriculum
+ failure_timeout=1.2,
+ reset_delay=0.,
+ z_height=.14, # 5.5" in real experiments
+ target_area=(),
+ physics_timestep=DEFAULT_PHYSICS_TIMESTEP,
+ control_timestep=DEFAULT_CONTROL_TIMESTEP):
+ self._walker = walker
+ self._arena = arena
+ self._walker.create_root_joints(self._arena.attach(self._walker))
+ if 'CMUHumanoid' in str(type(self._walker)):
+ self._lhand_body = walker.mjcf_model.find('body', 'lhand')
+ self._rhand_body = walker.mjcf_model.find('body', 'rhand')
+ elif 'Rat' in str(type(self._walker)):
+ self._lhand_body = walker.mjcf_model.find('body', 'hand_L')
+ self._rhand_body = walker.mjcf_model.find('body', 'hand_R')
+ else:
+ raise ValueError('Expects Rat or CMUHumanoid.')
+ self._lhand_geoms = self._lhand_body.find_all('geom')
+ self._rhand_geoms = self._rhand_body.find_all('geom')
+
+ self._targets = []
+ self._target_builders = target_builders
+ self._target_type_rewards = tuple(target_type_rewards)
+ self._shuffle_target_builders = shuffle_target_builders
+
+ self._randomize_spawn_position = randomize_spawn_position
+ self._spawn_position = [0.0, 0.0] # x, y
+ self._randomize_spawn_rotation = randomize_spawn_rotation
+ self._rotation_bias_factor = rotation_bias_factor
+
+ self._aliveness_reward = aliveness_reward
+ self._discount = 1.0
+
+ self._touch_interval = touch_interval
+ self._interval_tolerance = interval_tolerance
+ self._failure_timeout = failure_timeout
+ self._reset_delay = reset_delay
+ self._target_positions = []
+ self._state_logic = TwoTouchState.PRE_TOUCH
+
+ self._z_height = z_height
+ arena_size = self._arena.size
+ if target_area:
+ self._target_area = target_area
+ else:
+ self._target_area = [1/2*arena_size[0], 1/2*arena_size[1]]
+ target_x = 1.
+ target_y = 1.
+ self._target_positions.append((target_x, target_y, self._z_height))
+
+ self.set_timesteps(
+ physics_timestep=physics_timestep, control_timestep=control_timestep)
+
+ self._task_observables = collections.OrderedDict()
+ def task_state(physics):
+ del physics
+ return np.array([self._state_logic])
+ self._task_observables['task_logic'] = dm_observable.Generic(task_state)
+
+ self._walker.observables.egocentric_camera.height = 64
+ self._walker.observables.egocentric_camera.width = 64
+
+ for observable in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors +
+ list(self._task_observables.values())):
+ observable.enabled = True
+ self._walker.observables.egocentric_camera.enabled = True
+
+ def _get_targets(self, total_target_count, random_state):
+ # Multiply total target count by the fraction for each type, rounded down.
+ target_numbers = np.array([1, len(self._target_positions)-1])
+
+ if self._shuffle_target_builders:
+ random_state.shuffle(self._target_builders)
+
+ all_targets = []
+ for target_type, num in enumerate(target_numbers):
+ targets = []
+ if num < 1:
+ break
+ target_builder = self._target_builders[target_type]
+ for i in range(num):
+ target = target_builder(name='target_{}_{}'.format(target_type, i))
+ targets.append(target)
+ all_targets.append(targets)
+ return all_targets
+
+ @property
+ def name(self):
+ return 'two_touch'
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ def _randomize_targets(self, physics, random_state=np.random):
+ for ii in range(len(self._target_positions)):
+ target_x = self._target_area[0]*random_state.uniform(-1., 1.)
+ target_y = self._target_area[1]*random_state.uniform(-1., 1.)
+ self._target_positions[ii] = (target_x, target_y, self._z_height)
+ target_positions = np.copy(self._target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._targets
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ target.reset(physics)
+ physics.bind(target.geom).pos = pos
+ self._targets = all_targets
+ self._target_rewarded_once = [
+ [False] * len(targets) for targets in all_targets]
+ self._target_rewarded_twice = [
+ [False] * len(targets) for targets in all_targets]
+ self._first_touch_time = None
+ self._second_touch_time = None
+ self._do_time_out = False
+ self._state_logic = TwoTouchState.PRE_TOUCH
+
+ def initialize_episode_mjcf(self, random_state):
+ self._arena.regenerate(random_state)
+ for target in itertools.chain(*self._targets):
+ target.detach()
+ target_positions = np.copy(self._target_positions)
+ random_state.shuffle(target_positions)
+ all_targets = self._get_targets(len(self._target_positions), random_state)
+ for pos, target in zip(target_positions, itertools.chain(*all_targets)):
+ self._arena.attach(target)
+ target.geom.pos = pos
+ target.initialize_episode_mjcf(random_state)
+ self._targets = all_targets
+
+ def _respawn_walker(self, physics, random_state):
+ self._walker.reinitialize_pose(physics, random_state)
+
+ if self._randomize_spawn_position:
+ self._spawn_position = self._arena.spawn_positions[
+ random_state.randint(0, len(self._arena.spawn_positions))]
+
+ if self._randomize_spawn_rotation:
+ rotation = 2*np.pi*np.random.uniform()
+ quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)]
+
+ self._walker.shift_pose(
+ physics,
+ [self._spawn_position[0], self._spawn_position[1], 0.0],
+ quat,
+ rotate_velocity=True)
+
+ def initialize_episode(self, physics, random_state):
+ super(TwoTouch, self).initialize_episode(physics, random_state)
+ self._respawn_walker(physics, random_state)
+ self._state_logic = TwoTouchState.PRE_TOUCH
+ self._discount = 1.0
+ self._lhand_geomids = set(physics.bind(self._lhand_geoms).element_id)
+ self._rhand_geomids = set(physics.bind(self._rhand_geoms).element_id)
+ self._hand_geomids = self._lhand_geomids | self._rhand_geomids
+ self._randomize_targets(physics)
+ self._must_randomize_targets = False
+ for target in itertools.chain(*self._targets):
+ target._specific_collision_geom_ids = self._hand_geomids # pylint: disable=protected-access
+
+ def before_step(self, physics, action, random_state):
+ super(TwoTouch, self).before_step(physics, action, random_state)
+ if self._must_randomize_targets:
+ self._randomize_targets(physics)
+ self._must_randomize_targets = False
+
+ def should_terminate_episode(self, physics):
+ failure_termination = False
+ if failure_termination:
+ self._discount = 0.0
+ return True
+ else:
+ return False
+
+ def get_reward(self, physics):
+ reward = self._aliveness_reward
+ lhand_pos = physics.bind(self._lhand_body).xpos
+ rhand_pos = physics.bind(self._rhand_body).xpos
+ target_pos = physics.bind(self._targets[0][0].geom).xpos
+ lhand_rew = np.exp(-3.*sum(np.abs(lhand_pos-target_pos)))
+ rhand_rew = np.exp(-3.*sum(np.abs(rhand_pos-target_pos)))
+ closeness_reward = np.maximum(lhand_rew, rhand_rew)
+ reward += .01*closeness_reward*self._target_type_rewards[0]
+ if self._state_logic == TwoTouchState.PRE_TOUCH:
+ # touch the first time
+ for target_type, targets in enumerate(self._targets):
+ for i, target in enumerate(targets):
+ if (target.activated[0] and
+ not self._target_rewarded_once[target_type][i]):
+ self._first_touch_time = physics.time()
+ self._state_logic = TwoTouchState.TOUCHED_ONCE
+ self._target_rewarded_once[target_type][i] = True
+ reward += self._target_type_rewards[target_type]
+ elif self._state_logic == TwoTouchState.TOUCHED_ONCE:
+ for target_type, targets in enumerate(self._targets):
+ for i, target in enumerate(targets):
+ if (target.activated[1] and
+ not self._target_rewarded_twice[target_type][i]):
+ self._second_touch_time = physics.time()
+ self._state_logic = TwoTouchState.TOUCHED_TWICE
+ self._target_rewarded_twice[target_type][i] = True
+ # check if touched too soon
+ if ((self._second_touch_time - self._first_touch_time) <
+ (self._touch_interval - self._interval_tolerance)):
+ self._do_time_out = True
+ self._state_logic = TwoTouchState.TOUCHED_TOO_SOON
+ # check if touched at correct time
+ elif ((self._second_touch_time - self._first_touch_time) <=
+ (self._touch_interval + self._interval_tolerance)):
+ reward += self._target_type_rewards[target_type]
+ # check if no second touch within time interval
+ if ((physics.time() - self._first_touch_time) >
+ (self._touch_interval + self._interval_tolerance)):
+ self._do_time_out = True
+ self._state_logic = TwoTouchState.NO_SECOND_TOUCH
+ self._second_touch_time = physics.time()
+ elif (self._state_logic == TwoTouchState.TOUCHED_TWICE or
+ self._state_logic == TwoTouchState.TOUCHED_TOO_SOON or
+ self._state_logic == TwoTouchState.NO_SECOND_TOUCH):
+ # hold here due to timeout
+ if self._do_time_out:
+ if physics.time() > (self._second_touch_time + self._failure_timeout):
+ self._do_time_out = False
+ # reset/re-randomize
+ elif physics.time() > (self._second_touch_time + self._reset_delay):
+ self._must_randomize_targets = True
+ return reward
+
+ def get_discount(self, physics):
+ del physics
+ return self._discount
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py
new file mode 100644
index 0000000..2132409
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py
@@ -0,0 +1,66 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for locomotion.tasks.reach."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+from absl.testing import absltest
+
+from dm_control import composer
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.props import target_sphere
+from dm_control.locomotion.tasks import reach
+from dm_control.locomotion.walkers import rodent
+
+import numpy as np
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+class ReachTest(absltest.TestCase):
+
+ def test_observables(self):
+ walker = rodent.Rat()
+
+ arena = floors.Floor(
+ size=(10., 10.),
+ aesthetic='outdoor_natural')
+
+ task = reach.TwoTouch(
+ walker=walker,
+ arena=arena,
+ target_builders=[
+ functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025),
+ ],
+ randomize_spawn_rotation=True,
+ target_type_rewards=[25.],
+ shuffle_target_builders=False,
+ target_area=(1.5, 1.5),
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP,
+ )
+ random_state = np.random.RandomState(12345)
+ env = composer.Environment(task, random_state=random_state)
+ timestep = env.reset()
+
+ self.assertIn('walker/joints_pos', timestep.observation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py
new file mode 100644
index 0000000..939a03b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Walkers for Locomotion tasks."""
+
+from dm_control.locomotion.walkers.rodent import Rat
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml
new file mode 100644
index 0000000..64c9db9
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml
@@ -0,0 +1,297 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml
new file mode 100644
index 0000000..7268f3c
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml
@@ -0,0 +1,609 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn
new file mode 100644
index 0000000..53a7e7a
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn differ
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py
new file mode 100644
index 0000000..8890dab
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py
@@ -0,0 +1,196 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for Walkers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+
+from dm_control import composer
+from dm_control.composer.observation import observable
+
+from dm_env import specs
+import numpy as np
+import six
+
+
+def _make_readonly_float64_copy(value):
+ if np.isscalar(value):
+ return np.float64(value)
+ else:
+ out = np.array(value, dtype=np.float64)
+ out.flags.writeable = False
+ return out
+
+
+class WalkerPose(collections.namedtuple(
+ 'WalkerPose', ('qpos', 'xpos', 'xquat'))):
+ """A named tuple representing a walker's joint and Cartesian pose."""
+
+ __slots__ = ()
+
+ def __new__(cls, qpos=None, xpos=(0, 0, 0), xquat=(1, 0, 0, 0)):
+ """Creates a new WalkerPose.
+
+ Args:
+ qpos: The joint position for the pose, or `None` if the `qpos0` values in
+ the `mjModel` should be used.
+ xpos: A Cartesian displacement, for example if the walker should be lifted
+ or lowered by a specific amount for this pose.
+ xquat: A quaternion displacement for the root body.
+
+ Returns:
+ A new instance of `WalkerPose`.
+ """
+ return super(WalkerPose, cls).__new__(
+ cls,
+ qpos=_make_readonly_float64_copy(qpos) if qpos is not None else None,
+ xpos=_make_readonly_float64_copy(xpos),
+ xquat=_make_readonly_float64_copy(xquat))
+
+ def __eq__(self, other):
+ return (np.all(self.qpos == other.qpos) and
+ np.all(self.xpos == other.xpos) and
+ np.all(self.xquat == other.xquat))
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Walker(composer.Robot):
+ """Abstract base class for Walker robots."""
+
+ def create_root_joints(self, attachment_frame):
+ attachment_frame.add('freejoint')
+
+ def _build_observables(self):
+ return WalkerObservables(self)
+
+ def transform_vec_to_egocentric_frame(self, physics, vec_in_world_frame):
+ """Linearly transforms a world-frame vector into walker's egocentric frame.
+
+ Note that this function does not perform an affine transformation of the
+ vector. In other words, the input vector is assumed to be specified with
+ respect to the same origin as this walker's egocentric frame. This function
+ can also be applied to matrices whose innermost dimensions are either 2 or
+ 3. In this case, a matrix with the same leading dimensions is returned
+ where the innermost vectors are replaced by their values computed in the
+ egocentric frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ vec_in_world_frame: A NumPy array with last dimension of shape (2,) or
+ (3,) that represents a vector quantity in the world frame.
+
+ Returns:
+ The same quantity as `vec_in_world_frame` but reexpressed in this
+ entity's egocentric frame. The returned np.array has the same shape as
+ np.asarray(vec_in_world_frame).
+
+ Raises:
+ ValueError: if `vec_in_world_frame` does not have shape ending with (2,)
+ or (3,).
+ """
+ return super(Walker, self).global_vector_to_local_frame(
+ physics, vec_in_world_frame)
+
+ def transform_xmat_to_egocentric_frame(self, physics, xmat):
+ """Transforms another entity's `xmat` into this walker's egocentric frame.
+
+ This function takes another entity's (E) xmat, which is an SO(3) matrix
+ from E's frame to the world frame, and turns it to a matrix that transforms
+ from E's frame into this walker's egocentric frame.
+
+ Args:
+ physics: An `mjcf.Physics` instance.
+ xmat: A NumPy array of shape (3, 3) or (9,) that represents another
+ entity's xmat.
+
+ Returns:
+ The `xmat` reexpressed in this entity's egocentric frame. The returned
+ np.array has the same shape as np.asarray(xmat).
+
+ Raises:
+ ValueError: if `xmat` does not have shape (3, 3) or (9,).
+ """
+ return super(Walker, self).global_xmat_to_local_frame(physics, xmat)
+
+ @abc.abstractproperty
+ def root_body(self):
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def observable_joints(self):
+ raise NotImplementedError
+
+ @property
+ def action_spec(self):
+ minimum, maximum = zip(*[
+ a.ctrlrange if a.ctrlrange is not None else (-1., 1.)
+ for a in self.actuators
+ ])
+ return specs.BoundedArray(
+ shape=(len(self.actuators),),
+ dtype=np.float,
+ minimum=minimum,
+ maximum=maximum,
+ name='\t'.join([actuator.name for actuator in self.actuators]))
+
+ def apply_action(self, physics, action, random_state):
+ """Apply action to walker's actuators."""
+ del random_state
+ physics.bind(self.actuators).ctrl = action
+
+
+class WalkerObservables(composer.Observables):
+ """Base class for Walker obserables."""
+
+ @composer.observable
+ def joints_pos(self):
+ return observable.MJCFFeature('qpos', self._entity.observable_joints)
+
+ @composer.observable
+ def sensors_gyro(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.gyro)
+
+ @composer.observable
+ def sensors_accelerometer(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.accelerometer)
+
+ # Semantic groupings of Walker observables.
+ def _collect_from_attachments(self, attribute_name):
+ out = []
+ for entity in self._entity.iter_entities(exclude_self=True):
+ out.extend(getattr(entity.observables, attribute_name, []))
+ return out
+
+ @property
+ def proprioception(self):
+ return ([self.joints_pos] +
+ self._collect_from_attachments('proprioception'))
+
+ @property
+ def kinematic_sensors(self):
+ return ([self.sensors_gyro, self.sensors_accelerometer] +
+ self._collect_from_attachments('kinematic_sensors'))
+
+ @property
+ def dynamic_sensors(self):
+ return self._collect_from_attachments('dynamic_sensors')
+
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py
new file mode 100644
index 0000000..0794cec
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py
@@ -0,0 +1,98 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for dm_control.locomotion.walkers.base."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.walkers import base
+
+import numpy as np
+
+
+class FakeWalker(base.Walker):
+
+ def _build(self):
+ self._mjcf_root = mjcf.RootElement(model='walker')
+ self._torso_body = self._mjcf_root.worldbody.add(
+ 'body', name='torso', xyaxes=[0, 1, 0, -1, 0, 0])
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def actuators(self):
+ return []
+
+ @property
+ def root_body(self):
+ return self._torso_body
+
+ @property
+ def observable_joints(self):
+ return []
+
+
+class BaseWalkerTest(absltest.TestCase):
+
+ def testTransformVectorToEgocentricFrame(self):
+ walker = FakeWalker()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+
+ # 3D vectors
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 1, 0]), [1, 0, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [-1, 0, 0]),
+ [0, 1, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 0, 1]), [0, 0, 1],
+ atol=1e-10)
+
+ # 2D vectors; z-component is ignored
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [0, 1]), [1, 0],
+ atol=1e-10)
+ np.testing.assert_allclose(
+ walker.transform_vec_to_egocentric_frame(physics, [-1, 0]), [0, 1],
+ atol=1e-10)
+
+ def testTransformMatrixToEgocentricFrame(self):
+ walker = FakeWalker()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+
+ rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])
+ ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]])
+
+ np.testing.assert_allclose(
+ walker.transform_xmat_to_egocentric_frame(physics, rotation_atob),
+ ego_rotation_atob, atol=1e-10)
+
+ flat_rotation_atob = np.reshape(rotation_atob, -1)
+ flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1)
+ np.testing.assert_allclose(
+ walker.transform_xmat_to_egocentric_frame(physics, flat_rotation_atob),
+ flat_rotation_ego_atob, atol=1e-10)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py
new file mode 100644
index 0000000..a8000b4
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py
@@ -0,0 +1,352 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""A CMU humanoid walker."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import os
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+from dm_control.locomotion.walkers import scaled_actuators
+from dm_control.mujoco import wrapper as mj_wrapper
+import numpy as np
+import six
+from six.moves import zip
+
+_XML_PATH = os.path.join(os.path.dirname(__file__), 'assets/humanoid_CMU.xml')
+
+_WALKER_GEOM_GROUP = 2
+
+_CMU_MOCAP_JOINTS = (
+ 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx',
+ 'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz',
+ 'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx',
+ 'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry',
+ 'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz',
+ 'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz',
+ 'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx',
+ 'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx',
+ 'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx',
+ 'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz',
+ 'rthumbrx')
+
+
+# pylint: disable=bad-whitespace
+PositionActuatorParams = collections.namedtuple(
+ 'PositionActuatorParams', ['name', 'forcerange', 'kp'])
+_POSITION_ACTUATORS = [
+ PositionActuatorParams('headrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('headry', [-20, 20 ], 20 ),
+ PositionActuatorParams('headrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lclaviclery', [-20, 20 ], 20 ),
+ PositionActuatorParams('lclaviclerz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lfemurrx', [-120, 120], 120),
+ PositionActuatorParams('lfemurry', [-80, 80 ], 80 ),
+ PositionActuatorParams('lfemurrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('lfingersrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lfootrx', [-50, 50 ], 50 ),
+ PositionActuatorParams('lfootrz', [-50, 50 ], 50 ),
+ PositionActuatorParams('lhandrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lhandrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lhumerusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('lhumerusry', [-60, 60 ], 60 ),
+ PositionActuatorParams('lhumerusrz', [-60, 60 ], 60 ),
+ PositionActuatorParams('lowerbackrx', [-120, 120], 150),
+ PositionActuatorParams('lowerbackry', [-120, 120], 150),
+ PositionActuatorParams('lowerbackrz', [-120, 120], 150),
+ PositionActuatorParams('lowerneckrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lowerneckry', [-20, 20 ], 20 ),
+ PositionActuatorParams('lowerneckrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('lradiusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('lthumbrx', [-20, 20 ], 20) ,
+ PositionActuatorParams('lthumbrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('ltibiarx', [-80, 80 ], 80 ),
+ PositionActuatorParams('ltoesrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('lwristry', [-20, 20 ], 20 ),
+ PositionActuatorParams('rclaviclery', [-20, 20 ], 20 ),
+ PositionActuatorParams('rclaviclerz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rfemurrx', [-120, 120], 120),
+ PositionActuatorParams('rfemurry', [-80, 80 ], 80 ),
+ PositionActuatorParams('rfemurrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('rfingersrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rfootrx', [-50, 50 ], 50 ),
+ PositionActuatorParams('rfootrz', [-50, 50 ], 50 ),
+ PositionActuatorParams('rhandrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rhandrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rhumerusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('rhumerusry', [-60, 60 ], 60 ),
+ PositionActuatorParams('rhumerusrz', [-60, 60 ], 60 ),
+ PositionActuatorParams('rradiusrx', [-60, 60 ], 60 ),
+ PositionActuatorParams('rthumbrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rthumbrz', [-20, 20 ], 20 ),
+ PositionActuatorParams('rtibiarx', [-80, 80 ], 80 ),
+ PositionActuatorParams('rtoesrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('rwristry', [-20, 20 ], 20 ),
+ PositionActuatorParams('thoraxrx', [-80, 80 ], 100),
+ PositionActuatorParams('thoraxry', [-80, 80 ], 100),
+ PositionActuatorParams('thoraxrz', [-80, 80 ], 100),
+ PositionActuatorParams('upperbackrx', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperbackry', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperbackrz', [-80, 80 ], 80 ),
+ PositionActuatorParams('upperneckrx', [-20, 20 ], 20 ),
+ PositionActuatorParams('upperneckry', [-20, 20 ], 20 ),
+ PositionActuatorParams('upperneckrz', [-20, 20 ], 20 ),
+]
+# pylint: enable=bad-whitespace
+
+_UPRIGHT_POS = (0.0, 0.0, 0.94)
+_UPRIGHT_QUAT = (0.859, 1.0, 1.0, 0.859)
+
+# Height of head above which the humanoid is considered standing.
+_STAND_HEIGHT = 1.5
+
+_TORQUE_THRESHOLD = 60
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _CMUHumanoidBase(legacy_base.Walker):
+ """The abstract base class for walkers compatible with the CMU humanoid."""
+
+ def _build(self,
+ name='walker',
+ marker_rgba=None,
+ initializer=None):
+ self._mjcf_root = mjcf.from_path(self._xml_path)
+ if name:
+ self._mjcf_root.model = name
+
+ # Set corresponding marker color if specified.
+ if marker_rgba is not None:
+ for geom in self.marker_geoms:
+ geom.set_attributes(rgba=marker_rgba)
+
+ self._actuator_order = np.argsort(_CMU_MOCAP_JOINTS)
+ self._inverse_order = np.argsort(self._actuator_order)
+
+ super(_CMUHumanoidBase, self)._build(initializer=initializer)
+
+ def _build_observables(self):
+ return CMUHumanoidObservables(self)
+
+ @abc.abstractproperty
+ def _xml_path(self):
+ raise NotImplementedError
+
+ @composer.cached_property
+ def mocap_joints(self):
+ return tuple(
+ self._mjcf_root.find('joint', name) for name in _CMU_MOCAP_JOINTS)
+
+ @property
+ def actuator_order(self):
+ """Index of joints from the CMU mocap dataset sorted alphabetically by name.
+
+ Actuators in this walkers are ordered alphabetically by name. This property
+ provides a mapping between from actuator ordering to canonical CMU ordering.
+
+ Returns:
+ A list of integers corresponding to joint indices from the CMU dataset.
+ Specifically, the n-th element in the list is the index of the CMU joint
+ index that corresponds to the n-th actuator in this walker.
+ """
+ return self._actuator_order
+
+ @property
+ def actuator_to_joint_order(self):
+ """Index of actuators corresponding to each CMU mocap joint.
+
+ Actuators in this walkers are ordered alphabetically by name. This property
+ provides a mapping between from canonical CMU ordering to actuator ordering.
+
+ Returns:
+ A list of integers corresponding to actuator indices within this walker.
+ Specifically, the n-th element in the list is the index of the actuator
+ in this walker that corresponds to the n-th joint from the CMU mocap
+ dataset.
+ """
+ return self._inverse_order
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ return tuple(self._mjcf_root.find_all('actuator'))
+
+ @composer.cached_property
+ def root_body(self):
+ return self._mjcf_root.find('body', 'root')
+
+ @composer.cached_property
+ def head(self):
+ return self._mjcf_root.find('body', 'head')
+
+ @composer.cached_property
+ def left_arm_root(self):
+ return self._mjcf_root.find('body', 'lclavicle')
+
+ @composer.cached_property
+ def right_arm_root(self):
+ return self._mjcf_root.find('body', 'rclavicle')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ return tuple(self._mjcf_root.find('body', 'lfoot').find_all('geom') +
+ self._mjcf_root.find('body', 'rfoot').find_all('geom'))
+
+ @composer.cached_property
+ def standing_height(self):
+ return _STAND_HEIGHT
+
+ @composer.cached_property
+ def end_effectors(self):
+ return (self._mjcf_root.find('body', 'rradius'),
+ self._mjcf_root.find('body', 'lradius'),
+ self._mjcf_root.find('body', 'rfoot'),
+ self._mjcf_root.find('body', 'lfoot'))
+
+ @composer.cached_property
+ def observable_joints(self):
+ return tuple(actuator.joint for actuator in self.actuators
+ if actuator.joint is not None)
+
+ @composer.cached_property
+ def bodies(self):
+ return tuple(self._mjcf_root.find_all('body'))
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @composer.cached_property
+ def body_camera(self):
+ return self._mjcf_root.find('camera', 'bodycam')
+
+ @property
+ def marker_geoms(self):
+ return (self._mjcf_root.find('geom', 'rradius'),
+ self._mjcf_root.find('geom', 'lradius'))
+
+
+class CMUHumanoid(_CMUHumanoidBase):
+ """A CMU humanoid walker."""
+
+ @property
+ def _xml_path(self):
+ return _XML_PATH
+
+
+class CMUHumanoidPositionControlled(CMUHumanoid):
+ """A position-controlled CMU humanoid with control range scaled to [-1, 1]."""
+
+ def _build(self, *args, **kwargs):
+ super(CMUHumanoidPositionControlled, self)._build(*args, **kwargs)
+ self._mjcf_root.default.general.forcelimited = 'true'
+ self._mjcf_root.actuator.motor.clear()
+ for actuator_params in _POSITION_ACTUATORS:
+ associated_joint = self._mjcf_root.find('joint', actuator_params.name)
+ scaled_actuators.add_position_actuator(
+ name=actuator_params.name,
+ target=associated_joint,
+ kp=actuator_params.kp,
+ qposrange=associated_joint.range,
+ ctrlrange=(-1, 1),
+ forcerange=actuator_params.forcerange)
+ limits = zip(*(actuator.joint.range for actuator in self.actuators)) # pylint: disable=not-an-iterable
+ lower, upper = (np.array(limit) for limit in limits)
+ self._scale = upper - lower
+ self._offset = upper + lower
+
+ def cmu_pose_to_actuation(self, target_pose):
+ """Creates the control signal corresponding a CMU mocap joints pose.
+
+ Args:
+ target_pose: An array containing the target position for each joint.
+ These must be given in "canonical CMU order" rather than "qpos order",
+ i.e. the order of `target_pose[self.actuator_order]` should correspond
+ to the order of `physics.bind(self.actuators).ctrl`.
+
+ Returns:
+ An array of the same shape as `target_pose` containing inputs for position
+ controllers. Writing these values into `physics.bind(self.actuators).ctrl`
+ will cause the actuators to drive joints towards `target_pose`.
+ """
+ return (2 * target_pose[self.actuator_order] - self._offset) / self._scale
+
+
+class CMUHumanoidObservables(legacy_base.WalkerObservables):
+ """Observables for the Humanoid."""
+
+ @composer.observable
+ def body_camera(self):
+ options = mj_wrapper.MjvOption()
+
+ # Don't render this walker's geoms.
+ options.geomgroup[_WALKER_GEOM_GROUP] = 0
+ return observable.MJCFCamera(
+ self._entity.body_camera, width=64, height=64, scene_option=options)
+
+ @composer.observable
+ def head_height(self):
+ return observable.MJCFFeature('xpos', self._entity.head)[2]
+
+ @composer.observable
+ def sensors_torque(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.torque,
+ corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD))
+
+ @composer.observable
+ def actuator_activation(self):
+ return observable.MJCFFeature('act',
+ self._entity.mjcf_model.find_all('actuator'))
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` with the head's position appended."""
+ def relative_pos_in_egocentric_frame(physics):
+ end_effectors_with_head = (
+ self._entity.end_effectors + (self._entity.head,))
+ end_effector = physics.bind(end_effectors_with_head).xpos
+ torso = physics.bind(self._entity.root_body).xpos
+ xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(end_effector - torso, xmat), -1)
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @property
+ def proprioception(self):
+ return [
+ self.joints_pos,
+ self.joints_vel,
+ self.actuator_activation,
+ self.body_height,
+ self.end_effectors_pos,
+ self.appendages_pos,
+ self.world_zaxis
+ ] + self._collect_from_attachments('proprioception')
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py
new file mode 100644
index 0000000..677bd1a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py
@@ -0,0 +1,166 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for the CMU humanoid."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.walkers import cmu_humanoid
+import numpy as np
+from six.moves import range
+from six.moves import zip
+
+
+class CMUHumanoidTest(parameterized.TestCase):
+
+ @parameterized.parameters([
+ cmu_humanoid.CMUHumanoid,
+ cmu_humanoid.CMUHumanoidPositionControlled,
+ ])
+ def test_can_compile_and_step_simulation(self, walker_type):
+ walker = walker_type()
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ cmu_humanoid.CMUHumanoid,
+ cmu_humanoid.CMUHumanoidPositionControlled,
+ ])
+ def test_actuators_sorted_alphabetically(self, walker_type):
+ walker = walker_type()
+ actuator_names = [
+ actuator.name for actuator in walker.mjcf_model.find_all('actuator')]
+ np.testing.assert_array_equal(actuator_names, sorted(actuator_names))
+
+ def test_actuator_to_mocap_joint_mapping(self):
+ walker = cmu_humanoid.CMUHumanoid()
+
+ with self.subTest('Forward mapping'):
+ for actuator_num, cmu_mocap_joint_num in enumerate(walker.actuator_order):
+ self.assertEqual(walker.actuator_to_joint_order[cmu_mocap_joint_num],
+ actuator_num)
+
+ with self.subTest('Inverse mapping'):
+ for cmu_mocap_joint_num, actuator_num in enumerate(
+ walker.actuator_to_joint_order):
+ self.assertEqual(walker.actuator_order[actuator_num],
+ cmu_mocap_joint_num)
+
+ def test_cmu_humanoid_position_controlled_has_correct_actuators(self):
+ walker_torque = cmu_humanoid.CMUHumanoid()
+ walker_pos = cmu_humanoid.CMUHumanoidPositionControlled()
+
+ actuators_torque = walker_torque.mjcf_model.find_all('actuator')
+ actuators_pos = walker_pos.mjcf_model.find_all('actuator')
+
+ actuator_pos_params = {
+ params.name: params for params in cmu_humanoid._POSITION_ACTUATORS}
+
+ self.assertEqual(len(actuators_torque), len(actuators_pos))
+
+ for actuator_torque, actuator_pos in zip(actuators_torque, actuators_pos):
+ self.assertEqual(actuator_pos.name, actuator_torque.name)
+ self.assertEqual(actuator_pos.joint.full_identifier,
+ actuator_torque.joint.full_identifier)
+ self.assertEqual(actuator_pos.tag, 'general')
+ self.assertEqual(actuator_pos.ctrllimited, 'true')
+ np.testing.assert_array_equal(actuator_pos.ctrlrange, (-1, 1))
+
+ expected_params = actuator_pos_params[actuator_pos.name]
+ self.assertEqual(actuator_pos.biasprm[1], -expected_params.kp)
+ np.testing.assert_array_equal(actuator_pos.forcerange,
+ expected_params.forcerange)
+
+ @parameterized.parameters([
+ 'body_camera',
+ 'egocentric_camera',
+ 'head',
+ 'left_arm_root',
+ 'right_arm_root',
+ 'root_body',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'bodies',
+ 'end_effectors',
+ 'marker_geoms',
+ 'mocap_joints',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = cmu_humanoid.CMUHumanoid(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ def test_set_marker_rgba(self):
+ marker_rgba = (1., 0., 1., 0.5)
+ walker = cmu_humanoid.CMUHumanoid(marker_rgba=marker_rgba)
+ for marker_geom in walker.marker_geoms:
+ np.testing.assert_array_equal(marker_geom.rgba, marker_rgba)
+
+ @parameterized.parameters(
+ 'actuator_activation',
+ 'appendages_pos',
+ 'body_camera',
+ 'head_height',
+ 'sensors_torque',
+ )
+ def test_evaluate_observable(self, name):
+ walker = cmu_humanoid.CMUHumanoid()
+ observable = getattr(walker.observables, name)
+ physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = cmu_humanoid.CMUHumanoid()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+ def test_cmu_pose_to_actuation(self):
+ walker = cmu_humanoid.CMUHumanoidPositionControlled()
+ random_state = np.random.RandomState(123)
+
+ expected_actuation = random_state.uniform(-1, 1, len(walker.actuator_order))
+
+ cmu_limits = zip(*(joint.range for joint in walker.mocap_joints))
+ cmu_lower, cmu_upper = (np.array(limit) for limit in cmu_limits)
+ cmu_pose = cmu_lower + (cmu_upper - cmu_lower) * (
+ 1 + expected_actuation[walker.actuator_to_joint_order]) / 2
+
+ actual_actuation = walker.cmu_pose_to_actuation(cmu_pose)
+
+ np.testing.assert_allclose(actual_actuation, expected_actuation)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py
new file mode 100644
index 0000000..766fa7c
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py
@@ -0,0 +1,61 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Initializers for the locomotion walkers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import numpy as np
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class WalkerInitializer(object):
+ """The abstract base class for a walker initializer."""
+
+ @abc.abstractmethod
+ def initialize_pose(self, physics, walker, random_state):
+ raise NotImplementedError
+
+
+class UprightInitializer(WalkerInitializer):
+ """An initializer that uses the walker-declared upright pose."""
+
+ def initialize_pose(self, physics, walker, random_state):
+ all_joints_binding = physics.bind(walker.mjcf_model.find_all('joint'))
+ qpos, xpos, xquat = walker.upright_pose
+ if qpos is None:
+ all_joints_binding.qpos = all_joints_binding.qpos0
+ else:
+ all_joints_binding.qpos = qpos
+ walker.set_pose(physics, position=xpos, quaternion=xquat)
+ walker.set_velocity(
+ physics, velocity=np.zeros(3), angular_velocity=np.zeros(3))
+
+
+class RandomlySampledInitializer(WalkerInitializer):
+ """Initializer that random selects between many initializers."""
+
+ def __init__(self, initializers):
+ self._initializers = initializers
+ self.num_initializers = len(initializers)
+
+ def initialize_pose(self, physics, walker, random_state):
+ random_initalizer_idx = np.random.randint(0, self.num_initializers)
+ self._initializers[random_initalizer_idx].initialize_pose(
+ physics, walker, random_state)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py
new file mode 100644
index 0000000..995fe02
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py
@@ -0,0 +1,347 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Base class for Walkers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from dm_control import composer
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import initializers
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+
+import numpy as np
+
+_RANGEFINDER_SCALE = 10.0
+_TOUCH_THRESHOLD = 1e-3
+
+
+class Walker(base.Walker):
+ """Legacy base class for Walker robots."""
+
+ def _build(self, initializer=None):
+ try:
+ self._initializers = tuple(initializer)
+ except TypeError:
+ self._initializers = (initializer or initializers.UprightInitializer(),)
+
+ @property
+ def upright_pose(self):
+ return base.WalkerPose()
+
+ def _build_observables(self):
+ return WalkerObservables(self)
+
+ def reinitialize_pose(self, physics, random_state):
+ for initializer in self._initializers:
+ initializer.initialize_pose(physics, self, random_state)
+
+ def aliveness(self, physics):
+ """A measure of the aliveness of the walker.
+
+ Aliveness measure could be used for deciding on termination (ant flipped
+ over and it's impossible for it to recover), or used as a shaping reward
+ to maintain an alive pose that we desired (humanoids remaining upright).
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a `float` in the range of [-1., 0.] where -1 means not alive and 0. means
+ alive. In walkers for which the concept of aliveness does not make sense,
+ the default implementation is to always return 0.0.
+ """
+ return 0.
+
+ @abc.abstractproperty
+ def ground_contact_geoms(self):
+ """Geoms in this walker that are expected to be in contact with the ground.
+
+ This property is used by some tasks to determine contact-based failure
+ termination. It should only contain geoms that are expected to be in
+ contact with the ground during "normal" locomotion. For example, for a
+ humanoid model, this property would be expected to contain only the geoms
+ that make up the two feet.
+
+ Note that certain specialized tasks may also allow geoms that are not listed
+ here to be in contact with the ground. For example, a humanoid cartwheel
+ task would also allow the hands to touch the ground in addition to the feet.
+ """
+ raise NotImplementedError
+
+ def after_compile(self, physics, unused_random_state):
+ super(Walker, self).after_compile(physics, unused_random_state)
+ self._end_effector_geom_ids = set()
+ for eff_body in self.end_effectors:
+ eff_geom = eff_body.find_all('geom')
+ self._end_effector_geom_ids |= set(physics.bind(eff_geom).element_id)
+ self._body_geom_ids = set(
+ physics.bind(geom).element_id
+ for geom in self.mjcf_model.find_all('geom'))
+ self._body_geom_ids.difference_update(self._end_effector_geom_ids)
+
+ @property
+ def end_effector_geom_ids(self):
+ return self._end_effector_geom_ids
+
+ @property
+ def body_geom_ids(self):
+ return self._body_geom_ids
+
+ def end_effector_contacts(self, physics):
+ """Collect the contacts with the end effectors.
+
+ This function returns any contacts being made with any of the end effectors,
+ both the other geom with which contact is being made as well as the
+ magnitude.
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a dict with as key a tuple of geom ids, of which one is an end effector,
+ and as value the total magnitude of all contacts between these geoms
+ """
+ return self.collect_contacts(physics, self._end_effector_geom_ids)
+
+ def body_contacts(self, physics):
+ """Collect the contacts with the body.
+
+ This function returns any contacts being made with any of body geoms, except
+ the end effectors, both the other geom with which contact is being made as
+ well as the magnitude.
+
+ Args:
+ physics: an instance of `Physics`.
+
+ Returns:
+ a dict with as key a tuple of geom ids, of which one is a body geom,
+ and as value the total magnitude of all contacts between these geoms
+ """
+ return self.collect_contacts(physics, self._body_geom_ids)
+
+ def collect_contacts(self, physics, geom_ids):
+ contacts = {}
+ forcetorque = np.zeros(6)
+ for i, contact in enumerate(physics.data.contact):
+ if ((contact.geom1 in geom_ids) or
+ (contact.geom2 in geom_ids)) and contact.dist < contact.includemargin:
+ mjlib.mj_contactForce(physics.model.ptr, physics.data.ptr, i,
+ forcetorque)
+ contacts[(contact.geom1, contact.geom2)] = (forcetorque[0]
+ + contacts.get(
+ (contact.geom1,
+ contact.geom2), 0.))
+ return contacts
+
+ @abc.abstractproperty
+ def end_effectors(self):
+ raise NotImplementedError
+
+ @abc.abstractproperty
+ def egocentric_camera(self):
+ raise NotImplementedError
+
+ @composer.cached_property
+ def touch_sensors(self):
+ return self._mjcf_root.sensor.get_children('touch')
+
+ @property
+ def prev_action(self):
+ """Returns the actuation actions applied in the previous step.
+
+ Concrete walker implementations should provide caching mechanism themselves
+ in order to access this observable (for example, through `apply_action`).
+ """
+ raise NotImplementedError
+
+ def after_substep(self, physics, random_state):
+ del random_state # Unused.
+ # As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped
+ # unless these quantities are needed by the simulation. We need these in
+ # order to calculate `torso_{x,y}vel`, so we therefore call `mj_subtreeVel`
+ # explicitly.
+ # TODO(b/123065920): Consider using a `subtreelinvel` sensor instead.
+ mjlib.mj_subtreeVel(physics.model.ptr, physics.data.ptr)
+
+
+class WalkerObservables(base.WalkerObservables):
+ """Legacy base class for Walker obserables."""
+
+ @composer.observable
+ def joints_vel(self):
+ return observable.MJCFFeature('qvel', self._entity.observable_joints)
+
+ @composer.observable
+ def body_height(self):
+ return observable.MJCFFeature('xpos', self._entity.root_body)[2]
+
+ @composer.observable
+ def end_effectors_pos(self):
+ """Position of end effectors relative to torso, in the egocentric frame."""
+ def relative_pos_in_egocentric_frame(physics):
+ end_effector = physics.bind(self._entity.end_effectors).xpos
+ torso = physics.bind(self._entity.root_body).xpos
+ xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(end_effector - torso, xmat), -1)
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @composer.observable
+ def world_zaxis(self):
+ """The world's z-vector in this Walker's torso frame."""
+ return observable.MJCFFeature('xmat', self._entity.root_body)[6:]
+
+ @composer.observable
+ def sensors_velocimeter(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.velocimeter)
+
+ @composer.observable
+ def sensors_force(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.force)
+
+ @composer.observable
+ def sensors_torque(self):
+ return observable.MJCFFeature('sensordata',
+ self._entity.mjcf_model.sensor.torque)
+
+ @composer.observable
+ def sensors_touch(self):
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.sensor.touch,
+ corruptor=
+ lambda v, random_state: np.array(v > _TOUCH_THRESHOLD, dtype=np.float))
+
+ @composer.observable
+ def sensors_rangefinder(self):
+ def tanh_rangefinder(physics):
+ raw = physics.bind(self._entity.mjcf_model.sensor.rangefinder).sensordata
+ raw = np.array(raw)
+ raw[raw == -1.0] = np.inf
+ return _RANGEFINDER_SCALE * np.tanh(raw / _RANGEFINDER_SCALE)
+ return observable.Generic(tanh_rangefinder)
+
+ @composer.observable
+ def egocentric_camera(self):
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=64, height=64)
+
+ @composer.observable
+ def position(self):
+ return observable.MJCFFeature('xpos', self._entity.root_body)
+
+ @composer.observable
+ def orientation(self):
+ return observable.MJCFFeature('xmat', self._entity.root_body)
+
+ def add_egocentric_vector(self,
+ name,
+ world_frame_observable,
+ enabled=True,
+ origin_callable=None,
+ **kwargs):
+
+ def _egocentric(physics, origin_callable=origin_callable):
+ vec = world_frame_observable.observation_callable(physics)()
+ origin_callable = origin_callable or (lambda physics: np.zeros(vec.size))
+ delta = vec - origin_callable(physics)
+ return self._entity.transform_vec_to_egocentric_frame(physics, delta)
+
+ self._observables[name] = observable.Generic(_egocentric, **kwargs)
+ self._observables[name].enabled = enabled
+
+ def add_egocentric_xmat(self, name, xmat_observable, enabled=True, **kwargs):
+
+ def _egocentric(physics):
+ return self._entity.transform_xmat_to_egocentric_frame(
+ physics,
+ xmat_observable.observation_callable(physics)())
+
+ self._observables[name] = observable.Generic(_egocentric, **kwargs)
+ self._observables[name].enabled = enabled
+
+ # Semantic groupings of Walker observables.
+ def _collect_from_attachments(self, attribute_name):
+ out = []
+ for entity in self._entity.iter_entities(exclude_self=True):
+ out.extend(getattr(entity.observables, attribute_name, []))
+ return out
+
+ @property
+ def proprioception(self):
+ return ([self.joints_pos, self.joints_vel,
+ self.body_height, self.end_effectors_pos, self.world_zaxis] +
+ self._collect_from_attachments('proprioception'))
+
+ @property
+ def kinematic_sensors(self):
+ return ([self.sensors_gyro, self.sensors_velocimeter,
+ self.sensors_accelerometer] +
+ self._collect_from_attachments('kinematic_sensors'))
+
+ @property
+ def dynamic_sensors(self):
+ return ([self.sensors_force, self.sensors_torque, self.sensors_touch] +
+ self._collect_from_attachments('dynamic_sensors'))
+
+ # Convenience observables for defining rewards and terminations.
+ @composer.observable
+ def veloc_strafe(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[1]
+
+ @composer.observable
+ def veloc_up(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[2]
+
+ @composer.observable
+ def veloc_forward(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[0]
+
+ @composer.observable
+ def gyro_backward_roll(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[0]
+
+ @composer.observable
+ def gyro_rightward_roll(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[1]
+
+ @composer.observable
+ def gyro_anticlockwise_spin(self):
+ return observable.MJCFFeature(
+ 'sensordata', self._entity.mjcf_model.sensor.gyro)[2]
+
+ @composer.observable
+ def torso_xvel(self):
+ return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[0]
+
+ @composer.observable
+ def torso_yvel(self):
+ return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[1]
+
+ @composer.observable
+ def prev_action(self):
+ return observable.Generic(lambda _: self._entity.prev_action)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py
new file mode 100644
index 0000000..9842718
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py
@@ -0,0 +1,320 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""A Rodent walker."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.walkers import base
+from dm_control.locomotion.walkers import legacy_base
+from dm_control.mujoco import wrapper as mj_wrapper
+import numpy as np
+
+_XML_PATH = os.path.join(os.path.dirname(__file__),
+ 'assets/rodent.xml')
+
+_RAT_MOCAP_JOINTS = [
+ 'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist',
+ 'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist',
+ 'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L',
+ 'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R',
+ 'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend',
+ 'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend',
+ 'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend',
+ 'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend',
+ 'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend',
+ 'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend',
+ 'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend',
+ 'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend',
+ 'vertebra_C29_bend', 'vertebra_cervical_5_extend',
+ 'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist',
+ 'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend',
+ 'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible',
+ 'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L',
+ 'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate',
+ 'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R',
+ 'elbow_R', 'wrist_R', 'finger_R'
+]
+
+
+_UPRIGHT_POS = (0.0, 0.0, 0.0)
+_UPRIGHT_QUAT = (1., 0., 0., 0.)
+_TORQUE_THRESHOLD = 60
+
+
+class Rat(legacy_base.Walker):
+ """A position-controlled rat with control range scaled to [-1, 1]."""
+
+ def _build(self,
+ params=None,
+ name='walker',
+ initializer=None):
+ self.params = params
+ self._mjcf_root = mjcf.from_path(_XML_PATH)
+ if name:
+ self._mjcf_root.model = name
+
+ self.body_sites = []
+ super(Rat, self)._build(initializer=initializer)
+
+ @property
+ def upright_pose(self):
+ """Reset pose to upright position."""
+ return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT)
+
+ @property
+ def mjcf_model(self):
+ """Return the model root."""
+ return self._mjcf_root
+
+ @composer.cached_property
+ def actuators(self):
+ """Return all actuators."""
+ return tuple(self._mjcf_root.find_all('actuator'))
+
+ @composer.cached_property
+ def root_body(self):
+ """Return the body."""
+ return self._mjcf_root.find('body', 'torso')
+
+ @composer.cached_property
+ def pelvis_body(self):
+ """Return the body."""
+ return self._mjcf_root.find('body', 'pelvis')
+
+ @composer.cached_property
+ def head(self):
+ """Return the head."""
+ return self._mjcf_root.find('body', 'skull')
+
+ @composer.cached_property
+ def left_arm_root(self):
+ """Return the left arm."""
+ return self._mjcf_root.find('body', 'scapula_L')
+
+ @composer.cached_property
+ def right_arm_root(self):
+ """Return the right arm."""
+ return self._mjcf_root.find('body', 'scapula_R')
+
+ @composer.cached_property
+ def ground_contact_geoms(self):
+ """Return ground contact geoms."""
+ return tuple(
+ self._mjcf_root.find('body', 'foot_L').find_all('geom') +
+ self._mjcf_root.find('body', 'foot_R').find_all('geom'))
+
+ @composer.cached_property
+ def standing_height(self):
+ """Return standing height."""
+ return self.params['_STAND_HEIGHT']
+
+ @composer.cached_property
+ def end_effectors(self):
+ """Return end effectors."""
+ return (self._mjcf_root.find('body', 'lower_arm_R'),
+ self._mjcf_root.find('body', 'lower_arm_L'),
+ self._mjcf_root.find('body', 'foot_R'),
+ self._mjcf_root.find('body', 'foot_L'))
+
+ @composer.cached_property
+ def observable_joints(self):
+ """Return observable joints."""
+ return tuple(actuator.joint
+ for actuator in self.actuators # This lint is mistaken; pylint: disable=not-an-iterable
+ if actuator.joint is not None)
+
+ @composer.cached_property
+ def observable_tendons(self):
+ return self._mjcf_root.find_all('tendon')
+
+ @composer.cached_property
+ def mocap_joints(self):
+ return tuple(
+ self._mjcf_root.find('joint', name) for name in _RAT_MOCAP_JOINTS)
+
+ @composer.cached_property
+ def mocap_joint_order(self):
+ return tuple([jnt.name for jnt in self.mocap_joints]) # This lint is mistaken; pylint: disable=not-an-iterable
+
+ @composer.cached_property
+ def bodies(self):
+ """Return all bodies."""
+ return tuple(self._mjcf_root.find_all('body'))
+
+ @composer.cached_property
+ def mocap_bodies(self):
+ """Return bodies for mocap comparison."""
+ return tuple(body for body in self._mjcf_root.find_all('body')
+ if not re.match(r'(vertebra|hand|toe)', body.name))
+
+ @composer.cached_property
+ def primary_joints(self):
+ """Return primary (non-vertebra) joints."""
+ return tuple(jnt for jnt in self._mjcf_root.find_all('joint')
+ if 'vertebra' not in jnt.name)
+
+ @composer.cached_property
+ def vertebra_joints(self):
+ """Return vertebra joints."""
+ return tuple(jnt for jnt in self._mjcf_root.find_all('joint')
+ if 'vertebra' in jnt.name)
+
+ @composer.cached_property
+ def primary_joint_order(self):
+ joint_names = self.mocap_joint_order
+ primary_names = tuple([jnt.name for jnt in self.primary_joints]) # pylint: disable=not-an-iterable
+ primary_order = []
+ for nm in primary_names:
+ primary_order.append(joint_names.index(nm))
+ return primary_order
+
+ @composer.cached_property
+ def vertebra_joint_order(self):
+ joint_names = self.mocap_joint_order
+ vertebra_names = tuple([jnt.name for jnt in self.vertebra_joints]) # pylint: disable=not-an-iterable
+ vertebra_order = []
+ for nm in vertebra_names:
+ vertebra_order.append(joint_names.index(nm))
+ return vertebra_order
+
+ @composer.cached_property
+ def egocentric_camera(self):
+ """Return the egocentric camera."""
+ return self._mjcf_root.find('camera', 'egocentric')
+
+ @property
+ def _xml_path(self):
+ """Return the path to th model .xml file."""
+ return self.params['_XML_PATH']
+
+ @composer.cached_property
+ def joint_actuators(self):
+ """Return all joint actuators."""
+ return tuple([act for act in self._mjcf_root.find_all('actuator')
+ if act.joint])
+
+ @composer.cached_property
+ def joint_actuators_range(self):
+ act_joint_range = []
+ for act in self.joint_actuators: # This lint is mistaken; pylint: disable=not-an-iterable
+ associated_joint = self._mjcf_root.find('joint', act.name)
+ act_range = associated_joint.dclass.joint.range
+ act_joint_range.append(act_range)
+ return act_joint_range
+
+ def pose_to_actuation(self, pose):
+ # holds for joint actuators, find desired torque = 0
+ # u_ref = [2 q_ref - (r_low + r_up) ]/(r_up - r_low)
+ r_lower = np.array([ajr[0] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable
+ r_upper = np.array([ajr[1] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable
+ num_tendon_actuators = len(self.actuators) - len(self.joint_actuators)
+ tendon_actions = np.zeros(num_tendon_actuators)
+ return np.hstack([tendon_actions, (2*pose[self.joint_actuator_order]-
+ (r_lower+r_upper))/(r_upper-r_lower)])
+
+ @composer.cached_property
+ def joint_actuator_order(self):
+ joint_names = self.mocap_joint_order
+ joint_actuator_names = tuple([act.name for act in self.joint_actuators]) # This lint is mistaken; pylint: disable=not-an-iterable
+ actuator_order = []
+ for nm in joint_actuator_names:
+ actuator_order.append(joint_names.index(nm))
+ return actuator_order
+
+ def _build_observables(self):
+ return RodentObservables(self)
+
+
+class RodentObservables(legacy_base.WalkerObservables):
+ """Observables for the Rat."""
+
+ @composer.observable
+ def head_height(self):
+ """Observe the head height."""
+ return observable.MJCFFeature('xpos', self._entity.head)[2]
+
+ @composer.observable
+ def sensors_torque(self):
+ """Observe the torque sensors."""
+ return observable.MJCFFeature(
+ 'sensordata',
+ self._entity.mjcf_model.sensor.torque,
+ corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD)
+ )
+
+ @composer.observable
+ def tendons_pos(self):
+ return observable.MJCFFeature('length', self._entity.observable_tendons)
+
+ @composer.observable
+ def tendons_vel(self):
+ return observable.MJCFFeature('velocity', self._entity.observable_tendons)
+
+ @composer.observable
+ def actuator_activation(self):
+ """Observe the actuator activation."""
+ model = self._entity.mjcf_model
+ return observable.MJCFFeature('act', model.find_all('actuator'))
+
+ @composer.observable
+ def appendages_pos(self):
+ """Equivalent to `end_effectors_pos` with head's position appended."""
+
+ def relative_pos_in_egocentric_frame(physics):
+ end_effectors_with_head = (
+ self._entity.end_effectors + (self._entity.head,))
+ end_effector = physics.bind(end_effectors_with_head).xpos
+ torso = physics.bind(self._entity.root_body).xpos
+ xmat = \
+ np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3))
+ return np.reshape(np.dot(end_effector - torso, xmat), -1)
+
+ return observable.Generic(relative_pos_in_egocentric_frame)
+
+ @property
+ def proprioception(self):
+ """Return proprioceptive information."""
+ return [
+ self.joints_pos, self.joints_vel,
+ self.tendons_pos, self.tendons_vel,
+ self.actuator_activation,
+ self.body_height, self.end_effectors_pos, self.appendages_pos,
+ self.world_zaxis
+ ] + self._collect_from_attachments('proprioception')
+
+ @composer.observable
+ def egocentric_camera(self):
+ """Observable of the egocentric camera."""
+
+ if not hasattr(self, '_scene_options'):
+ # Don't render this walker's geoms.
+ self._scene_options = mj_wrapper.MjvOption()
+ collision_geom_group = 2
+ self._scene_options.geomgroup[collision_geom_group] = 0
+ cosmetic_geom_group = 1
+ self._scene_options.geomgroup[cosmetic_geom_group] = 0
+
+ return observable.MJCFCamera(self._entity.egocentric_camera,
+ width=64, height=64,
+ scene_option=self._scene_options
+ )
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py
new file mode 100644
index 0000000..42b31d2
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py
@@ -0,0 +1,116 @@
+# Copyright 2020 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Tests for the Rodent."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation.observable import base as observable_base
+from dm_control.locomotion.arenas import corridors as corr_arenas
+from dm_control.locomotion.tasks import corridors as corr_tasks
+from dm_control.locomotion.walkers import rodent
+
+import numpy as np
+from six.moves import range
+
+_CONTROL_TIMESTEP = .02
+_PHYSICS_TIMESTEP = 0.001
+
+
+def _get_rat_corridor_physics():
+ walker = rodent.Rat()
+ arena = corr_arenas.EmptyCorridor()
+ task = corr_tasks.RunThroughCorridor(
+ walker=walker,
+ arena=arena,
+ walker_spawn_position=(5, 0, 0),
+ walker_spawn_rotation=0,
+ physics_timestep=_PHYSICS_TIMESTEP,
+ control_timestep=_CONTROL_TIMESTEP)
+
+ env = composer.Environment(
+ time_limit=30,
+ task=task,
+ strip_singleton_obs_buffer_dim=True)
+
+ return walker, env
+
+
+class RatTest(parameterized.TestCase):
+
+ def test_can_compile_and_step_simulation(self):
+ _, env = _get_rat_corridor_physics()
+ physics = env.physics
+ for _ in range(100):
+ physics.step()
+
+ @parameterized.parameters([
+ 'egocentric_camera',
+ 'head',
+ 'left_arm_root',
+ 'right_arm_root',
+ 'root_body',
+ 'pelvis_body',
+ ])
+ def test_get_element_property(self, name):
+ attribute_value = getattr(rodent.Rat(), name)
+ self.assertIsInstance(attribute_value, mjcf.Element)
+
+ @parameterized.parameters([
+ 'actuators',
+ 'bodies',
+ 'mocap_bodies',
+ 'end_effectors',
+ 'mocap_joints',
+ 'observable_joints',
+ ])
+ def test_get_element_tuple_property(self, name):
+ attribute_value = getattr(rodent.Rat(), name)
+ self.assertNotEmpty(attribute_value)
+ for item in attribute_value:
+ self.assertIsInstance(item, mjcf.Element)
+
+ def test_set_name(self):
+ name = 'fred'
+ walker = rodent.Rat(name=name)
+ self.assertEqual(walker.mjcf_model.model, name)
+
+ @parameterized.parameters(
+ 'tendons_pos',
+ 'tendons_vel',
+ 'actuator_activation',
+ 'appendages_pos',
+ 'head_height',
+ 'sensors_torque',
+ )
+ def test_evaluate_observable(self, name):
+ walker, env = _get_rat_corridor_physics()
+ physics = env.physics
+ observable = getattr(walker.observables, name)
+ observation = observable(physics)
+ self.assertIsInstance(observation, (float, np.ndarray))
+
+ def test_proprioception(self):
+ walker = rodent.Rat()
+ for item in walker.observables.proprioception:
+ self.assertIsInstance(item, observable_base.Observable)
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py
new file mode 100644
index 0000000..806f5c3
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py
@@ -0,0 +1,131 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Position & velocity actuators whose controls are scaled to a given range."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+_DISALLOWED_KWARGS = frozenset(
+ ['biastype', 'gainprm', 'biasprm', 'ctrllimited',
+ 'joint', 'tendon', 'site', 'slidersite', 'cranksite'])
+_ALLOWED_TAGS = frozenset(['joint', 'tendon', 'site'])
+
+_GOT_INVALID_KWARGS = 'Received invalid keyword argument(s): {}'
+_GOT_INVALID_TARGET = '`target` tag type should be one of {}: got {{}}'.format(
+ sorted(_ALLOWED_TAGS))
+
+
+def _check_target_and_kwargs(target, **kwargs):
+ invalid_kwargs = _DISALLOWED_KWARGS.intersection(kwargs)
+ if invalid_kwargs:
+ raise TypeError(_GOT_INVALID_KWARGS.format(sorted(invalid_kwargs)))
+ if target.tag not in _ALLOWED_TAGS:
+ raise TypeError(_GOT_INVALID_TARGET.format(target))
+
+
+def add_position_actuator(target, qposrange, ctrlrange=(-1, 1),
+ kp=1.0, **kwargs):
+ """Adds a scaled position actuator that is bound to the specified element.
+
+ This is equivalent to MuJoCo's built-in `` actuator where an affine
+ transformation is pre-applied to the control signal, such that the minimum
+ control value corresponds to the minimum desired position, and the
+ maximum control value corresponds to the maximum desired position.
+
+ Args:
+ target: A PyMJCF joint, tendon, or site element object that is to be
+ controlled.
+ qposrange: A sequence of two numbers specifying the allowed range of target
+ position.
+ ctrlrange: A sequence of two numbers specifying the allowed range of
+ this actuator's control signal.
+ kp: The gain parameter of this position actuator.
+ **kwargs: Additional MJCF attributes for this actuator element.
+ The following attributes are disallowed: `['biastype', 'gainprm',
+ 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site',
+ 'slidersite', 'cranksite']`.
+
+ Returns:
+ A PyMJCF actuator element that has been added to the MJCF model containing
+ the specified `target`.
+
+ Raises:
+ TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute,
+ or `target` is not an allowed MJCF element type.
+ """
+ _check_target_and_kwargs(target, **kwargs)
+ kwargs[target.tag] = target
+
+ slope = (qposrange[1] - qposrange[0]) / (ctrlrange[1] - ctrlrange[0])
+ g0 = kp * slope
+ b0 = kp * (qposrange[0] - slope * ctrlrange[0])
+ b1 = -kp
+ b2 = 0
+ return target.root.actuator.add('general',
+ biastype='affine',
+ gainprm=[g0],
+ biasprm=[b0, b1, b2],
+ ctrllimited=True,
+ ctrlrange=ctrlrange,
+ **kwargs)
+
+
+def add_velocity_actuator(target, qvelrange, ctrlrange=(-1, 1),
+ kv=1.0, **kwargs):
+ """Adds a scaled velocity actuator that is bound to the specified element.
+
+ This is equivalent to MuJoCo's built-in `` actuator where an affine
+ transformation is pre-applied to the control signal, such that the minimum
+ control value corresponds to the minimum desired velocity, and the
+ maximum control value corresponds to the maximum desired velocity.
+
+ Args:
+ target: A PyMJCF joint, tendon, or site element object that is to be
+ controlled.
+ qvelrange: A sequence of two numbers specifying the allowed range of target
+ velocity.
+ ctrlrange: A sequence of two numbers specifying the allowed range of
+ this actuator's control signal.
+ kv: The gain parameter of this velocity actuator.
+ **kwargs: Additional MJCF attributes for this actuator element.
+ The following attributes are disallowed: `['biastype', 'gainprm',
+ 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site',
+ 'slidersite', 'cranksite']`.
+
+ Returns:
+ A PyMJCF actuator element that has been added to the MJCF model containing
+ the specified `target`.
+
+ Raises:
+ TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute,
+ or `target` is not an allowed MJCF element type.
+ """
+ _check_target_and_kwargs(target, **kwargs)
+ kwargs[target.tag] = target
+
+ slope = (qvelrange[1] - qvelrange[0]) / (ctrlrange[1] - ctrlrange[0])
+ g0 = kv * slope
+ b0 = kv * (qvelrange[0] - slope * ctrlrange[0])
+ b1 = 0
+ b2 = -kv
+ return target.root.actuator.add('general',
+ biastype='affine',
+ gainprm=[g0],
+ biasprm=[b0, b1, b2],
+ ctrllimited=True,
+ ctrlrange=ctrlrange,
+ **kwargs)
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py
new file mode 100644
index 0000000..1ff0f20
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py
@@ -0,0 +1,135 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for scaled actuators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import absltest
+from dm_control import mjcf
+from dm_control.locomotion.walkers import scaled_actuators
+import numpy as np
+from six.moves import range
+
+
+class ScaledActuatorsTest(absltest.TestCase):
+
+ def setUp(self):
+ super(ScaledActuatorsTest, self).setUp()
+ self._mjcf_model = mjcf.RootElement()
+ self._min = -1.4
+ self._max = 2.3
+ self._gain = 1.7
+ self._scaled_min = -0.8
+ self._scaled_max = 1.3
+ self._range = self._max - self._min
+ self._scaled_range = self._scaled_max - self._scaled_min
+ self._joints = []
+ for _ in range(2):
+ body = self._mjcf_model.worldbody.add('body')
+ body.add('geom', type='sphere', size=[1])
+ self._joints.append(body.add('joint', type='hinge'))
+ self._scaled_actuator_joint = self._joints[0]
+ self._standard_actuator_joint = self._joints[1]
+ self._random_state = np.random.RandomState(3474)
+
+ def _set_actuator_controls(self, physics, normalized_ctrl,
+ scaled_actuator=None, standard_actuator=None):
+ if scaled_actuator is not None:
+ physics.bind(scaled_actuator).ctrl = (
+ normalized_ctrl * self._scaled_range + self._scaled_min)
+ if standard_actuator is not None:
+ physics.bind(standard_actuator).ctrl = (
+ normalized_ctrl * self._range + self._min)
+
+ def _assert_same_qfrc_actuator(self, physics, joint1, joint2):
+ np.testing.assert_allclose(physics.bind(joint1).qfrc_actuator,
+ physics.bind(joint2).qfrc_actuator)
+
+ def test_position_actuator(self):
+ scaled_actuator = scaled_actuators.add_position_actuator(
+ target=self._scaled_actuator_joint, kp=self._gain,
+ qposrange=(self._min, self._max),
+ ctrlrange=(self._scaled_min, self._scaled_max))
+ standard_actuator = self._mjcf_model.actuator.add(
+ 'position', joint=self._standard_actuator_joint, kp=self._gain,
+ ctrllimited=True, ctrlrange=(self._min, self._max))
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+
+ # Zero torque.
+ physics.bind(self._scaled_actuator_joint).qpos = (
+ 0.2345 * self._range + self._min)
+ self._set_actuator_controls(physics, 0.2345, scaled_actuator)
+ np.testing.assert_allclose(
+ physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15)
+
+ for _ in range(100):
+ normalized_ctrl = self._random_state.uniform()
+ physics.bind(self._joints).qpos = (
+ self._random_state.uniform() * self._range + self._min)
+ self._set_actuator_controls(physics, normalized_ctrl,
+ scaled_actuator, standard_actuator)
+ self._assert_same_qfrc_actuator(
+ physics, self._scaled_actuator_joint, self._standard_actuator_joint)
+
+ def test_velocity_actuator(self):
+ scaled_actuator = scaled_actuators.add_velocity_actuator(
+ target=self._scaled_actuator_joint, kv=self._gain,
+ qvelrange=(self._min, self._max),
+ ctrlrange=(self._scaled_min, self._scaled_max))
+ standard_actuator = self._mjcf_model.actuator.add(
+ 'velocity', joint=self._standard_actuator_joint, kv=self._gain,
+ ctrllimited=True, ctrlrange=(self._min, self._max))
+ physics = mjcf.Physics.from_mjcf_model(self._mjcf_model)
+
+ # Zero torque.
+ physics.bind(self._scaled_actuator_joint).qvel = (
+ 0.5432 * self._range + self._min)
+ self._set_actuator_controls(physics, 0.5432, scaled_actuator)
+ np.testing.assert_allclose(
+ physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15)
+
+ for _ in range(100):
+ normalized_ctrl = self._random_state.uniform()
+ physics.bind(self._joints).qvel = (
+ self._random_state.uniform() * self._range + self._min)
+ self._set_actuator_controls(physics, normalized_ctrl,
+ scaled_actuator, standard_actuator)
+ self._assert_same_qfrc_actuator(
+ physics, self._scaled_actuator_joint, self._standard_actuator_joint)
+
+ def test_invalid_kwargs(self):
+ invalid_kwargs = dict(joint=self._scaled_actuator_joint, ctrllimited=False)
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ scaled_actuators._GOT_INVALID_KWARGS.format(sorted(invalid_kwargs))):
+ scaled_actuators.add_position_actuator(
+ target=self._scaled_actuator_joint,
+ qposrange=(self._min, self._max),
+ **invalid_kwargs)
+
+ def test_invalid_target(self):
+ invalid_target = self._mjcf_model.worldbody
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ scaled_actuators._GOT_INVALID_TARGET.format(invalid_target)):
+ scaled_actuators.add_position_actuator(
+ target=invalid_target, qposrange=(self._min, self._max))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walls.png b/DMC/src/env/dm_control/dm_control/locomotion/walls.png
new file mode 100644
index 0000000..fc15386
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/walls.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mjcf/README.md b/DMC/src/env/dm_control/dm_control/mjcf/README.md
new file mode 100644
index 0000000..8418b31
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mjcf/README.md
@@ -0,0 +1,498 @@
+# PyMJCF
+
+IMPORTANT: If you find yourself stuck while using PyMJCF, check out the various
+IMPORTANT boxes on this page and the [Common gotchas](#common-gotchas) section
+at the bottom to see if any of them is relevant.
+
+This library provides a Python object model for MuJoCo's XML-based
+[MJCF](http://www.mujoco.org/book/modeling.html) physics modeling language. The
+goal of the library is to allow users to easily interact with and modify MJCF
+models in Python, similarly to what the JavaScript DOM does for HTML.
+
+A key feature of this library is the ability to easily compose multiple separate
+MJCF models into a larger one. Disambiguation of duplicated names from different
+models, or multiple instances of the same model, is handled automatically.
+
+The following snippet provides a quick example of this library's typical use
+case. Here, the `UpperBody` class can simply instantiate two copies of `Arm`,
+thus reducing code duplication. The names of bodies, joints, or geoms of each
+`Arm` are automatically prefixed by their parent's names, and so no name
+collision occurs.
+
+```python
+from dm_control import mjcf
+
+class Arm(object):
+
+ def __init__(self, name):
+ self.mjcf_model = mjcf.RootElement(model=name)
+
+ self.upper_arm = self.mjcf_model.worldbody.add('body', name='upper_arm')
+ self.shoulder = self.upper_arm.add('joint', name='shoulder', type='ball')
+ self.upper_arm.add('geom', name='upper_arm', type='capsule',
+ pos=[0, 0, -0.15], size=[0.045, 0.15])
+
+ self.forearm = self.upper_arm.add('body', name='forearm', pos=[0, 0, -0.3])
+ self.elbow = self.forearm.add('joint', name='elbow',
+ type='hinge', axis=[0, 1, 0])
+ self.forearm.add('geom', name='forearm', type='capsule',
+ pos=[0, 0, -0.15], size=[0.045, 0.15])
+
+class UpperBody(object):
+
+ def __init__(self):
+ self.mjcf_model = mjcf.RootElement()
+ self.mjcf_model.worldbody.add(
+ 'geom', name='torso', type='box', size=[0.15, 0.045, 0.25])
+ left_shoulder_site = self.mjcf_model.worldbody.add(
+ 'site', size=[1e-6]*3, pos=[-0.15, 0, 0.25])
+ right_shoulder_site = self.mjcf_model.worldbody.add(
+ 'site', size=[1e-6]*3, pos=[0.15, 0, 0.25])
+
+ self.left_arm = Arm(name='left_arm')
+ left_shoulder_site.attach(self.left_arm.mjcf_model)
+
+ self.right_arm = Arm(name='right_arm')
+ right_shoulder_site.attach(self.right_arm.mjcf_model)
+
+body = UpperBody()
+physics = mjcf.Physics.from_mjcf_model(body.mjcf_model)
+```
+
+## Basic operations
+
+### Creating an MJCF model
+
+In PyMJCF, the basic building block of a model is an `mjcf.Element`. This
+corresponds to an element in the generated XML. However, user code _cannot_
+instantiate a generic `mjcf.Element` object directly.
+
+A valid model always consists of a single root `` element. This is
+represented as the special `mjcf.RootElement` type in PyMJCF, which _can_ be
+instantiated in user code to create an empty model.
+
+```python
+from dm_control import mjcf
+
+mjcf_model = mjcf.RootElement()
+print(mjcf_model) # MJCF Element:
+```
+
+### Adding new elements
+
+Attributes of the new element can be passed as kwargs:
+
+```python
+my_box = mjcf_model.worldbody.add('geom', name='my_box',
+ type='box', pos=[0, .1, 0])
+print(my_box) # MJCF Element:
+```
+
+### Parsing an existing XML document
+
+Alternatively, if an existing XML file already exists, PyMJCF can parse it to
+create a Python object:
+
+```python
+from dm_control import mjcf
+
+# Parse from path
+mjcf_model = mjcf.from_path(filename)
+
+# Parse from file
+with open(filename) as f:
+ mjcf_model = mjcf.from_file(f)
+
+# Parse from string
+with open(filename) as f:
+ xml_string = f.read()
+mjcf_model = mjcf.from_xml_string(xml_string)
+
+print(type(mjcf_model)) #
+```
+
+### Traversing through a model
+
+Consider the following MJCF model:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+The child elements and XML attributes of an `Element` object are exposed as
+Python attributes. These attributes all have the same names as their XML
+counterparts, with one exception: the `class` XML attribute is named `dclass` in
+order to avoid a clash with the Python `class` keyword:
+
+```python
+my_geom = mjcf_model.worldbody.body['foo'].body['bar'].geom['my_geom']
+print(isinstance(mjcf_model, mjcf.Element)) # True
+print(my_geom.name) # 'my_geom'
+print(my_geom.pos) # np.array([0., 1., 2.], dtype=np.float)
+print(my_geom.class) # SyntaxError
+print(my_geom.dclass) # 'brick'
+```
+
+Note that attribute values in the object model are **not** affected by defaults:
+
+```python
+print(mjcf_model.default.default['brick'].geom.rgba) # [1, 0, 0, 1]
+print(my_geom.rgba) # None
+```
+
+### Finding elements without traversing
+
+We can also find elements directly without having to traverse through the object
+hierarchy:
+
+```python
+found_geom = mjcf_model.find('geom', 'my_geom')
+print(found_geom == my_geom) # True
+```
+
+Find all elements of a given type:
+
+```python
+# Note that is also considered a joint
+joints = mjcf_model.find_all('joint')
+print(len(joints)) # 2
+print(joints[0] == mjcf_model.worldbody.body['foo'].freejoint) # True
+print(joints[1] == mjcf_model.worldbody.body['foo'].body['bar'].joint[0]) # True
+```
+
+Note that the order of elements returned by `find_all` is the same as the order
+in which they are declared in the model.
+
+### Modifying XML attributes
+
+Attributes can be modified, added, or removed:
+
+```python
+my_geom.pos = [1, 2, 3]
+print(my_geom.pos) # np.array([1., 2., 3.], dtype=np.float)
+my_geom.quat = [0, 1, 0, 0]
+print(my_geom.quat) # np.array([0., 1., 0., 0.], dtype=np.float)
+del my_geom.quat
+print(my_geom.quat) # None
+```
+
+Schema violations result in errors:
+
+```python
+print(my_geom.poss) # raise AttributeError (no child or attribute called poss)
+my_geom.pos = 'invalid' # raise ValueError (assigning string to array)
+my_geom.pos = [1, 2, 3, 4, 5, 6] # raise ValueError (array length is too long)
+
+# raise ValueError (mass is a required attribute of )
+del mjcf_model.find('body', 'foo').inertial.mass
+```
+
+### Uniqueness of identifiers
+
+PyMJCF enforces the uniqueness of "identifier" attributes within a model.
+Identifiers consist of the `class` attribute of a ``, and all `name`
+attributes. Their uniqueness is only enforced within a particular namespace. For
+example, a `` is allowed to have the same name as a ``, whereas
+`` and `` actuators cannot have the same name.
+
+```python
+mjcf_model.worldbody.add('geom', name='my_geom')
+foo = mjcf_model.worldbody.find('body', 'foo')
+foo.add('my_geom') # Error, duplicated geom name
+foo.add('foo') # OK, a geom can have the same name as a body
+mjcf_model.find('geom', 'foo').name = 'my_geom' # Error, duplicated geom name
+```
+
+### Reference attributes
+
+Some attributes are references to other elements. For example, the `joint`
+attribute of an actuator refers to a `` element in the model.
+
+An `mjcf.Element` can be directly assigned to these reference attributes:
+
+```python
+my_hinge = mjcf_model.find('joint', 'my_hinge')
+my_actuator = mjcf_model.actuator.add('velocity', joint=my_hinge)
+```
+
+This is the recommended way to assign reference attributes, since it guarantees
+that the reference is not invalidated if the referenced element is renamed.
+Alternatively, a string can also be assigned to reference attributes. In this
+case, PyMJCF does **not** attempt to verify that the named element actually
+exists in the model.
+
+IMPORTANT: If the element being referenced is in a different model to the
+reference attribute (e.g. in an attached model), the reference **must** be
+created by directly assigning an `mjcf.Element` object to the attribute rather
+than a string. Strings assigned to reference attributes cannot contain '/',
+since they are automatically scoped by PyMJCF upon attachment.
+
+## Attaching models
+
+In this section we will refer to an `mjcf.RootElement` simply as a "model".
+Models can be _attached_ to other models in order to create compositional
+scenes.
+
+```python
+arena = mjcf.RootElement()
+arena.worldbody.add('geom', name='ground', type='plane', size=[10, 10, 1])
+
+robot = mjcf.from_xml_file('robot.xml')
+arena.attach(robot)
+```
+
+We refer to `arena` as the _parent model_, and `robot` as the _child model_ (or
+the _attached model_).
+
+### Attachment frames
+
+When a model is attached to a site, an empty body is created in the parent
+model. This empty body is called an _attachment frame_.
+
+The attachment frame is created as a child of the body that contains the
+attachment site, and it has the same position and orientation as the site. When
+the XML is generated, the attachment frame's contents shadow the contents of the
+attached model's ``. The attachment frame's name in the generated XML
+is the child's `fully/qualified/prefix/`. The trailing slash ensures that the
+attachment frame's name never collides with a user-defined body.
+
+More concretely, if we have the following parent and child models:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+Then the final generated XML will be:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+IMPORTANT: The attachment frame is created _transparently_ to the user. In
+particular, it is NOT treated as a regular `body` by PyMJCF. Its name in the
+generated XML should be considered implementation detail and should NOT be
+relied on.
+
+Having said that, it is sometimes necessary to access the attachment frame, for
+example to add a joint between the parent and the child model. The easiest way
+to do this is to hold a reference to the object returned by a call to `attach`:
+
+```python
+attachment_frame = parent_model.attach('child')
+attachment_frame.add('freejoint')
+```
+
+Alternatively, if a model has already been attached, the `find` function can be
+used with the `attachment_frame` namespace in order to retrieve the attachment
+frame. The `get_attachment_frame` convenience function in `mjcf.traversal_utils`
+can find the child model's attachment frame without needing access to the parent
+model.
+
+```python
+frame_1 = parent_model.find('attachment_frame', 'child')
+
+# Convenience function: get the attachment frame directly from a child model
+frame_2 = mjcf.traversal_utils.get_attachment_frame(child_model)
+print(frame_1 == frame_2) # True
+```
+
+IMPORTANT: To encourage good modeling practices, the only allowed direct
+children of an attachment frame are `` and ``. Other types of
+elements should instead add be added to the `` of the attached model.
+
+### Element ownership
+
+IMPORTANT: Elements of child models do **not** appear when traversing through
+the parent model.
+
+### Default classes
+
+PyMJCF ensures that default classes of a parent model _never_ affect any of its
+child models. This minimises the possibility that two models become subtly
+"incompatible", as a model always behaves in the same way regardless of what it
+is attached to.
+
+The way that PyMJCF achieves this in practice is to move everything in a model's
+global `` context into a default class named `/`. In other words, a
+PyMJCF-generated model never has anything in the global default context.
+Instead, the generated model always looks like:
+
+```xml
+
+
+
+
+
+
+
+
+
+```
+
+IMPORTANT: This transformation is _transparent_ to the user. Within Python, the
+above geom rgba setting is accessed as if it were a global default, i.e.
+`mjcf_model.default.geom.rgba`. Generally speaking, users should never have to
+worry about PyMJCF's internal handling of defaults.
+
+When a model is attached, its `/` default class turns into
+`fully/qualified/prefix/`. The trailing slash ensures that this transformation
+never conflicts with a user-named default class. More specifically, if we have
+the following parent and child models:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+Then the final generated XML will be:
+
+```xml
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+```
+
+### Global options
+
+A model cannot be attached to another model if _any_ of the global options are
+different. Global options consist of attributes of ``, `
+ """.format(name=name, quat=quat)
+ physics = engine.Physics.from_xml_string(xml_string)
+ rotation_matrix = getattr(physics.named.data, data_field_name)[name]
+ self.assertAlmostEqual(np.linalg.det(rotation_matrix.reshape(3, 3)), 1)
+
+ @parameterized.parameters(['xmat', 'geom_xmat', 'site_xmat'])
+ def testValidRotationMatrixIfQuatNotNormalizedInXML(self, field_name):
+ self._check_valid_rotation_matrix(field_name)
+
+ # TODO(b/123918714): Update this once the bug has been fixed in MuJoCo.
+ @unittest.expectedFailure
+ def testValidCameraRotationMatrixIfQuatNotNormalizedInXML(self):
+ self._check_valid_rotation_matrix('cam_xmat')
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/index.py b/DMC/src/env/dm_control/dm_control/mujoco/index.py
new file mode 100644
index 0000000..5bfd428
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/index.py
@@ -0,0 +1,667 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Mujoco functions to support named indexing.
+
+The Mujoco name structure works as follows:
+
+In mjxmacro.h, each "X" entry denotes a type (a), a field name (b) and a list
+of dimension size metadata (c) which may contain both numbers and names, for
+example
+
+ X(int, name_bodyadr, nbody, 1) // or
+ X(mjtNum, body_pos, nbody, 3)
+ a b c ----->
+
+The second declaration states that the field `body_pos` has type `mjtNum` and
+dimension sizes `(nbody, 3)`, i.e. the first axis is indexed by body number.
+These and other named dimensions are sized based on the loaded model. This
+information is parsed and stored in `mjbindings.sizes`.
+
+In mjmodel.h, the struct mjModel contains an array of element name addresses
+for each size name.
+
+ int* name_bodyadr; // body name pointers (nbody x 1)
+
+By iterating over each of these element name address arrays, we first obtain a
+mapping from size names to a list of element names.
+
+ {'nbody': ['cart', 'pole'], 'njnt': ['free', 'ball', 'hinge'], ...}
+
+In addition to the element names that are derived from the mjModel struct at
+runtime, we also assign hard-coded names to certain dimensions where there is an
+established naming convention (e.g. 'x', 'y', 'z' for dimensions that correspond
+to Cartesian positions).
+
+For some dimensions, a single element name maps to multiple indices within the
+underlying field. For example, a single joint name corresponds to a variable
+number of indices within `qpos` that depends on the number of degrees of freedom
+associated with that joint type. These are referred to as "ragged" dimensions.
+
+In such cases we determine the size of each named element by examining the
+address arrays (e.g. `jnt_qposadr`), and construct a mapping from size name to
+element sizes:
+
+ {'nq': [7, 3, 1], 'nv': [6, 3, 1], ...}
+
+Given these two dictionaries, we then create an `Axis` instance for each size
+name. These objects have a `convert_key_item` method that handles the conversion
+from indexing expressions containing element names to valid numpy indices.
+Different implementations of `Axis` are used to handle "ragged" and "non-ragged"
+dimensions.
+
+ {'nbody': RegularNamedAxis(names=['cart', 'pole']),
+ 'nq': RaggedNamedAxis(names=['free', 'ball', 'hinge'], sizes=[7, 4, 1])}
+
+We construct this dictionary once using `make_axis_indexers`.
+
+Finally, for each field we construct a `FieldIndexer` class. A `FieldIndexer`
+instance encapsulates a field together with a list of `Axis` instances (one per
+dimension), and implements the named indexing logic by calling their respective
+`convert_key_item` methods.
+
+Summary of terminology:
+
+* _size name_ or _size_ A dimension size name, e.g. `nbody` or `ngeom`.
+* _element name_ or _name_ A named element in a Mujoco model, e.g. 'cart' or
+ 'pole'.
+* _element index_ or _index_ The index of an element name, for a specific size
+ name.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import collections
+import weakref
+
+from dm_control.mujoco.wrapper import util
+from dm_control.mujoco.wrapper.mjbindings import sizes
+import numpy as np
+import six
+
+
+# Mapping from {size_name: address_field_name} for ragged dimensions.
+_RAGGED_ADDRS = {
+ 'nq': 'jnt_qposadr',
+ 'nv': 'jnt_dofadr',
+ 'nsensordata': 'sensor_adr',
+ 'nnumericdata': 'numeric_adr',
+}
+
+# Names of columns.
+_COLUMN_NAMES = {
+ 'xyz': ['x', 'y', 'z'],
+ 'quat': ['qw', 'qx', 'qy', 'qz'],
+ 'mat': ['xx', 'xy', 'xz',
+ 'yx', 'yy', 'yz',
+ 'zx', 'zy', 'zz'],
+ 'rgba': ['r', 'g', 'b', 'a'],
+}
+
+# Mapping from keys of _COLUMN_NAMES to sets of field names whose columns are
+# addressable using those names.
+_COLUMN_ID_TO_FIELDS = {
+ 'xyz': set([
+ 'body_pos',
+ 'body_ipos',
+ 'body_inertia',
+ 'jnt_pos',
+ 'jnt_axis',
+ 'geom_size',
+ 'geom_pos',
+ 'site_size',
+ 'site_pos',
+ 'cam_pos',
+ 'cam_poscom0',
+ 'cam_pos0',
+ 'light_pos',
+ 'light_dir',
+ 'light_poscom0',
+ 'light_pos0',
+ 'light_dir0',
+ 'mesh_vert',
+ 'mesh_normal',
+ 'mocap_pos',
+ 'xpos',
+ 'xipos',
+ 'xanchor',
+ 'xaxis',
+ 'geom_xpos',
+ 'site_xpos',
+ 'cam_xpos',
+ 'light_xpos',
+ 'light_xdir',
+ 'subtree_com',
+ 'wrap_xpos',
+ 'subtree_linvel',
+ 'subtree_angmom',
+ ]),
+ 'quat': set([
+ 'body_quat',
+ 'body_iquat',
+ 'geom_quat',
+ 'site_quat',
+ 'cam_quat',
+ 'mocap_quat',
+ 'xquat',
+ ]),
+ 'mat': set([
+ 'cam_mat0',
+ 'xmat',
+ 'ximat',
+ 'geom_xmat',
+ 'site_xmat',
+ 'cam_xmat',
+ ]),
+ 'rgba': set([
+ 'geom_rgba',
+ 'site_rgba',
+ 'skin_rgba',
+ 'mat_rgba',
+ 'tendon_rgba',
+ ])
+}
+
+
+def _get_size_name_to_element_names(model):
+ """Returns a dict that maps size names to element names.
+
+ Args:
+ model: An instance of `mjbindings.mjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nbody'`) to a list of element
+ names.
+ """
+
+ names = model.names[:model.nnames]
+ size_name_to_element_names = {}
+
+ for field_name in dir(model):
+ if not _is_name_pointer(field_name):
+ continue
+
+ # Get addresses of element names in `model.names` array, e.g.
+ # field name: `name_nbodyadr` and name_addresses: `[86, 92, 101]`, and skip
+ # when there are no elements for this type in the model.
+ name_addresses = getattr(model, field_name).ravel()
+ if not name_addresses.size:
+ continue
+
+ # Get the element names.
+ element_names = []
+ for start_index in name_addresses:
+ name = names[start_index:names.find(b'\0', start_index)]
+ element_names.append(util.to_native_string(name))
+
+ # String identifier for the size of the first dimension, e.g. 'nbody'.
+ size_name = _get_size_name(field_name)
+
+ size_name_to_element_names[size_name] = element_names
+
+ # Add custom element names for certain columns.
+ for size_name, element_names in six.iteritems(_COLUMN_NAMES):
+ size_name_to_element_names[size_name] = element_names
+
+ # "Ragged" axes inherit their element names from other "non-ragged" axes.
+ # For example, the element names for "nv" axis come from "njnt".
+ for size_name, address_field_name in six.iteritems(_RAGGED_ADDRS):
+ donor = 'n' + address_field_name.split('_')[0]
+ if donor in size_name_to_element_names:
+ size_name_to_element_names[size_name] = size_name_to_element_names[donor]
+
+ # Mocap bodies are a special subset of bodies.
+ mocap_body_names = [None] * model.nmocap
+ for body_id, body_name in enumerate(size_name_to_element_names['nbody']):
+ body_mocapid = model.body_mocapid[body_id]
+ if body_mocapid != -1:
+ mocap_body_names[body_mocapid] = body_name
+ assert None not in mocap_body_names
+ size_name_to_element_names['nmocap'] = mocap_body_names
+
+ # Arrays with dimension `na` correspond to stateful actuators. MuJoCo's
+ # compiler requires that these are always defined after stateless actuators,
+ # so we only need the final `na` elements in the list of all actuator names.
+ if model.na:
+ all_actuator_names = size_name_to_element_names['nu']
+ size_name_to_element_names['na'] = all_actuator_names[-model.na:]
+
+ return size_name_to_element_names
+
+
+def _get_size_name_to_element_sizes(model):
+ """Returns a dict that maps size names to element sizes for ragged axes.
+
+ Args:
+ model: An instance of `mjbindings.mjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nv'`) to a numpy array of element
+ sizes. Size names corresponding to non-ragged axes are omitted.
+ """
+
+ size_name_to_element_sizes = {}
+
+ for size_name, address_field_name in six.iteritems(_RAGGED_ADDRS):
+ addresses = getattr(model, address_field_name).ravel()
+ total_length = getattr(model, size_name)
+ element_sizes = np.diff(np.r_[addresses, total_length])
+ size_name_to_element_sizes[size_name] = element_sizes
+
+ return size_name_to_element_sizes
+
+
+def make_axis_indexers(model):
+ """Returns a dict that maps size names to `Axis` indexers.
+
+ Args:
+ model: An instance of `mjbindings.MjModelWrapper`.
+
+ Returns:
+ A `dict` mapping from a size name (e.g. `'nbody'`) to an `Axis` instance.
+ """
+
+ size_name_to_element_names = _get_size_name_to_element_names(model)
+ size_name_to_element_sizes = _get_size_name_to_element_sizes(model)
+
+ # Unrecognized size names are treated as unnamed axes.
+ axis_indexers = collections.defaultdict(UnnamedAxis)
+
+ for size_name in size_name_to_element_names:
+ element_names = size_name_to_element_names[size_name]
+ if size_name in _RAGGED_ADDRS:
+ element_sizes = size_name_to_element_sizes[size_name]
+ indexer = RaggedNamedAxis(element_names, element_sizes)
+ else:
+ indexer = RegularNamedAxis(element_names)
+ axis_indexers[size_name] = indexer
+
+ return axis_indexers
+
+
+def _is_name_pointer(field_name):
+ """Returns True for name pointer field names such as `name_bodyadr`."""
+ # Denotes name pointer fields in mjModel.
+ prefix, suffix = 'name_', 'adr'
+ return field_name.startswith(prefix) and field_name.endswith(suffix)
+
+
+def _get_size_name(field_name, struct_name='mjmodel'):
+ # Look up size name in metadata.
+ return sizes.array_sizes[struct_name][field_name][0]
+
+
+def _validate_key_item(key_item):
+ if isinstance(key_item, (list, np.ndarray)):
+ for sub in key_item:
+ _validate_key_item(sub) # Recurse into nested arrays and lists.
+ elif key_item is Ellipsis:
+ raise IndexError('Ellipsis indexing not supported.')
+ elif key_item is None:
+ raise IndexError('None indexing not supported.')
+ elif key_item in (b'', u''):
+ raise IndexError('Empty strings are not allowed.')
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Axis(object):
+ """Handles the conversion of named indexing expressions into numpy indices."""
+
+ @abc.abstractmethod
+ def convert_key_item(self, key_item):
+ """Converts a (possibly named) indexing expression to a numpy index."""
+
+
+class UnnamedAxis(Axis):
+ """An object representing an axis where the elements are not named."""
+
+ def convert_key_item(self, key_item):
+ """Validate the indexing expression and return it unmodified."""
+ _validate_key_item(key_item)
+ return key_item
+
+
+class RegularNamedAxis(Axis):
+ """Represents an axis where each named element has a fixed size of 1."""
+
+ def __init__(self, names):
+ """Initializes a new `RegularNamedAxis` instance.
+
+ Args:
+ names: A list or array of element names.
+ """
+ self._names = names
+ self._names_to_offsets = {name: offset
+ for offset, name in enumerate(names) if name}
+
+ def convert_key_item(self, key_item):
+ """Converts a named indexing expression to a numpy-friendly index."""
+
+ _validate_key_item(key_item)
+
+ if isinstance(key_item, six.string_types):
+ key_item = self._names_to_offsets[util.to_native_string(key_item)]
+
+ elif isinstance(key_item, (list, np.ndarray)):
+ # Cast lists to numpy arrays.
+ key_item = np.array(key_item, copy=False)
+ original_shape = key_item.shape
+
+ # We assume that either all or none of the items in the array are strings
+ # representing names. If there is a mix, we will let NumPy throw an error
+ # when trying to index with the returned item.
+ if isinstance(key_item.flat[0], six.string_types):
+ key_item = np.array([self._names_to_offsets[util.to_native_string(k)]
+ for k in key_item.flat])
+ # Ensure the output shape is the same as that of the input.
+ key_item.shape = original_shape
+
+ return key_item
+
+ @property
+ def names(self):
+ """Returns a list of element names."""
+ return self._names
+
+
+class RaggedNamedAxis(Axis):
+ """Represents an axis where the named elements may vary in size."""
+
+ def __init__(self, element_names, element_sizes):
+ """Initializes a new `RaggedNamedAxis` instance.
+
+ Args:
+ element_names: A list or array containing the element names.
+ element_sizes: A list or array containing the size of each element.
+ """
+ names_to_slices = {}
+ names_to_indices = {}
+
+ offset = 0
+ for name, size in zip(element_names, element_sizes):
+ # Don't add unnamed elements to the dicts.
+ if name:
+ names_to_slices[name] = slice(offset, offset + size)
+ names_to_indices[name] = range(offset, offset + size)
+ offset += size
+
+ self._names = element_names
+ self._sizes = element_sizes
+ self._names_to_slices = names_to_slices
+ self._names_to_indices = names_to_indices
+
+ def convert_key_item(self, key):
+ """Converts a named indexing expression to a numpy-friendly index."""
+
+ _validate_key_item(key)
+
+ if isinstance(key, six.string_types):
+ key = self._names_to_slices[util.to_native_string(key)]
+
+ elif isinstance(key, (list, np.ndarray)):
+ # We assume that either all or none of the items in the sequence are
+ # strings representing names. If there is a mix, we will let NumPy throw
+ # an error when trying to index with the returned key.
+ if isinstance(key[0], six.string_types):
+ new_key = []
+ for k in key:
+ idx = self._names_to_indices[util.to_native_string(k)]
+ if isinstance(idx, int):
+ new_key.append(idx)
+ else:
+ new_key.extend(idx)
+ key = new_key
+
+ return key
+
+ @property
+ def names(self):
+ """Returns a list of element names."""
+ return self._names
+
+
+Axes = collections.namedtuple('Axes', ['row', 'col'])
+Axes.__new__.__defaults__ = (None,) # Default value for optional 'col' field
+
+
+class FieldIndexer(object):
+ """An array-like object providing named access to a field in a MuJoCo struct.
+
+ FieldIndexers expose the same attributes and methods as an `np.ndarray`.
+
+ They may be indexed with strings or lists of strings corresponding to element
+ names. They also support standard numpy indexing expressions, with the
+ exception of indices containing `Ellipsis` or `None`.
+ """
+
+ __slots__ = ('_field_name', '_field', '_axes')
+
+ def __init__(self,
+ parent_struct,
+ field_name,
+ axis_indexers):
+ """Initializes a new `FieldIndexer`.
+
+ Args:
+ parent_struct: Wrapped ctypes structure, as generated by `mjbindings`.
+ field_name: String containing field name in `parent_struct`.
+ axis_indexers: A list of `Axis` instances, one per dimension.
+ """
+ self._field_name = field_name
+ self._field = weakref.proxy(getattr(parent_struct, field_name))
+ self._axes = Axes(*axis_indexers)
+
+ def __dir__(self):
+ # Enables IPython tab completion
+ return sorted(set(dir(type(self)) + dir(self._field)))
+
+ def __getattr__(self, name):
+ return getattr(self._field, name)
+
+ def _convert_key(self, key):
+ """Convert a (possibly named) indexing expression to a valid numpy index."""
+ return_tuple = isinstance(key, tuple)
+ if not return_tuple:
+ key = (key,)
+ if len(key) > self._field.ndim:
+ raise IndexError('Index tuple has {} elements, but array has only {} '
+ 'dimensions.'.format(len(key), self._field.ndim))
+ new_key = tuple(axis.convert_key_item(key_item)
+ for axis, key_item in zip(self._axes, key))
+ if not return_tuple:
+ new_key = new_key[0]
+ return new_key
+
+ def __getitem__(self, key):
+ """Converts the key to a numeric index and returns the indexed array.
+
+ Args:
+ key: Indexing expression.
+
+ Raises:
+ IndexError: If an indexing tuple has too many elements, or if it contains
+ `Ellipsis`, `None`, or an empty string.
+
+ Returns:
+ The indexed array.
+ """
+ return self._field[self._convert_key(key)]
+
+ def __setitem__(self, key, value):
+ """Converts the key and assigns to the indexed array.
+
+ Args:
+ key: Indexing expression.
+ value: Value to assign.
+
+ Raises:
+ IndexError: If an indexing tuple has too many elements, or if it contains
+ `Ellipsis`, `None`, or an empty string.
+ """
+ self._field[self._convert_key(key)] = value
+
+ @property
+ def axes(self):
+ """A namedtuple containing the row and column indexers for this field."""
+ return self._axes
+
+ def __repr__(self):
+ """Returns a pretty string representation of the `FieldIndexer`."""
+
+ def get_name_arr_and_len(dim_idx):
+ """Returns a string array of element names and the max name length."""
+ axis = self._axes[dim_idx]
+ size = self._field.shape[dim_idx]
+ try:
+ name_len = max(len(name) for name in axis.names)
+ name_arr = np.zeros(size, dtype='S{}'.format(name_len))
+ for name in axis.names:
+ if name:
+ # Use the `Axis` object to convert the name into a numpy index, then
+ # use this index to write into name_arr.
+ name_arr[axis.convert_key_item(name)] = name
+ except AttributeError:
+ name_arr = np.zeros(size, dtype='S0') # An array of zero-length strings
+ name_len = 0
+ return name_arr, name_len
+
+ row_name_arr, row_name_len = get_name_arr_and_len(0)
+ if self._field.ndim > 1:
+ col_name_arr, col_name_len = get_name_arr_and_len(1)
+ else:
+ col_name_arr, col_name_len = np.zeros(1, dtype='S0'), 0
+
+ idx_len = int(np.log10(max(self._field.shape[0], 1))) + 1
+
+ cls_template = '{class_name:}({field_name:}):'
+ col_template = '{padding:}{col_names:}'
+ row_template = '{idx:{idx_len:}} {row_name:>{row_name_len:}} {row_vals:}'
+
+ lines = []
+
+ # Write the class name and field name.
+ lines.append(cls_template.format(class_name=self.__class__.__name__,
+ field_name=self._field_name))
+
+ # Write a header line containing the column names (if there are any).
+ if col_name_len:
+ col_width = max(col_name_len, 9) + 1
+ extra_indent = 4
+ padding = ' ' * (idx_len + row_name_len + extra_indent)
+ col_names = ''.join(
+ '{name:<{width:}}'
+ .format(name=util.to_native_string(name), width=col_width)
+ for name in col_name_arr)
+ lines.append(col_template.format(padding=padding, col_names=col_names))
+
+ # Write the row names (if there are any) and the formatted array values.
+ if not self._field.shape[0]:
+ lines.append('(empty)')
+ else:
+ for idx, row in enumerate(self._field):
+ row_vals = np.array2string(
+ np.atleast_1d(row),
+ suppress_small=True,
+ formatter={'float_kind': '{: < 9.3g}'.format})
+ lines.append(row_template.format(
+ idx=idx,
+ idx_len=idx_len,
+ row_name=util.to_native_string(row_name_arr[idx]),
+ row_name_len=row_name_len,
+ row_vals=row_vals))
+ return '\n'.join(lines)
+
+
+def struct_indexer(struct, struct_name, size_to_axis_indexer):
+ """Returns an object with a `FieldIndexer` attribute for each dynamic field.
+
+ Usage example
+
+ ```python
+ named_data = struct_indexer(mjdata, 'mjdata', size_to_axis_indexer)
+ fingertip_xpos = named_data.xpos['fingertip']
+ elbow_qvel = named_data.qvel['elbow']
+ ```
+
+ Args:
+ struct: Wrapped ctypes structure as generated by `mjbindings`.
+ struct_name: String containing corresponding Mujoco name of struct.
+ size_to_axis_indexer: dict that maps size names to `Axis` instances.
+
+ Returns:
+ An object with a field for every dynamically sized array field, mapping to a
+ `FieldIndexer`. The returned object is immutable and has an `_asdict`
+ method.
+
+ Raises:
+ ValueError: If `struct_name` is not recognized.
+ """
+ struct_name = struct_name.lower()
+ if struct_name not in sizes.array_sizes:
+ raise ValueError('Unrecognized struct name ' + struct_name)
+
+ array_sizes = sizes.array_sizes[struct_name]
+
+ # Used to create the namedtuple.
+ field_names = []
+ field_indexers = {}
+
+ for field_name in array_sizes:
+
+ # Skip over structured arrays and fields that have sizes but aren't numpy
+ # arrays, such as text fields and contacts (b/34805932).
+ attr = getattr(struct, field_name)
+ if not isinstance(attr, np.ndarray) or attr.dtype.fields:
+ continue
+
+ size_names = sizes.array_sizes[struct_name][field_name]
+
+ # Here we override the size name in order to enable named column indexing
+ # for certain fields, e.g. 3 becomes "xyz" for field name "xpos".
+ for new_col_size, field_set in six.iteritems(_COLUMN_ID_TO_FIELDS):
+ if field_name in field_set:
+ size_names = (size_names[0], new_col_size)
+ break
+
+ axis_indexers = []
+ for size_name in size_names:
+ axis_indexers.append(size_to_axis_indexer[size_name])
+
+ field_indexers[field_name] = FieldIndexer(
+ parent_struct=struct,
+ field_name=field_name,
+ axis_indexers=axis_indexers)
+
+ field_names.append(field_name)
+
+ return make_struct_indexer(field_indexers)
+
+
+def make_struct_indexer(field_indexers):
+ """Returns an immutable container exposing named indexers as attributes."""
+
+ class StructIndexer(object):
+ __slots__ = ()
+
+ def _asdict(self):
+ return field_indexers.copy()
+
+ for name, indexer in six.iteritems(field_indexers):
+ setattr(StructIndexer, name, indexer)
+
+ return StructIndexer()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/index_test.py b/DMC/src/env/dm_control/dm_control/mujoco/index_test.py
new file mode 100644
index 0000000..7a52c65
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/index_test.py
@@ -0,0 +1,360 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for index."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from dm_control.mujoco import index
+from dm_control.mujoco import wrapper
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper.mjbindings import sizes
+
+import numpy as np
+import six
+
+MODEL = assets.get_contents('cartpole.xml')
+MODEL_NO_NAMES = assets.get_contents('cartpole_no_names.xml')
+MODEL_3RD_ORDER_ACTUATORS = assets.get_contents(
+ 'model_with_third_order_actuators.xml')
+MODEL_INCORRECT_ACTUATOR_ORDER = assets.get_contents(
+ 'model_incorrect_actuator_order.xml')
+
+FIELD_REPR = {
+ 'act': ('FieldIndexer(act):\n'
+ '(empty)'),
+ 'qM': ('FieldIndexer(qM):\n'
+ '0 [ 0 ]\n'
+ '1 [ 1 ]\n'
+ '2 [ 2 ]'),
+ 'sensordata': ('FieldIndexer(sensordata):\n'
+ '0 accelerometer [ 0 ]\n'
+ '1 accelerometer [ 1 ]\n'
+ '2 accelerometer [ 2 ]\n'
+ '3 collision [ 3 ]'),
+ 'xpos': ('FieldIndexer(xpos):\n'
+ ' x y z \n'
+ '0 world [ 0 1 2 ]\n'
+ '1 cart [ 3 4 5 ]\n'
+ '2 pole [ 6 7 8 ]\n'
+ '3 mocap1 [ 9 10 11 ]\n'
+ '4 mocap2 [ 12 13 14 ]'),
+}
+
+
+class MujocoIndexTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(MujocoIndexTest, self).setUp()
+ self._model = wrapper.MjModel.from_xml_string(MODEL)
+ self._data = wrapper.MjData(self._model)
+
+ self._size_to_axis_indexer = index.make_axis_indexers(self._model)
+
+ self._model_indexers = index.struct_indexer(self._model, 'mjmodel',
+ self._size_to_axis_indexer)
+
+ self._data_indexers = index.struct_indexer(self._data, 'mjdata',
+ self._size_to_axis_indexer)
+
+ def assertIndexExpressionEqual(self, expected, actual):
+ try:
+ if isinstance(expected, tuple):
+ self.assertLen(actual, len(expected))
+ for expected_item, actual_item in zip(expected, actual):
+ self.assertIndexExpressionEqual(expected_item, actual_item)
+ elif isinstance(expected, (list, np.ndarray)):
+ np.testing.assert_array_equal(expected, actual)
+ else:
+ self.assertEqual(expected, actual)
+ except AssertionError:
+ self.fail('Indexing expressions are not equal.\n'
+ 'expected: {!r}\nactual: {!r}'.format(expected, actual))
+
+ @parameterized.parameters(
+ # (field name, named index key, expected integer index key)
+ ('actuator_gear', 'slide', 0),
+ ('geom_rgba', ('mocap_sphere', 'g'), (6, 1)),
+ ('dof_armature', 'slider', slice(0, 1, None)),
+ ('dof_armature', ['slider', 'hinge'], [0, 1]),
+ ('numeric_data', 'three_numbers', slice(1, 4, None)),
+ ('numeric_data', ['three_numbers', 'control_timestep'], [1, 2, 3, 0]))
+ def testModelNamedIndexing(self, field_name, key, numeric_key):
+
+ indexer = getattr(self._model_indexers, field_name)
+ field = getattr(self._model, field_name)
+
+ converted_key = indexer._convert_key(key)
+
+ # Explicit check that the converted key matches the numeric key.
+ converted_key = indexer._convert_key(key)
+ self.assertIndexExpressionEqual(numeric_key, converted_key)
+
+ # This writes unique values to the underlying buffer to prevent false
+ # negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the result of named indexing matches the result of numeric
+ # indexing.
+ np.testing.assert_array_equal(field[numeric_key], indexer[key])
+
+ @parameterized.parameters(
+ # (field name, named index key, expected integer index key)
+ ('xpos', 'pole', 2),
+ ('xpos', ['pole', 'cart'], [2, 1]),
+ ('sensordata', 'accelerometer', slice(0, 3, None)),
+ ('sensordata', 'collision', slice(3, 4, None)),
+ ('sensordata', ['accelerometer', 'collision'], [0, 1, 2, 3]),
+ # Slices.
+ ('xpos', (slice(None), 0), (slice(None), 0)),
+ # Custom fixed-size columns.
+ ('xpos', ('pole', 'y'), (2, 1)),
+ ('xmat', ('cart', ['yy', 'zz']), (1, [4, 8])),
+ # Custom indexers for mocap bodies.
+ ('mocap_quat', 'mocap1', 0),
+ ('mocap_pos', (['mocap2', 'mocap1'], 'z'), ([1, 0], 2)),
+ # Two-dimensional named indexing.
+ ('xpos', (['pole', 'cart'], ['x', 'z']), ([2, 1], [0, 2])),
+ ('xpos', ([['pole'], ['cart']], ['x', 'z']), ([[2], [1]], [0, 2])))
+ def testDataNamedIndexing(self, field_name, key, numeric_key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # Explicit check that the converted key matches the numeric key.
+ converted_key = indexer._convert_key(key)
+ self.assertIndexExpressionEqual(numeric_key, converted_key)
+
+ # This writes unique values to the underlying buffer to prevent false
+ # negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the result of named indexing matches the result of numeric
+ # indexing.
+ np.testing.assert_array_equal(field[numeric_key], indexer[key])
+
+ @parameterized.parameters(
+ # (field name, named index key, expected integer index key)
+ ('act', 'cylinder', 0),
+ ('act_dot', 'general', 1),
+ ('act', ['general', 'cylinder', 'general'], [1, 0, 1]))
+ def testIndexThirdOrderActuators(self, field_name, key, numeric_key):
+ model = wrapper.MjModel.from_xml_string(MODEL_3RD_ORDER_ACTUATORS)
+ data = wrapper.MjData(model)
+ size_to_axis_indexer = index.make_axis_indexers(model)
+ data_indexers = index.struct_indexer(data, 'mjdata', size_to_axis_indexer)
+
+ indexer = getattr(data_indexers, field_name)
+ field = getattr(data, field_name)
+
+ # Explicit check that the converted key matches the numeric key.
+ converted_key = indexer._convert_key(key)
+ self.assertIndexExpressionEqual(numeric_key, converted_key)
+
+ # This writes unique values to the underlying buffer to prevent false
+ # negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the result of named indexing matches the result of numeric
+ # indexing.
+ np.testing.assert_array_equal(field[numeric_key], indexer[key])
+
+ def testIncorrectActuatorOrder(self):
+ # Our indexing of third-order actuators relies on an undocumented
+ # requirement of MuJoCo's compiler that all third-order actuators come after
+ # all second-order actuators. This test ensures that the rule still holds
+ # (e.g. in future versions of MuJoCo).
+ with six.assertRaisesRegex(
+ self, wrapper.Error, '2nd-order actuators must come before 3rd-order'):
+ wrapper.MjModel.from_xml_string(MODEL_INCORRECT_ACTUATOR_ORDER)
+
+ @parameterized.parameters(
+ # (field name, named index key)
+ ('xpos', 'pole'),
+ ('xpos', ['pole', 'cart']),
+ ('xpos', (['pole', 'cart'], 'y')),
+ ('xpos', (['pole', 'cart'], ['x', 'z'])),
+ ('qpos', 'slider'),
+ ('qvel', ['slider', 'hinge']),)
+ def testDataAssignment(self, field_name, key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # The result of the indexing expression is either an array or a scalar.
+ index_result = indexer[key]
+ try:
+ # Write a sequence of unique values to prevent false negatives.
+ new_values = np.arange(index_result.size).reshape(index_result.shape)
+ except AttributeError:
+ new_values = 99
+ indexer[key] = new_values
+
+ # Check that the new value(s) can be read back from the underlying buffer.
+ converted_key = indexer._convert_key(key)
+ np.testing.assert_array_equal(new_values, field[converted_key])
+
+ @parameterized.parameters(
+ # (field name, first index key, second index key)
+ ('sensordata', 'accelerometer', 0),
+ ('sensordata', 'accelerometer', [0, 2]),
+ ('sensordata', 'accelerometer', slice(None)),)
+ def testChainedAssignment(self, field_name, first_key, second_key):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # The result of the indexing expression is either an array or a scalar.
+ index_result = indexer[first_key][second_key]
+ try:
+ # Write a sequence of unique values to prevent false negatives.
+ new_values = np.arange(index_result.size).reshape(index_result.shape)
+ except AttributeError:
+ new_values = 99
+ indexer[first_key][second_key] = new_values
+
+ # Check that the new value(s) can be read back from the underlying buffer.
+ converted_key = indexer._convert_key(first_key)
+ np.testing.assert_array_equal(new_values, field[converted_key][second_key])
+
+ def testNamedColumnFieldNames(self):
+
+ all_fields = set()
+ for struct in six.itervalues(sizes.array_sizes):
+ all_fields.update(struct.keys())
+
+ named_col_fields = set()
+ for field_set in six.itervalues(index._COLUMN_ID_TO_FIELDS):
+ named_col_fields.update(field_set)
+
+ # Check that all of the "named column" fields specified in index are
+ # also found in mjbindings.sizes.
+ self.assertContainsSubset(named_col_fields, all_fields)
+
+ @parameterized.parameters('xpos', 'xmat') # field name
+ def testTooManyIndices(self, field_name):
+ indexer = getattr(self._data_indexers, field_name)
+ with six.assertRaisesRegex(self, IndexError, 'Index tuple'):
+ _ = indexer[:, :, :, 'too', 'many', 'elements']
+
+ @parameterized.parameters(
+ # bad item, exception regexp
+ (Ellipsis, 'Ellipsis'),
+ (None, 'None'),
+ (np.newaxis, 'None'),
+ (b'', 'Empty string'),
+ (u'', 'Empty string'))
+ def testBadIndexItems(self, bad_index_item, exception_regexp):
+ indexer = getattr(self._data_indexers, 'xpos')
+ expressions = [
+ bad_index_item,
+ (0, bad_index_item),
+ [bad_index_item],
+ [[bad_index_item]],
+ (0, [bad_index_item]),
+ (0, [[bad_index_item]]),
+ np.array([bad_index_item]),
+ (0, np.array([bad_index_item])),
+ (0, np.array([[bad_index_item]]))
+ ]
+ for expression in expressions:
+ with six.assertRaisesRegex(self, IndexError, exception_regexp):
+ _ = indexer[expression]
+
+ @parameterized.parameters('act', 'qM', 'sensordata', 'xpos') # field name
+ def testFieldIndexerRepr(self, field_name):
+
+ indexer = getattr(self._data_indexers, field_name)
+ field = getattr(self._data, field_name)
+
+ # Write a sequence of unique values to prevent false negatives.
+ field.flat[:] = np.arange(field.size)
+
+ # Check that the string representation is as expected.
+ self.assertEqual(FIELD_REPR[field_name], repr(indexer))
+
+ @parameterized.parameters(MODEL, MODEL_NO_NAMES)
+ def testBuildIndexersForEdgeCases(self, xml_string):
+ model = wrapper.MjModel.from_xml_string(xml_string)
+ data = wrapper.MjData(model)
+
+ size_to_axis_indexer = index.make_axis_indexers(model)
+
+ index.struct_indexer(model, 'mjmodel', size_to_axis_indexer)
+ index.struct_indexer(data, 'mjdata', size_to_axis_indexer)
+
+ @parameterized.parameters(
+ name for name in dir(np.ndarray)
+ if not name.startswith('_') # Exclude 'private' attributes
+ and name not in ('ctypes', 'flat') # Can't compare via identity/equality
+ )
+ def testFieldIndexerDelegatesNDArrayAttributes(self, name):
+ field = self._data.xpos
+ field_indexer = self._data_indexers.xpos
+ actual = getattr(field_indexer, name)
+ expected = getattr(field, name)
+ if isinstance(expected, np.ndarray):
+ np.testing.assert_array_equal(actual, expected)
+ else:
+ self.assertEqual(actual, expected)
+
+ # FieldIndexer attributes should be read-only
+ with six.assertRaisesRegex(self, AttributeError, name):
+ setattr(field_indexer, name, expected)
+
+ def testFieldIndexerDir(self):
+ expected_subset = dir(self._data.xpos)
+ actual_set = dir(self._data_indexers.xpos)
+ self.assertContainsSubset(expected_subset, actual_set)
+
+
+def _iter_indexers(model, data):
+ size_to_axis_indexer = index.make_axis_indexers(model)
+ for struct, struct_name in ((model, 'mjmodel'), (data, 'mjdata')):
+ indexer = index.struct_indexer(struct, struct_name, size_to_axis_indexer)
+ for field_name, field_indexer in six.iteritems(indexer._asdict()):
+ yield field_name, field_indexer
+
+
+class AllFieldsTest(parameterized.TestCase):
+ """Generic tests covering each FieldIndexer in model and data."""
+
+ # NB: the class must hold references to the model and data instances or they
+ # may be garbage-collected before any indexing is attempted.
+ model = wrapper.MjModel.from_xml_string(MODEL)
+ data = wrapper.MjData(model)
+
+ # Iterates over ('field_name', FieldIndexer) pairs
+ @parameterized.named_parameters(_iter_indexers(model, data))
+ def testReadWrite_(self, field):
+ # Read the contents of the FieldIndexer as a numpy array.
+ old_contents = field[:]
+ # Write unique values to the FieldIndexer and read them back again.
+ # Don't write to non-float fields since these might contain pointers.
+ if np.issubdtype(old_contents.dtype, np.floating):
+ new_contents = np.arange(old_contents.size, dtype=old_contents.dtype)
+ new_contents.shape = old_contents.shape
+ field[:] = new_contents
+ np.testing.assert_array_equal(new_contents, field[:])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/math.py b/DMC/src/env/dm_control/dm_control/mujoco/math.py
new file mode 100644
index 0000000..4be5da3
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/math.py
@@ -0,0 +1,79 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utility functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+
+import numpy as np
+
+
+def euler2quat(ax, ay, az):
+ """Converts euler angles to a quaternion.
+
+ Note: rotation order is zyx
+
+ Args:
+ ax: Roll angle (deg)
+ ay: Pitch angle (deg).
+ az: Yaw angle (deg).
+
+ Returns:
+ A numpy array representing the rotation as a quaternion.
+ """
+ r1 = az
+ r2 = ay
+ r3 = ax
+
+ c1 = np.cos(np.deg2rad(r1 / 2))
+ s1 = np.sin(np.deg2rad(r1 / 2))
+ c2 = np.cos(np.deg2rad(r2 / 2))
+ s2 = np.sin(np.deg2rad(r2 / 2))
+ c3 = np.cos(np.deg2rad(r3 / 2))
+ s3 = np.sin(np.deg2rad(r3 / 2))
+
+ q0 = c1 * c2 * c3 + s1 * s2 * s3
+ q1 = c1 * c2 * s3 - s1 * s2 * c3
+ q2 = c1 * s2 * c3 + s1 * c2 * s3
+ q3 = s1 * c2 * c3 - c1 * s2 * s3
+
+ return np.array([q0, q1, q2, q3])
+
+
+def mj_quatprod(q, r):
+ quaternion = np.zeros(4)
+ mjlib.mju_mulQuat(quaternion, np.ascontiguousarray(q),
+ np.ascontiguousarray(r))
+ return quaternion
+
+
+def mj_quat2vel(q, dt):
+ vel = np.zeros(3)
+ mjlib.mju_quat2Vel(vel, np.ascontiguousarray(q), dt)
+ return vel
+
+
+def mj_quatneg(q):
+ quaternion = np.zeros(4)
+ mjlib.mju_negQuat(quaternion, np.ascontiguousarray(q))
+ return quaternion
+
+
+def mj_quatdiff(source, target):
+ return mj_quatprod(mj_quatneg(source), np.ascontiguousarray(target))
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/math_test.py b/DMC/src/env/dm_control/dm_control/mujoco/math_test.py
new file mode 100644
index 0000000..502f17b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/math_test.py
@@ -0,0 +1,61 @@
+# Copyright 2019 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for index."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.mujoco import math as mjmath
+import numpy as np
+
+
+class MathTest(parameterized.TestCase):
+
+ def testQuatProd(self):
+ np.testing.assert_allclose(
+ mjmath.mj_quatprod([0., 1., 0., 0.], [0., 0., 1., 0.]),
+ [0., 0., 0., 1.])
+ np.testing.assert_allclose(
+ mjmath.mj_quatprod([0., 0., 1., 0.], [0., 0., 0., 1.]),
+ [0., 1., 0., 0.])
+ np.testing.assert_allclose(
+ mjmath.mj_quatprod([0., 0., 0., 1.], [0., 1., 0., 0.]),
+ [0., 0., 1., 0.])
+
+ def testQuat2Vel(self):
+ np.testing.assert_allclose(
+ mjmath.mj_quat2vel([0., 1., 0., 0.], 0.1), [math.pi / 0.1, 0., 0.])
+
+ def testQuatNeg(self):
+ np.testing.assert_allclose(
+ mjmath.mj_quatneg([math.sqrt(0.5), math.sqrt(0.5), 0., 0.]),
+ [math.sqrt(0.5), -math.sqrt(0.5), 0., 0.])
+
+ def testQuatDiff(self):
+ np.testing.assert_allclose(
+ mjmath.mj_quatdiff([0., 1., 0., 0.], [0., 0., 1., 0.]),
+ [0., 0., 0., -1.])
+
+ def testEuler2Quat(self):
+ np.testing.assert_allclose(
+ mjmath.euler2quat(0., 0., 0.), [1., 0., 0., 0.])
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/render_test.py b/DMC/src/env/dm_control/dm_control/mujoco/render_test.py
new file mode 100644
index 0000000..84bdabb
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/render_test.py
@@ -0,0 +1,89 @@
+# Copyright 2017-2018 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Integration tests for rendering."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import platform
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import _render
+from dm_control import mujoco
+from dm_control.mujoco.testing import decorators
+from dm_control.mujoco.testing import image_utils
+from six.moves import range
+from six.moves import zip
+
+
+DEBUG_IMAGE_DIR = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
+ absltest.get_default_test_tmpdir())
+
+# Context creation with GLFW is not threadsafe.
+if _render.BACKEND == 'glfw':
+ # On Linux we are able to create a GLFW window in a single thread that is not
+ # the main thread.
+ # On Mac we are only allowed to create windows on the main thread, so we
+ # disable the `run_threaded` wrapper entirely.
+ NUM_THREADS = None if platform.system() == 'Darwin' else 1
+else:
+ NUM_THREADS = 4
+CALLS_PER_THREAD = 1
+
+
+class RenderTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(image_utils.SEQUENCES.items())
+ @image_utils.save_images_on_failure(output_dir=DEBUG_IMAGE_DIR)
+ @decorators.run_threaded(num_threads=NUM_THREADS,
+ calls_per_thread=CALLS_PER_THREAD)
+ def test_render(self, sequence):
+ for expected, actual in zip(sequence.iter_load(), sequence.iter_render()):
+ image_utils.assert_images_close(expected, actual)
+
+ @decorators.run_threaded(num_threads=NUM_THREADS,
+ calls_per_thread=CALLS_PER_THREAD)
+ @image_utils.save_images_on_failure(output_dir=DEBUG_IMAGE_DIR)
+ def test_render_multiple_physics_per_thread(self):
+ cartpole = image_utils.cartpole
+ humanoid = image_utils.humanoid
+ cartpole_frames = []
+ humanoid_frames = []
+ for cartpole_frame, humanoid_frame in zip(cartpole.iter_render(),
+ humanoid.iter_render()):
+ cartpole_frames.append(cartpole_frame)
+ humanoid_frames.append(humanoid_frame)
+
+ for expected, actual in zip(cartpole.iter_load(), cartpole_frames):
+ image_utils.assert_images_close(expected, actual)
+
+ for expected, actual in zip(humanoid.iter_load(), humanoid_frames):
+ image_utils.assert_images_close(expected, actual)
+
+ @decorators.run_threaded(num_threads=NUM_THREADS, calls_per_thread=1)
+ def test_repeatedly_create_and_destroy_rendering_contexts(self):
+ # Tests for errors that may occur due to per-thread GL resource leakage.
+ physics = mujoco.Physics.from_xml_string('')
+ for _ in range(500):
+ physics._make_rendering_contexts()
+ physics._free_rendering_contexts()
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/__init__.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/__init__.py
new file mode 100644
index 0000000..1ebb270
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/__init__.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/__init__.py
new file mode 100644
index 0000000..9d0aca3
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Assets used for testing the MuJoCo bindings."""
+
+import os
+
+from dm_control.utils import io as resources
+
+_ASSETS_DIR = os.path.dirname(__file__)
+
+
+def get_contents(filename):
+ """Returns the contents of an asset as a string."""
+ return resources.GetResource(os.path.join(_ASSETS_DIR, filename))
+
+
+def get_path(filename):
+ """Returns the path to an asset."""
+ return resources.GetResourceFilename(os.path.join(_ASSETS_DIR, filename))
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/arm.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/arm.xml
new file mode 100644
index 0000000..905d63e
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/arm.xml
@@ -0,0 +1,92 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cartpole.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cartpole.xml
new file mode 100644
index 0000000..b15370b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cartpole.xml
@@ -0,0 +1,69 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cube.stl b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cube.stl
new file mode 100644
index 0000000..a5bc825
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/cube.stl differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/deepmind.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/deepmind.png
new file mode 100644
index 0000000..1586759
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/deepmind.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_000.png
new file mode 100644
index 0000000..94832e4
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_001.png
new file mode 100644
index 0000000..ea76f01
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_002.png
new file mode 100644
index 0000000..46e2940
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_003.png
new file mode 100644
index 0000000..ef98511
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_004.png
new file mode 100644
index 0000000..33ea89d
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_005.png
new file mode 100644
index 0000000..0d79d5c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_006.png
new file mode 100644
index 0000000..4e22f40
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_007.png
new file mode 100644
index 0000000..e8b6457
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_008.png
new file mode 100644
index 0000000..b23e1a8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_009.png
new file mode 100644
index 0000000..85c3923
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_010.png
new file mode 100644
index 0000000..83a301b
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_011.png
new file mode 100644
index 0000000..9ec9921
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_012.png
new file mode 100644
index 0000000..009467b
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_013.png
new file mode 100644
index 0000000..0bb65be
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_014.png
new file mode 100644
index 0000000..61cfa74
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_015.png
new file mode 100644
index 0000000..dc32fa2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_016.png
new file mode 100644
index 0000000..8013c56
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_017.png
new file mode 100644
index 0000000..e632980
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_018.png
new file mode 100644
index 0000000..d5ac61c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_019.png
new file mode 100644
index 0000000..463f386
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_hardware/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_000.png
new file mode 100644
index 0000000..09f62be
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_001.png
new file mode 100644
index 0000000..679f77c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_002.png
new file mode 100644
index 0000000..0bb1f15
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_003.png
new file mode 100644
index 0000000..5a6c336
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_004.png
new file mode 100644
index 0000000..a9215f7
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_005.png
new file mode 100644
index 0000000..b9caa13
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_006.png
new file mode 100644
index 0000000..4c426d5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_007.png
new file mode 100644
index 0000000..d4622a0
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_008.png
new file mode 100644
index 0000000..d77c619
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_009.png
new file mode 100644
index 0000000..5e499f1
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_010.png
new file mode 100644
index 0000000..291ecc5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_011.png
new file mode 100644
index 0000000..f8fa9b8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_012.png
new file mode 100644
index 0000000..192f7cc
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_013.png
new file mode 100644
index 0000000..a8ed6d5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_014.png
new file mode 100644
index 0000000..b8b19de
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_015.png
new file mode 100644
index 0000000..ec5f11e
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_016.png
new file mode 100644
index 0000000..21b359e
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_017.png
new file mode 100644
index 0000000..a34a274
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_018.png
new file mode 100644
index 0000000..addcf71
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_019.png
new file mode 100644
index 0000000..32c693f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/cartpole_seed_0_camera_0_320x240_software/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_000.png
new file mode 100644
index 0000000..f5e137f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_001.png
new file mode 100644
index 0000000..68211d2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_002.png
new file mode 100644
index 0000000..8843e29
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_003.png
new file mode 100644
index 0000000..55a2988
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_004.png
new file mode 100644
index 0000000..20f3d3d
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_005.png
new file mode 100644
index 0000000..edc1d31
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_006.png
new file mode 100644
index 0000000..9cf2daf
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_007.png
new file mode 100644
index 0000000..d66f7ca
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_008.png
new file mode 100644
index 0000000..ea0ebb1
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_009.png
new file mode 100644
index 0000000..bfeca8f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_010.png
new file mode 100644
index 0000000..b5d2e0f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_011.png
new file mode 100644
index 0000000..cd1dbd0
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_012.png
new file mode 100644
index 0000000..b4bffa7
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_013.png
new file mode 100644
index 0000000..75c56a8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_014.png
new file mode 100644
index 0000000..3d184cb
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_015.png
new file mode 100644
index 0000000..295efb0
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_016.png
new file mode 100644
index 0000000..d62897d
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_017.png
new file mode 100644
index 0000000..50f2d22
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_018.png
new file mode 100644
index 0000000..fbb9dc2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_019.png
new file mode 100644
index 0000000..c0d3bf2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_hardware/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_000.png
new file mode 100644
index 0000000..d94c22f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_001.png
new file mode 100644
index 0000000..7caa3eb
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_002.png
new file mode 100644
index 0000000..e77cee5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_003.png
new file mode 100644
index 0000000..1bfac02
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_004.png
new file mode 100644
index 0000000..a342230
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_005.png
new file mode 100644
index 0000000..82f8c9e
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_006.png
new file mode 100644
index 0000000..14e4cbb
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_007.png
new file mode 100644
index 0000000..f54cb18
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_008.png
new file mode 100644
index 0000000..4b8f85d
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_009.png
new file mode 100644
index 0000000..0a59baf
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_010.png
new file mode 100644
index 0000000..59da995
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_011.png
new file mode 100644
index 0000000..624a6a6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_012.png
new file mode 100644
index 0000000..f760469
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_013.png
new file mode 100644
index 0000000..28f9603
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_014.png
new file mode 100644
index 0000000..b17433c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_015.png
new file mode 100644
index 0000000..e050360
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_016.png
new file mode 100644
index 0000000..eabe5d2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_017.png
new file mode 100644
index 0000000..8464db5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_018.png
new file mode 100644
index 0000000..b70d1f1
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_019.png
new file mode 100644
index 0000000..7114678
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_0_240x320_software/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_000.png
new file mode 100644
index 0000000..c2a6ab2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_001.png
new file mode 100644
index 0000000..a6a016a
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_002.png
new file mode 100644
index 0000000..bd413c5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_003.png
new file mode 100644
index 0000000..73d9291
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_004.png
new file mode 100644
index 0000000..706a1fa
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_005.png
new file mode 100644
index 0000000..8c31c31
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_006.png
new file mode 100644
index 0000000..1a6fc8c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_007.png
new file mode 100644
index 0000000..f6173ab
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_008.png
new file mode 100644
index 0000000..d74d7f8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_009.png
new file mode 100644
index 0000000..b527baa
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_010.png
new file mode 100644
index 0000000..93368ed
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_011.png
new file mode 100644
index 0000000..01da06f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_012.png
new file mode 100644
index 0000000..7c9ace8
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_013.png
new file mode 100644
index 0000000..9e92fdc
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_014.png
new file mode 100644
index 0000000..cb4a08e
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_015.png
new file mode 100644
index 0000000..5b93ff5
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_016.png
new file mode 100644
index 0000000..1eb9171
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_017.png
new file mode 100644
index 0000000..7a68dd2
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_018.png
new file mode 100644
index 0000000..95493f4
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_019.png
new file mode 100644
index 0000000..5564632
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_hardware/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_000.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_000.png
new file mode 100644
index 0000000..3dee659
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_000.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_001.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_001.png
new file mode 100644
index 0000000..f15593c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_001.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_002.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_002.png
new file mode 100644
index 0000000..86fb6d6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_002.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_003.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_003.png
new file mode 100644
index 0000000..60ca5c6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_003.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_004.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_004.png
new file mode 100644
index 0000000..fa19ca6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_004.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_005.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_005.png
new file mode 100644
index 0000000..9915796
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_005.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_006.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_006.png
new file mode 100644
index 0000000..e2ff894
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_006.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_007.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_007.png
new file mode 100644
index 0000000..fc6fbbb
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_007.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_008.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_008.png
new file mode 100644
index 0000000..7a44914
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_008.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_009.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_009.png
new file mode 100644
index 0000000..2c061fd
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_009.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_010.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_010.png
new file mode 100644
index 0000000..353dc8c
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_010.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_011.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_011.png
new file mode 100644
index 0000000..c600cfc
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_011.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_012.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_012.png
new file mode 100644
index 0000000..21951a7
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_012.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_013.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_013.png
new file mode 100644
index 0000000..7b06354
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_013.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_014.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_014.png
new file mode 100644
index 0000000..fb976a3
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_014.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_015.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_015.png
new file mode 100644
index 0000000..39ecf8f
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_015.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_016.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_016.png
new file mode 100644
index 0000000..56886d6
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_016.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_017.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_017.png
new file mode 100644
index 0000000..641f7dc
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_017.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_018.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_018.png
new file mode 100644
index 0000000..9786372
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_018.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_019.png b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_019.png
new file mode 100644
index 0000000..4752a61
Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/frames/humanoid_seed_0_camera_head_track_64x64_software/frame_019.png differ
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/humanoid.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/humanoid.xml
new file mode 100644
index 0000000..adb8f39
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/humanoid.xml
@@ -0,0 +1,122 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_incorrect_actuator_order.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_incorrect_actuator_order.xml
new file mode 100644
index 0000000..06d7157
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_incorrect_actuator_order.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_assets.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_assets.xml
new file mode 100644
index 0000000..adfd45b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_assets.xml
@@ -0,0 +1,13 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_ball_joints.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_ball_joints.xml
new file mode 100644
index 0000000..9791d9b
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_ball_joints.xml
@@ -0,0 +1,21 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_third_order_actuators.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_third_order_actuators.xml
new file mode 100644
index 0000000..b86cd2d
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/model_with_third_order_actuators.xml
@@ -0,0 +1,16 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/sphere.xml b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/sphere.xml
new file mode 100644
index 0000000..0999191
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/assets/sphere.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators.py
new file mode 100644
index 0000000..4fc1db0
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Decorators used in MuJoCo tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import sys
+import threading
+
+import six
+from six.moves import range
+
+
+def run_threaded(num_threads=4, calls_per_thread=10):
+ """A decorator that executes the same test repeatedly in multiple threads.
+
+ Note: `setUp` and `tearDown` methods will only be called once from the main
+ thread, so all thread-local setup must be done within the test method.
+
+ Args:
+ num_threads: Number of concurrent threads to spawn. If None then the wrapped
+ method will be executed in the main thread instead.
+ calls_per_thread: Number of times each thread should call the test method.
+ Returns:
+ Decorated test method.
+ """
+ def decorator(test_method):
+ """Decorator around the test method."""
+ @functools.wraps(test_method) # Needed for `named_parameters` to work.
+ def decorated_method(self, *args, **kwargs):
+ """Actual method this factory will return."""
+ exceptions = []
+ def worker():
+ try:
+ for _ in range(calls_per_thread):
+ test_method(self, *args, **kwargs)
+ except: # pylint: disable=bare-except
+ # Appending to Python list is thread-safe:
+ # http://effbot.org/pyfaq/what-kinds-of-global-value-mutation-are-thread-safe.htm
+ exceptions.append(sys.exc_info())
+ if num_threads is not None:
+ threads = [threading.Thread(target=worker, name='thread_{}'.format(i))
+ for i in range(num_threads)]
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ else:
+ worker()
+ for exc_class, old_exc, tb in exceptions:
+ six.reraise(exc_class, old_exc, tb)
+ return decorated_method
+ return decorator
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators_test.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators_test.py
new file mode 100644
index 0000000..60d3cad
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/decorators_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests of the decorators module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+
+from absl.testing import absltest
+
+from dm_control.mujoco.testing import decorators
+import mock
+from six.moves import range
+
+
+class RunThreadedTest(absltest.TestCase):
+
+ @mock.patch(decorators.__name__ + ".threading")
+ def test_number_of_threads(self, mock_threading):
+ num_threads = 5
+
+ mock_threads = [mock.MagicMock() for _ in range(num_threads)]
+ for thread in mock_threads:
+ thread.start = mock.MagicMock()
+ thread.join = mock.MagicMock()
+
+ mock_threading.Thread = mock.MagicMock(side_effect=mock_threads)
+
+ test_decorator = decorators.run_threaded(num_threads=num_threads)
+ tested_method = mock.MagicMock()
+ tested_method.__name__ = "foo"
+ test_runner = test_decorator(tested_method)
+ test_runner(self)
+
+ for thread in mock_threads:
+ thread.start.assert_called_once()
+ thread.join.assert_called_once()
+
+ def test_number_of_iterations(self):
+ calls_per_thread = 5
+
+ tested_method = mock.MagicMock()
+ tested_method.__name__ = "foo"
+ test_decorator = decorators.run_threaded(
+ num_threads=1, calls_per_thread=calls_per_thread)
+ test_runner = test_decorator(tested_method)
+ test_runner(self)
+
+ self.assertEqual(calls_per_thread, tested_method.call_count)
+
+ @mock.patch(decorators.__name__ + ".threading")
+ def test_using_the_main_thread(self, mock_threading):
+ mock_thread = mock.MagicMock()
+ mock_thread.start = mock.MagicMock()
+ mock_thread.join = mock.MagicMock()
+ mock_threading.current_thread = mock.MagicMock(return_value=mock_thread)
+
+ test_decorator = decorators.run_threaded(num_threads=None,
+ calls_per_thread=1)
+ tested_method = mock.MagicMock()
+ tested_method.__name__ = "foo"
+ test_runner = test_decorator(tested_method)
+ test_runner(self)
+
+ tested_method.assert_called_once()
+ mock_thread.start.assert_not_called()
+ mock_thread.join.assert_not_called()
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/generate_frames.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/generate_frames.py
new file mode 100644
index 0000000..6c4f66a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/generate_frames.py
@@ -0,0 +1,34 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Script for generating pre-rendered frames used in integration tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Internal dependencies.
+from absl import app
+from dm_control.mujoco.testing import image_utils
+import six
+
+
+def main(argv):
+ del argv # Unused.
+ for sequence in six.itervalues(image_utils.SEQUENCES):
+ sequence.save()
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils.py
new file mode 100644
index 0000000..5f89138
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils.py
@@ -0,0 +1,222 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Utilities for testing rendering."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import functools
+import os
+import sys
+from dm_control import _render
+from dm_control import mujoco
+from dm_control.mujoco.testing import assets
+import numpy as np
+from PIL import Image
+import six
+from six.moves import range
+from six.moves import zip
+
+
+BACKEND_STRING = 'hardware' if _render.USING_GPU else 'software'
+
+
+class ImagesNotCloseError(AssertionError):
+ """Exception raised when two images are not sufficiently close."""
+
+ def __init__(self, message, expected, actual):
+ super(ImagesNotCloseError, self).__init__(message)
+ self.expected = expected
+ self.actual = actual
+
+
+_CameraSpec = collections.namedtuple(
+ '_CameraSpec', ['height', 'width', 'camera_id'])
+
+
+class _FrameSequence(object):
+ """A sequence of pre-rendered frames used in integration tests."""
+
+ _ASSETS_DIR = 'assets'
+ _FRAMES_DIR = 'frames'
+ _SUBDIR_TEMPLATE = (
+ '{name}_seed_{seed}_camera_{camera_id}_{width}x{height}_{backend_string}')
+ _FILENAME_TEMPLATE = 'frame_{frame_num:03}.png'
+
+ def __init__(self,
+ name,
+ xml_string,
+ camera_specs,
+ num_frames=20,
+ steps_per_frame=10,
+ seed=0):
+ """Initializes a new `_FrameSequence`.
+
+ Args:
+ name: A string containing the name to be used for the sequence.
+ xml_string: An MJCF XML string containing the model to be rendered.
+ camera_specs: A list of `_CameraSpec` instances specifying the cameras to
+ render on each frame.
+ num_frames: The number of frames to render.
+ steps_per_frame: The interval between frames, in simulation steps.
+ seed: Integer or None, used to initialize the random number generator for
+ generating actions.
+ """
+ self._name = name
+ self._xml_string = xml_string
+ self._camera_specs = camera_specs
+ self._num_frames = num_frames
+ self._steps_per_frame = steps_per_frame
+ self._seed = seed
+
+ def iter_render(self):
+ """Returns an iterator that yields newly rendered frames as numpy arrays."""
+ random_state = np.random.RandomState(self._seed)
+ physics = mujoco.Physics.from_xml_string(self._xml_string)
+ action_spec = mujoco.action_spec(physics)
+ for _ in range(self._num_frames):
+ for _ in range(self._steps_per_frame):
+ actions = random_state.uniform(action_spec.minimum, action_spec.maximum)
+ physics.set_control(actions)
+ physics.step()
+ for camera_spec in self._camera_specs:
+ yield physics.render(**camera_spec._asdict())
+
+ def iter_load(self):
+ """Returns an iterator that yields saved frames as numpy arrays."""
+ for directory, filename in self._iter_paths():
+ path = os.path.join(directory, filename)
+ yield _load_pixels(path)
+
+ def save(self):
+ """Saves a new set of golden output frames to disk."""
+ for pixels, (relative_to_assets, filename) in zip(self.iter_render(),
+ self._iter_paths()):
+ full_directory_path = os.path.join(self._ASSETS_DIR, relative_to_assets)
+ if not os.path.exists(full_directory_path):
+ os.makedirs(full_directory_path)
+ path = os.path.join(full_directory_path, filename)
+ _save_pixels(pixels, path)
+
+ def _iter_paths(self):
+ """Returns an iterator over paths to the reference images."""
+ for frame_num in range(self._num_frames):
+ filename = self._FILENAME_TEMPLATE.format(frame_num=frame_num)
+ for camera_spec in self._camera_specs:
+ subdir_name = self._SUBDIR_TEMPLATE.format(
+ name=self._name,
+ seed=self._seed,
+ backend_string=BACKEND_STRING,
+ **camera_spec._asdict())
+ directory = os.path.join(self._FRAMES_DIR, subdir_name)
+ yield directory, filename
+
+
+cartpole = _FrameSequence(
+ name='cartpole',
+ xml_string=assets.get_contents('cartpole.xml'),
+ camera_specs=[_CameraSpec(width=320, height=240, camera_id=0)],
+ steps_per_frame=5)
+
+humanoid = _FrameSequence(
+ name='humanoid',
+ xml_string=assets.get_contents('humanoid.xml'),
+ camera_specs=[
+ _CameraSpec(width=240, height=320, camera_id=0),
+ _CameraSpec(width=64, height=64, camera_id='head_track'),
+ ])
+
+
+SEQUENCES = {
+ 'cartpole': cartpole,
+ 'humanoid': humanoid,
+}
+
+
+def _save_pixels(pixels, path):
+ image = Image.fromarray(pixels)
+ image.save(path)
+
+
+def _load_pixels(path):
+ image_bytes = assets.get_contents(path)
+ image = Image.open(six.BytesIO(image_bytes))
+ return np.array(image)
+
+
+def compute_rms(image1, image2):
+ """Computes the RMS difference between two images."""
+ abs_diff = np.abs(image1.astype(np.int16) - image2)
+ values, counts = np.unique(abs_diff, return_counts=True)
+ sum_of_squares = np.sum(counts * values.astype(np.int64) ** 2)
+ return np.sqrt(float(sum_of_squares) / abs_diff.size)
+
+
+def assert_images_close(expected, actual, tolerance=10.):
+ """Tests whether two images are almost equal.
+
+ Args:
+ expected: A numpy array, the expected image.
+ actual: A numpy array, the actual image.
+ tolerance: A float specifying the maximum allowable RMS error between the
+ expected and actual images.
+
+ Raises:
+ ImagesNotCloseError: If the images are not sufficiently close.
+ """
+ rms = compute_rms(expected, actual)
+ if rms > tolerance:
+ message = 'RMS error exceeds tolerance ({} > {})'.format(rms, tolerance)
+ raise ImagesNotCloseError(message, expected=expected, actual=actual)
+
+
+def save_images_on_failure(output_dir):
+ """Decorator that saves debugging images if `ImagesNotCloseError` is raised.
+
+ Args:
+ output_dir: Path to the directory where the output images will be saved.
+
+ Returns:
+ A decorator function.
+ """
+ def decorator(test_method):
+ """Decorator, saves debugging images if `ImagesNotCloseError` is raised."""
+ method_name = test_method.__name__
+ @functools.wraps(test_method)
+ def decorated_method(*args, **kwargs):
+ """Call test method, save images if `ImagesNotCloseError` is raised."""
+ try:
+ test_method(*args, **kwargs)
+ except ImagesNotCloseError as e:
+ _, _, tb = sys.exc_info()
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ difference = e.actual.astype(np.double) - e.expected
+ difference = (0.5 * (difference + 255)).astype(np.uint8)
+ base_name = os.path.join(output_dir, method_name)
+ _save_pixels(e.expected, base_name + '-expected.png')
+ _save_pixels(e.actual, base_name + '-actual.png')
+ _save_pixels(difference, base_name + '-difference.png')
+ msg = ('{}. Debugging images saved to '
+ '{}-{{expected,actual,difference}}.png.'.format(e, base_name))
+ new_e = ImagesNotCloseError(msg, expected=e.expected, actual=e.actual)
+ # Reraise the exception with the original traceback.
+ six.reraise(ImagesNotCloseError, new_e, tb)
+
+ return decorated_method
+ return decorator
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils_test.py b/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils_test.py
new file mode 100644
index 0000000..d926bb6
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/testing/image_utils_test.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for image_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control.mujoco.testing import image_utils
+import mock
+import numpy as np
+from PIL import Image
+import six
+
+SEED = 0
+
+
+class ImageUtilsTest(parameterized.TestCase):
+
+ @parameterized.parameters(
+ (0, 0, 0.0),
+ (0, 2, 23.241),
+ (0, 18, 55.666))
+ def test_compute_rms(self, index1, index2, expected_rms):
+ # Force loading of the software rendering reference images regardless of the
+ # actual GL backend, since these should match the expected RMS values.
+ with mock.patch.object(image_utils, 'BACKEND_STRING', new='software'):
+ frames = list(image_utils.humanoid.iter_load())
+ image1 = frames[index1]
+ image2 = frames[index2]
+ rms = image_utils.compute_rms(image1, image2)
+ self.assertAlmostEqual(rms, expected_rms, places=3)
+
+ def test_assert_images_close(self):
+ random_state = np.random.RandomState(SEED)
+ image1 = random_state.randint(0, 255, size=(64, 64, 3), dtype=np.uint8)
+ image2 = random_state.randint(0, 255, size=(64, 64, 3), dtype=np.uint8)
+ image_utils.assert_images_close(image1, image1, tolerance=0)
+ with six.assertRaisesRegex(self, image_utils.ImagesNotCloseError,
+ 'RMS error exceeds tolerance'):
+ image_utils.assert_images_close(image1, image2)
+
+ def test_save_images_on_failure(self):
+ random_state = np.random.RandomState(SEED)
+ image1 = random_state.randint(0, 255, size=(64, 64, 3), dtype=np.uint8)
+ image2 = random_state.randint(0, 255, size=(64, 64, 3), dtype=np.uint8)
+ diff = (0.5 * (image2.astype(np.int16) - image1 + 255)).astype(np.uint8)
+ message = 'exception message'
+ output_dir = absltest.get_default_test_tmpdir()
+
+ @image_utils.save_images_on_failure(output_dir=output_dir)
+ def func():
+ raise image_utils.ImagesNotCloseError(message, image1, image2)
+
+ with six.assertRaisesRegex(self, image_utils.ImagesNotCloseError,
+ '{}.*'.format(message)):
+ func()
+
+ def validate_saved_file(name, expected_contents):
+ path = os.path.join(output_dir, '{}-{}.png'.format('func', name))
+ self.assertTrue(os.path.isfile(path))
+ image = Image.open(path)
+ actual_contents = np.array(image)
+ np.testing.assert_array_equal(expected_contents, actual_contents)
+
+ validate_saved_file('expected', image1)
+ validate_saved_file('actual', image2)
+ validate_saved_file('difference', diff)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/thread_safety_test.py b/DMC/src/env/dm_control/dm_control/mujoco/thread_safety_test.py
new file mode 100644
index 0000000..a59e5f6
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/thread_safety_test.py
@@ -0,0 +1,110 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests to check whether methods of `mujoco.Physics` are threadsafe."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+# Internal dependencies.
+
+from absl.testing import absltest
+from dm_control import _render
+from dm_control.mujoco import engine
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.testing import decorators
+
+from six.moves import range
+
+MODEL = assets.get_contents('cartpole.xml')
+NUM_STEPS = 10
+
+# Context creation with GLFW is not threadsafe.
+if _render.BACKEND == 'glfw':
+ # On Linux we are able to create a GLFW window in a single thread that is not
+ # the main thread.
+ # On Mac we are only allowed to create windows on the main thread, so we
+ # disable the `run_threaded` wrapper entirely.
+ NUM_THREADS = None if platform.system() == 'Darwin' else 1
+else:
+ NUM_THREADS = 4
+
+
+class ThreadSafetyTest(absltest.TestCase):
+
+ @decorators.run_threaded(num_threads=NUM_THREADS)
+ def test_load_physics_from_string(self):
+ engine.Physics.from_xml_string(MODEL)
+
+ @decorators.run_threaded(num_threads=NUM_THREADS)
+ def test_load_and_reload_physics_from_string(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+ physics.reload_from_xml_string(MODEL)
+
+ @decorators.run_threaded(num_threads=NUM_THREADS)
+ def test_load_and_step_physics(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+ for _ in range(NUM_STEPS):
+ physics.step()
+
+ @decorators.run_threaded(num_threads=NUM_THREADS)
+ def test_load_and_step_multiple_physics_parallel(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in range(NUM_STEPS):
+ physics1.step()
+ physics2.step()
+
+ @decorators.run_threaded(num_threads=NUM_THREADS)
+ def test_load_and_step_multiple_physics_sequential(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ for _ in range(NUM_STEPS):
+ physics1.step()
+ del physics1
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in range(NUM_STEPS):
+ physics2.step()
+
+ @decorators.run_threaded(num_threads=NUM_THREADS, calls_per_thread=5)
+ def test_load_physics_and_render(self):
+ physics = engine.Physics.from_xml_string(MODEL)
+
+ # Check that frames aren't repeated - make the cartpole move.
+ physics.set_control([1.0])
+
+ unique_frames = set()
+ for _ in range(NUM_STEPS):
+ physics.step()
+ frame = physics.render(width=320, height=240, camera_id=0)
+ unique_frames.add(frame.tostring())
+
+ self.assertLen(unique_frames, NUM_STEPS)
+
+ @decorators.run_threaded(num_threads=NUM_THREADS, calls_per_thread=5)
+ def test_render_multiple_physics_instances_per_thread_parallel(self):
+ physics1 = engine.Physics.from_xml_string(MODEL)
+ physics2 = engine.Physics.from_xml_string(MODEL)
+ for _ in range(NUM_STEPS):
+ physics1.step()
+ physics1.render(width=320, height=240, camera_id=0)
+ physics2.step()
+ physics2.render(width=320, height=240, camera_id=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/wrapper/README.md b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/README.md
new file mode 100644
index 0000000..948b460
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/README.md
@@ -0,0 +1,12 @@
+# MuJoCo Wrapper
+
+This package contains Python bindings for the [MuJoCo physics engine][1] using
+[`ctypes`][2]. The bindings and some higher-level wrapper code are automatically
+generated from MuJoCo's header files by `dm_control/autowrap/autowrap.py`.
+
+The main entry point for users of the generated bindings is
+[`dm_control.mujoco`][3].
+
+[1]: http://mujoco.org/
+[2]: https://docs.python.org/2/library/ctypes.html
+[3]: ../README.md
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/wrapper/__init__.py b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/__init__.py
new file mode 100644
index 0000000..4eba41f
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/__init__.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Python bindings and wrapper classes for MuJoCo."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from dm_control.mujoco.wrapper import mjbindings
+
+from dm_control.mujoco.wrapper.core import callback_context
+from dm_control.mujoco.wrapper.core import enable_timer
+
+from dm_control.mujoco.wrapper.core import Error
+
+from dm_control.mujoco.wrapper.core import get_schema
+
+from dm_control.mujoco.wrapper.core import MjData
+from dm_control.mujoco.wrapper.core import MjModel
+from dm_control.mujoco.wrapper.core import MjrContext
+from dm_control.mujoco.wrapper.core import MjvCamera
+from dm_control.mujoco.wrapper.core import MjvFigure
+from dm_control.mujoco.wrapper.core import MjvOption
+from dm_control.mujoco.wrapper.core import MjvPerturb
+from dm_control.mujoco.wrapper.core import MjvScene
+
+from dm_control.mujoco.wrapper.core import save_last_parsed_model_to_xml
+from dm_control.mujoco.wrapper.core import set_callback
+
+from dm_control.mujoco.wrapper.core import UnmanagedMjrContext
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core.py b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core.py
new file mode 100644
index 0000000..de6488a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core.py
@@ -0,0 +1,848 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Main user-facing classes and utility functions for loading MuJoCo models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import ctypes
+import os
+import weakref
+
+from absl import logging
+
+from dm_control.mujoco.wrapper import util
+from dm_control.mujoco.wrapper.mjbindings import constants
+from dm_control.mujoco.wrapper.mjbindings import enums
+from dm_control.mujoco.wrapper.mjbindings import functions
+from dm_control.mujoco.wrapper.mjbindings import mjlib
+from dm_control.mujoco.wrapper.mjbindings import types
+from dm_control.mujoco.wrapper.mjbindings import wrappers
+
+import six
+
+# Internal analytics import.
+# Unused internal import: resources.
+
+_NULL = b"\00"
+_FAKE_XML_FILENAME = b"model.xml"
+_FAKE_BINARY_FILENAME = b"model.mjb"
+
+# Although `mjMAXVFSNAME` from `mjmodel.h` specifies a limit of 100 bytes
+# (including the terminal null byte), the actual limit seems to be 99 bytes
+# (98 characters).
+_MAX_VFS_FILENAME_CHARACTERS = 98
+_VFS_FILENAME_TOO_LONG = (
+ "Filename length {length} exceeds {limit} character limit: {filename}")
+_INVALID_FONT_SCALE = ("`font_scale` must be one of {}, got {{}}."
+ .format(enums.mjtFontScale))
+
+# Global cache used to store finalizers for freeing ctypes pointers.
+# Contains {pointer_address: weakref_object} pairs.
+_FINALIZERS = {}
+
+
+class Error(Exception):
+ """Base class for MuJoCo exceptions."""
+ pass
+
+
+if constants.mjVERSION_HEADER != mjlib.mj_version():
+ raise Error("MuJoCo library version ({0}) does not match header version "
+ "({1})".format(constants.mjVERSION_HEADER, mjlib.mj_version()))
+
+_REGISTERED = False
+_ERROR_BUFSIZE = 1000
+
+# This is used to keep track of the `MJMODEL` pointer that was most recently
+# loaded by `_get_model_ptr_from_xml`. Only this model can be saved to XML.
+_LAST_PARSED_MODEL_PTR = None
+
+_NOT_LAST_PARSED_ERROR = (
+ "Only the model that was most recently loaded from an XML file or string "
+ "can be saved to an XML file.")
+
+import time
+
+# NB: Python functions that are called from C are defined at module-level to
+# ensure they won't be garbage-collected before they are called.
+@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
+def _warning_callback(message):
+ logging.warning(util.to_native_string(message))
+
+
+@ctypes.CFUNCTYPE(None, ctypes.c_char_p)
+def _error_callback(message):
+ logging.fatal(util.to_native_string(message))
+
+
+# Override MuJoCo's callbacks for handling warnings and errors.
+mjlib.mju_user_warning = ctypes.c_void_p.in_dll(mjlib, "mju_user_warning")
+mjlib.mju_user_error = ctypes.c_void_p.in_dll(mjlib, "mju_user_error")
+mjlib.mju_user_warning.value = ctypes.cast(
+ _warning_callback, ctypes.c_void_p).value
+mjlib.mju_user_error.value = ctypes.cast(
+ _error_callback, ctypes.c_void_p).value
+
+
+def enable_timer(enabled=True):
+ if enabled:
+ set_callback("mjcb_time", time.time)
+ else:
+ set_callback("mjcb_time", None)
+
+
+def _maybe_register_license(path=None):
+ """Registers the MuJoCo license if not already registered.
+
+ Args:
+ path: Optional custom path to license key file.
+
+ Raises:
+ Error: If the license could not be registered.
+ """
+ global _REGISTERED
+ if not _REGISTERED:
+ if path is None:
+ path = util.get_mjkey_path()
+ result = mjlib.mj_activate(util.to_binary_string(path))
+ if result == 1:
+ _REGISTERED = True
+ # Internal analytics of mj_activate.
+ elif result == 0:
+ raise Error("Could not register license.")
+ else:
+ raise Error("Unknown registration error (code: {})".format(result))
+
+
+def _str2type(type_str):
+ type_id = mjlib.mju_str2Type(util.to_binary_string(type_str))
+ if not type_id:
+ raise Error("{!r} is not a valid object type name.".format(type_str))
+ return type_id
+
+
+def _type2str(type_id):
+ type_str_ptr = mjlib.mju_type2Str(type_id)
+ if not type_str_ptr:
+ raise Error("{!r} is not a valid object type ID.".format(type_id))
+ return ctypes.string_at(type_str_ptr)
+
+
+def set_callback(name, new_callback=None):
+ """Sets a user-defined callback function to modify MuJoCo's behavior.
+
+ Callback functions should have the following signature:
+ func(const_mjmodel_ptr, mjdata_ptr) -> None
+
+ Args:
+ name: Name of the callback to set. Must be a field in
+ `functions.function_pointers`.
+ new_callback: The new callback. This can be one of the following:
+ * A Python callable
+ * A C function exposed by a `ctypes.CDLL` object
+ * An integer specifying the address of a callback function
+ * None, in which case any existing callback of that name is removed
+ """
+ setattr(functions.callbacks, name, new_callback)
+
+
+@contextlib.contextmanager
+def callback_context(name, new_callback=None):
+ """Context manager that temporarily overrides a MuJoCo callback function.
+
+ On exit, the callback will be restored to its original value (None if the
+ callback was not already overridden when the context was entered).
+
+ Args:
+ name: Name of the callback to set. Must be a field in
+ `mjbindings.function_pointers`.
+ new_callback: The new callback. This can be one of the following:
+ * A Python callable
+ * A C function exposed by a `ctypes.CDLL` object
+ * An integer specifying the address of a callback function
+ * None, in which case any existing callback of that name is removed
+
+ Yields:
+ None
+ """
+ old_callback = getattr(functions.callbacks, name)
+ set_callback(name, new_callback)
+ try:
+ yield
+ finally:
+ # Ensure that the callback is reset on exit, even if an exception is raised.
+ set_callback(name, old_callback)
+
+
+def get_schema():
+ """Returns a string containing the schema used by the MuJoCo XML parser."""
+ buf = ctypes.create_string_buffer(100000)
+ mjlib.mj_printSchema(None, buf, len(buf), 0, 0)
+ return buf.value
+
+
+@contextlib.contextmanager
+def _temporary_vfs(filenames_and_contents):
+ """Creates a temporary VFS containing one or more files.
+
+ Args:
+ filenames_and_contents: A dict containing `{filename: contents}` pairs.
+ The length of each filename must not exceed 98 characters.
+
+ Yields:
+ A `types.MJVFS` instance.
+
+ Raises:
+ Error: If a file cannot be added to the VFS, or if an error occurs when
+ looking up the filename.
+ ValueError: If the length of a filename exceeds 98 characters.
+ """
+ vfs = types.MJVFS()
+ mjlib.mj_defaultVFS(vfs)
+ for filename, contents in six.iteritems(filenames_and_contents):
+ if len(filename) > _MAX_VFS_FILENAME_CHARACTERS:
+ raise ValueError(
+ _VFS_FILENAME_TOO_LONG.format(
+ length=len(filename),
+ limit=_MAX_VFS_FILENAME_CHARACTERS,
+ filename=filename))
+ filename = util.to_binary_string(filename)
+ contents = util.to_binary_string(contents)
+ _, extension = os.path.splitext(filename)
+ # For XML files we need to append a NULL byte, otherwise MuJoCo's parser
+ # can sometimes read past the end of the string. However, we should *not*
+ # do this for other file types (in particular for STL meshes, where this
+ # causes MuJoCo's compiler to complain that the file size is incorrect).
+ append_null = extension.lower() == b".xml"
+ num_bytes = len(contents) + append_null
+ retcode = mjlib.mj_makeEmptyFileVFS(vfs, filename, num_bytes)
+ if retcode == 1:
+ raise Error("Failed to create {!r}: VFS is full.".format(filename))
+ elif retcode == 2:
+ raise Error("Failed to create {!r}: duplicate filename.".format(filename))
+ file_index = mjlib.mj_findFileVFS(vfs, filename)
+ if file_index == -1:
+ raise Error("Could not find {!r} in the VFS".format(filename))
+ vf = vfs.filedata[file_index]
+ vf_as_char_arr = ctypes.cast(vf, ctypes.POINTER(ctypes.c_char * num_bytes))
+ vf_as_char_arr.contents[:len(contents)] = contents
+ if append_null:
+ vf_as_char_arr.contents[-1] = _NULL
+ try:
+ yield vfs
+ finally:
+ mjlib.mj_deleteVFS(vfs) # Ensure that we free the VFS afterwards.
+
+
+def _create_finalizer(ptr, free_func):
+ """Creates a finalizer for a ctypes pointer.
+
+ Args:
+ ptr: A `ctypes.POINTER` to be freed.
+ free_func: A callable that frees the pointer. It will be called with `ptr`
+ as its only argument when `ptr` is garbage collected.
+ """
+ ptr_type = type(ptr)
+ address = ctypes.addressof(ptr.contents)
+
+ if address not in _FINALIZERS: # Only one finalizer needed per address.
+
+ logging.debug("Allocated %s at %x", ptr_type.__name__, address)
+
+ def callback(dead_ptr_ref):
+ """A weakref callback that frees the resource held by a pointer."""
+ del dead_ptr_ref # Unused weakref to the dead ctypes pointer object.
+ if address not in _FINALIZERS:
+ # Someone had already explicitly called `call_finalizer_for_pointer`.
+ return
+ else:
+ # Turn the address back into a pointer to be freed.
+ temp_ptr = ctypes.cast(address, ptr_type)
+ free_func(temp_ptr)
+ logging.debug("Freed %s at %x", ptr_type.__name__, address)
+ del _FINALIZERS[address] # Remove the weakref from the global cache.
+
+ # Store weakrefs in a global cache so that they don't get garbage collected
+ # before their referents.
+ _FINALIZERS[address] = (weakref.ref(ptr, callback), callback)
+
+
+def _finalize(ptr):
+ """Calls the finalizer for the specified pointer to free allocated memory."""
+ address = ctypes.addressof(ptr.contents)
+ try:
+ ptr_ref, callback = _FINALIZERS[address]
+ callback(ptr_ref)
+ except KeyError:
+ pass
+
+
+def _load_xml(filename, vfs_or_none):
+ """Invokes `mj_loadXML` with logging/error handling."""
+ error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
+ model_ptr = mjlib.mj_loadXML(
+ util.to_binary_string(filename),
+ vfs_or_none,
+ error_buf,
+ _ERROR_BUFSIZE)
+ if not model_ptr:
+ raise Error(util.to_native_string(error_buf.value))
+ elif error_buf.value:
+ logging.warning(util.to_native_string(error_buf.value))
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(model_ptr, mjlib.mj_deleteModel)
+
+ return model_ptr
+
+
+def _get_model_ptr_from_xml(xml_path=None, xml_string=None, assets=None):
+ """Parses a model XML file, compiles it, and returns a pointer to an mjModel.
+
+ Args:
+ xml_path: None or a path to a model XML file in MJCF or URDF format.
+ xml_string: None or an XML string containing an MJCF or URDF model
+ description.
+ assets: None or a dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML. Ignored if `xml_string` is None.
+
+ One of `xml_path` or `xml_string` must be specified.
+
+ Returns:
+ A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.
+
+ Raises:
+ TypeError: If both or neither of `xml_path` and `xml_string` are specified.
+ Error: If the model is not created successfully.
+ """
+ if xml_path is None and xml_string is None:
+ raise TypeError(
+ "At least one of `xml_path` or `xml_string` must be specified.")
+ elif xml_path is not None and xml_string is not None:
+ raise TypeError(
+ "Only one of `xml_path` or `xml_string` may be specified.")
+
+ _maybe_register_license()
+
+ if xml_string is not None:
+ assets = {} if assets is None else assets.copy()
+ # Ensure that the fake XML filename doesn't overwrite an existing asset.
+ xml_path = _FAKE_XML_FILENAME
+ while xml_path in assets:
+ xml_path = "_" + xml_path
+ assets[xml_path] = xml_string
+ with _temporary_vfs(assets) as vfs:
+ ptr = _load_xml(xml_path, vfs)
+ else:
+ ptr = _load_xml(xml_path, None)
+
+ global _LAST_PARSED_MODEL_PTR
+ _LAST_PARSED_MODEL_PTR = ptr
+
+ return ptr
+
+
+def save_last_parsed_model_to_xml(xml_path, check_model=None):
+ """Writes a description of the most recently loaded model to an MJCF XML file.
+
+ Args:
+ xml_path: Path to the output XML file.
+ check_model: Optional `MjModel` instance. If specified, this model will be
+ checked to see if it is the most recently parsed one, and a ValueError
+ will be raised otherwise.
+ Raises:
+ Error: If MuJoCo encounters an error while writing the XML file.
+ ValueError: If `check_model` was passed, and this model is not the most
+ recently parsed one.
+ """
+ if check_model and check_model.ptr is not _LAST_PARSED_MODEL_PTR:
+ raise ValueError(_NOT_LAST_PARSED_ERROR)
+ error_buf = ctypes.create_string_buffer(_ERROR_BUFSIZE)
+ mjlib.mj_saveLastXML(util.to_binary_string(xml_path),
+ _LAST_PARSED_MODEL_PTR,
+ error_buf,
+ _ERROR_BUFSIZE)
+ if error_buf.value:
+ raise Error(error_buf.value)
+
+
+def _get_model_ptr_from_binary(binary_path=None, byte_string=None):
+ """Returns a pointer to an mjModel from the contents of a MuJoCo model binary.
+
+ Args:
+ binary_path: Path to an MJB file (as produced by MjModel.save_binary).
+ byte_string: String of bytes (as returned by MjModel.to_bytes).
+
+ One of `binary_path` or `byte_string` must be specified.
+
+ Returns:
+ A `ctypes.POINTER` to a new `mjbindings.types.MJMODEL` instance.
+
+ Raises:
+ TypeError: If both or neither of `byte_string` and `binary_path`
+ are specified.
+ """
+ if binary_path is None and byte_string is None:
+ raise TypeError(
+ "At least one of `byte_string` or `binary_path` must be specified.")
+ elif binary_path is not None and byte_string is not None:
+ raise TypeError(
+ "Only one of `byte_string` or `binary_path` may be specified.")
+
+ _maybe_register_license()
+
+ if byte_string is not None:
+ with _temporary_vfs({_FAKE_BINARY_FILENAME: byte_string}) as vfs:
+ ptr = mjlib.mj_loadModel(_FAKE_BINARY_FILENAME, vfs)
+ else:
+ ptr = mjlib.mj_loadModel(util.to_binary_string(binary_path), None)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(ptr, mjlib.mj_deleteModel)
+
+ return ptr
+
+
+# Subclasses implementing constructors/destructors for low-level wrappers.
+# ------------------------------------------------------------------------------
+
+
+class MjModel(wrappers.MjModelWrapper):
+ """Wrapper class for a MuJoCo 'mjModel' instance.
+
+ MjModel encapsulates features of the model that are expected to remain
+ constant. It also contains simulation and visualization options which may be
+ changed occasionally, although this is done explicitly by the user.
+ """
+
+ def __init__(self, model_ptr):
+ """Creates a new MjModel instance from a ctypes pointer.
+
+ Args:
+ model_ptr: A `ctypes.POINTER` to an `mjbindings.types.MJMODEL` instance.
+ """
+ super(MjModel, self).__init__(ptr=model_ptr)
+
+ def __getstate__(self):
+ # All of MjModel's state is assumed to reside within the MuJoCo C struct.
+ # However there is no mechanism to prevent users from adding arbitrary
+ # Python attributes to an MjModel instance - these would not be serialized.
+ return self.to_bytes()
+
+ def __setstate__(self, byte_string):
+ model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
+ self.__init__(model_ptr)
+
+ def __copy__(self):
+ new_model_ptr = mjlib.mj_copyModel(None, self.ptr)
+ return self.__class__(new_model_ptr)
+
+ @classmethod
+ def from_xml_string(cls, xml_string, assets=None):
+ """Creates an `MjModel` instance from a model description XML string.
+
+ Args:
+ xml_string: String containing an MJCF or URDF model description.
+ assets: Optional dict containing external assets referenced by the model
+ (such as additional XML files, textures, meshes etc.), in the form of
+ `{filename: contents_string}` pairs. The keys should correspond to the
+ filenames specified in the model XML.
+
+ Returns:
+ An `MjModel` instance.
+ """
+ model_ptr = _get_model_ptr_from_xml(xml_string=xml_string, assets=assets)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_byte_string(cls, byte_string):
+ """Creates an MjModel instance from a model binary as a string of bytes."""
+ model_ptr = _get_model_ptr_from_binary(byte_string=byte_string)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_xml_path(cls, xml_path):
+ """Creates an MjModel instance from a path to a model XML file."""
+ model_ptr = _get_model_ptr_from_xml(xml_path=xml_path)
+ return cls(model_ptr)
+
+ @classmethod
+ def from_binary_path(cls, binary_path):
+ """Creates an MjModel instance from a path to a compiled model binary."""
+ model_ptr = _get_model_ptr_from_binary(binary_path=binary_path)
+ return cls(model_ptr)
+
+ def save_binary(self, binary_path):
+ """Saves the MjModel instance to a binary file."""
+ mjlib.mj_saveModel(self.ptr, util.to_binary_string(binary_path), None, 0)
+
+ def to_bytes(self):
+ """Serialize the model to a string of bytes."""
+ bufsize = mjlib.mj_sizeModel(self.ptr)
+ buf = ctypes.create_string_buffer(bufsize)
+ mjlib.mj_saveModel(self.ptr, None, buf, bufsize)
+ return buf.raw
+
+ def copy(self):
+ """Returns a copy of this MjModel instance."""
+ return self.__copy__()
+
+ def free(self):
+ """Frees the native resources held by this MjModel.
+
+ This is an advanced feature for use when manual memory management is
+ necessary. This MjModel object MUST NOT be used after this function has
+ been called.
+ """
+ _finalize(self._ptr)
+ del self._ptr
+
+ def name2id(self, name, object_type):
+ """Returns the integer ID of a specified MuJoCo object.
+
+ Args:
+ name: String specifying the name of the object to query.
+ object_type: The type of the object. Can be either a lowercase string
+ (e.g. 'body', 'geom') or an `mjtObj` enum value.
+
+ Returns:
+ An integer object ID.
+
+ Raises:
+ Error: If `object_type` is not a valid MuJoCo object type, or if no object
+ with the corresponding name and type was found.
+ """
+ if not isinstance(object_type, int):
+ object_type = _str2type(object_type)
+ obj_id = mjlib.mj_name2id(
+ self.ptr, object_type, util.to_binary_string(name))
+ if obj_id == -1:
+ raise Error("Object of type {!r} with name {!r} does not exist.".format(
+ _type2str(object_type), name))
+ return obj_id
+
+ def id2name(self, object_id, object_type):
+ """Returns the name associated with a MuJoCo object ID, if there is one.
+
+ Args:
+ object_id: Integer ID.
+ object_type: The type of the object. Can be either a lowercase string
+ (e.g. 'body', 'geom') or an `mjtObj` enum value.
+
+ Returns:
+ A string containing the object name, or an empty string if the object ID
+ either doesn't exist or has no name.
+
+ Raises:
+ Error: If `object_type` is not a valid MuJoCo object type.
+ """
+ if not isinstance(object_type, int):
+ object_type = _str2type(object_type)
+ name_ptr = mjlib.mj_id2name(self.ptr, object_type, object_id)
+ if not name_ptr:
+ return ""
+ return util.to_native_string(ctypes.string_at(name_ptr))
+
+ @contextlib.contextmanager
+ def disable(self, *flags):
+ """Context manager for temporarily disabling MuJoCo flags.
+
+ Args:
+ *flags: Positional arguments specifying flags to disable. Can be either
+ lowercase strings (e.g. 'gravity', 'contact') or `mjtDisableBit` enum
+ values.
+
+ Yields:
+ None
+
+ Raises:
+ ValueError: If any item in `flags` is neither a valid name nor a value
+ from `enums.mjtDisableBit`.
+ """
+ old_bitmask = self.opt.disableflags
+ new_bitmask = old_bitmask
+ for flag in flags:
+ if isinstance(flag, six.string_types):
+ try:
+ field_name = "mjDSBL_" + flag.upper()
+ bitmask = getattr(enums.mjtDisableBit, field_name)
+ except AttributeError:
+ valid_names = [field_name.split("_")[1].lower()
+ for field_name in enums.mjtDisableBit._fields[:-1]]
+ raise ValueError("'{}' is not a valid flag name. Valid names: {}"
+ .format(flag, ", ".join(valid_names)))
+ else:
+ if flag not in enums.mjtDisableBit[:-1]:
+ raise ValueError("'{}' is not a value in `enums.mjtDisableBit`. "
+ "Valid values: {}"
+ .format(flag, tuple(enums.mjtDisableBit[:-1])))
+ bitmask = flag
+ new_bitmask |= bitmask
+ self.opt.disableflags = new_bitmask
+ try:
+ yield
+ finally:
+ self.opt.disableflags = old_bitmask
+
+ @property
+ def name(self):
+ """Returns the name of the model."""
+ # The model name is the first null-terminated string in the `names` buffer.
+ return util.to_native_string(
+ ctypes.string_at(ctypes.addressof(self.names.contents)))
+
+
+class MjData(wrappers.MjDataWrapper):
+ """Wrapper class for a MuJoCo 'mjData' instance.
+
+ MjData contains all of the dynamic variables and intermediate results produced
+ by the simulation. These are expected to change on each simulation timestep.
+ """
+
+ def __init__(self, model):
+ """Construct a new MjData instance.
+
+ Args:
+ model: An MjModel instance.
+ """
+ self._model = model
+
+ # Allocate resources for mjData.
+ data_ptr = mjlib.mj_makeData(model.ptr)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(data_ptr, mjlib.mj_deleteData)
+
+ super(MjData, self).__init__(data_ptr, model)
+
+ def __getstate__(self):
+ # Note: we can replace this once a `saveData` MJAPI function exists.
+ # To reconstruct an MjData instance we need three things:
+ # 1. Its parent MjModel instance
+ # 2. A subset of its fixed-size fields whose values aren't determined by
+ # the model
+ # 3. The contents of its internal buffer (all of its pointer fields point
+ # into this)
+ struct_fields = {}
+ for name in ["solver", "timer", "warning"]:
+ struct_fields[name] = getattr(self, name).copy()
+ scalar_field_names = ["ncon", "time", "energy"]
+ scalar_fields = {name: getattr(self, name) for name in scalar_field_names}
+ static_fields = {"struct_fields": struct_fields,
+ "scalar_fields": scalar_fields}
+ buffer_contents = ctypes.string_at(self.buffer_, self.nbuffer)
+ return (self._model, static_fields, buffer_contents)
+
+ def __setstate__(self, state_tuple):
+ # Replace this once a `loadData` MJAPI function exists.
+ self._model, static_fields, buffer_contents = state_tuple
+ self.__init__(self.model)
+ for name, contents in six.iteritems(static_fields["struct_fields"]):
+ getattr(self, name)[:] = contents
+
+ for name, value in six.iteritems(static_fields["scalar_fields"]):
+ # Array and scalar values must be handled separately.
+ try:
+ getattr(self, name)[:] = value
+ except TypeError:
+ setattr(self, name, value)
+ buf_ptr = (ctypes.c_char * self.nbuffer).from_address(self.buffer_)
+ buf_ptr[:] = buffer_contents
+
+ def __copy__(self):
+ # This makes a shallow copy that shares the same parent MjModel instance.
+ new_obj = self.__class__(self.model)
+ mjlib.mj_copyData(new_obj.ptr, self.model.ptr, self.ptr)
+ return new_obj
+
+ def copy(self):
+ """Returns a copy of this MjData instance with the same parent MjModel."""
+ return self.__copy__()
+
+ def free(self):
+ """Frees the native resources held by this MjData.
+
+ This is an advanced feature for use when manual memory management is
+ necessary. This MjData object MUST NOT be used after this function has
+ been called.
+ """
+ _finalize(self._ptr)
+ del self._ptr
+
+ @property
+ def model(self):
+ """The parent MjModel for this MjData instance."""
+ return self._model
+
+ @util.CachedProperty
+ def _contact_buffer(self):
+ """Cached structured array containing the full contact buffer."""
+ contact_array = util.buf_to_npy(
+ super(MjData, self).contact, shape=(self._model.nconmax,))
+ return contact_array
+
+ @property
+ def contact(self):
+ """Variable-length recarray containing all current contacts."""
+ return self._contact_buffer[:self.ncon]
+
+
+# Docstrings for these subclasses are inherited from their Wrapper parent class.
+
+
+class MjvCamera(wrappers.MjvCameraWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVCAMERA())
+ mjlib.mjv_defaultCamera(ptr)
+ super(MjvCamera, self).__init__(ptr)
+
+
+class MjvOption(wrappers.MjvOptionWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVOPTION())
+ mjlib.mjv_defaultOption(ptr)
+ # Do not visualize rangefinder lines by default:
+ ptr.contents.flags[enums.mjtVisFlag.mjVIS_RANGEFINDER] = False
+ super(MjvOption, self).__init__(ptr)
+
+
+class UnmanagedMjrContext(wrappers.MjrContextWrapper):
+ """A wrapper for MjrContext that does not manage the native object's lifetime.
+
+ This wrapper is provided for API backward-compatibility reasons, since the
+ creating and destruction of an MjrContext requires an OpenGL context to be
+ provided.
+ """
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJRCONTEXT())
+ mjlib.mjr_defaultContext(ptr)
+ super(UnmanagedMjrContext, self).__init__(ptr)
+
+
+class MjrContext(wrappers.MjrContextWrapper): # pylint: disable=missing-docstring
+
+ def __init__(self,
+ model,
+ gl_context,
+ font_scale=enums.mjtFontScale.mjFONTSCALE_150):
+ """Initializes this MjrContext instance.
+
+ Args:
+ model: An `MjModel` instance.
+ gl_context: A `render.ContextBase` instance.
+ font_scale: Integer controlling the font size for text. Must be a value
+ in `mjbindings.enums.mjtFontScale`.
+
+ Raises:
+ ValueError: If `font_scale` is invalid.
+ """
+ if font_scale not in enums.mjtFontScale:
+ raise ValueError(_INVALID_FONT_SCALE.format(font_scale))
+
+ ptr = ctypes.pointer(types.MJRCONTEXT())
+ mjlib.mjr_defaultContext(ptr)
+
+ with gl_context.make_current() as ctx:
+ ctx.call(mjlib.mjr_makeContext, model.ptr, ptr, font_scale)
+ ctx.call(mjlib.mjr_setBuffer, enums.mjtFramebuffer.mjFB_OFFSCREEN, ptr)
+ gl_context.increment_refcount()
+
+ # Free resources when the ctypes pointer is garbage collected.
+ def finalize_mjr_context(ptr):
+ if not gl_context.terminated:
+ with gl_context.make_current() as ctx:
+ ctx.call(mjlib.mjr_freeContext, ptr)
+ gl_context.decrement_refcount()
+
+ _create_finalizer(ptr, finalize_mjr_context)
+
+ super(MjrContext, self).__init__(ptr)
+
+ def free(self):
+ """Frees the native resources held by this MjrContext.
+
+ This is an advanced feature for use when manual memory management is
+ necessary. This MjrContext object MUST NOT be used after this function has
+ been called.
+ """
+ _finalize(self._ptr)
+ del self._ptr
+
+
+class MjvScene(wrappers.MjvSceneWrapper): # pylint: disable=missing-docstring
+
+ def __init__(self, model=None, max_geom=1000):
+ """Initializes a new `MjvScene` instance.
+
+ Args:
+ model: (optional) An `MjModel` instance.
+ max_geom: (optional) An integer specifying the maximum number of geoms
+ that can be represented in the scene.
+ """
+ model_ptr = model.ptr if model is not None else None
+ scene_ptr = ctypes.pointer(types.MJVSCENE())
+
+ # Allocate and initialize resources for the abstract scene.
+ mjlib.mjv_makeScene(model_ptr, scene_ptr, max_geom)
+
+ # Free resources when the ctypes pointer is garbage collected.
+ _create_finalizer(scene_ptr, mjlib.mjv_freeScene)
+
+ super(MjvScene, self).__init__(scene_ptr)
+
+ def free(self):
+ """Frees the native resources held by this MjvScene.
+
+ This is an advanced feature for use when manual memory management is
+ necessary. This MjvScene object MUST NOT be used after this function has
+ been called.
+ """
+ _finalize(self._ptr)
+ del self._ptr
+
+ @util.CachedProperty
+ def _geoms_buffer(self):
+ """Cached recarray containing the full geom buffer."""
+ return util.buf_to_npy(super(MjvScene, self).geoms, shape=(self.maxgeom,))
+
+ @property
+ def geoms(self):
+ """Variable-length recarray containing all geoms currently in the buffer."""
+ return self._geoms_buffer[:self.ngeom]
+
+
+class MjvPerturb(wrappers.MjvPerturbWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVPERTURB())
+ mjlib.mjv_defaultPerturb(ptr)
+ super(MjvPerturb, self).__init__(ptr)
+
+
+class MjvFigure(wrappers.MjvFigureWrapper):
+
+ def __init__(self):
+ ptr = ctypes.pointer(types.MJVFIGURE())
+ mjlib.mjv_defaultFigure(ptr)
+ super(MjvFigure, self).__init__(ptr)
diff --git a/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core_test.py b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core_test.py
new file mode 100644
index 0000000..9fd942a
--- /dev/null
+++ b/DMC/src/env/dm_control/dm_control/mujoco/wrapper/core_test.py
@@ -0,0 +1,552 @@
+# Copyright 2017 The dm_control Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""Tests for core.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ctypes
+import gc
+import os
+
+# Internal dependencies.
+from absl.testing import absltest
+from absl.testing import parameterized
+from dm_control import _render
+from dm_control.mujoco.testing import assets
+from dm_control.mujoco.wrapper import core
+from dm_control.mujoco.wrapper import mjbindings
+from dm_control.mujoco.wrapper.mjbindings import enums
+import mock
+import numpy as np
+import six
+from six.moves import cPickle
+from six.moves import range
+
+mjlib = mjbindings.mjlib
+
+HUMANOID_XML_PATH = assets.get_path("humanoid.xml")
+MODEL_WITH_ASSETS = assets.get_contents("model_with_assets.xml")
+ASSETS = {
+ "texture.png": assets.get_contents("deepmind.png"),
+ "mesh.stl": assets.get_contents("cube.stl"),
+ "included.xml": assets.get_contents("sphere.xml")
+}
+
+SCALAR_TYPES = (int, float)
+ARRAY_TYPES = (np.ndarray,)
+
+OUT_DIR = absltest.get_default_test_tmpdir()
+if not os.path.exists(OUT_DIR):
+ os.makedirs(OUT_DIR) # Ensure that the output directory exists.
+
+
+class CoreTest(parameterized.TestCase):
+
+ def setUp(self):
+ super(CoreTest, self).setUp()
+ self.model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ self.data = core.MjData(self.model)
+
+ def _assert_attributes_equal(self, actual_obj, expected_obj, attr_to_compare):
+ for name in attr_to_compare:
+ actual_value = getattr(actual_obj, name)
+ expected_value = getattr(expected_obj, name)
+ try:
+ if isinstance(expected_value, np.ndarray):
+ np.testing.assert_array_equal(actual_value, expected_value)
+ else:
+ self.assertEqual(actual_value, expected_value)
+ except AssertionError as e:
+ self.fail("Attribute '{}' differs from expected value: {}"
+ .format(name, str(e)))
+
+ def testLoadXML(self):
+ with open(HUMANOID_XML_PATH, "r") as f:
+ xml_string = f.read()
+ model = core.MjModel.from_xml_string(xml_string)
+ core.MjData(model)
+ with self.assertRaises(TypeError):
+ core.MjModel()
+ with self.assertRaises(core.Error):
+ core.MjModel.from_xml_path("/path/to/nonexistent/model/file.xml")
+
+ xml_with_warning = """
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ # This model should compile successfully, but raise a warning on the first
+ # simulation step.
+ model = core.MjModel.from_xml_string(xml_with_warning)
+ data = core.MjData(model)
+ with mock.patch.object(core, "logging") as mock_logging:
+ mjlib.mj_step(model.ptr, data.ptr)
+ mock_logging.warning.assert_called_once_with(
+ "Pre-allocated constraint buffer is full. Increase njmax above 2. "
+ "Time = 0.0000.")
+
+ def testLoadXMLWithAssetsFromString(self):
+ core.MjModel.from_xml_string(MODEL_WITH_ASSETS, assets=ASSETS)
+ with self.assertRaises(core.Error):
+ # Should fail to load without the assets
+ core.MjModel.from_xml_string(MODEL_WITH_ASSETS)
+
+ def testVFSFilenameTooLong(self):
+ limit = core._MAX_VFS_FILENAME_CHARACTERS
+ contents = "fake contents"
+ valid_filename = "a" * limit
+ with core._temporary_vfs({valid_filename: contents}):
+ pass
+ invalid_filename = "a" * (limit + 1)
+ expected_message = core._VFS_FILENAME_TOO_LONG.format(
+ length=(limit + 1), limit=limit, filename=invalid_filename)
+ with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
+ with core._temporary_vfs({invalid_filename: contents}):
+ pass
+
+ def testSaveLastParsedModelToXML(self):
+ save_xml_path = os.path.join(OUT_DIR, "tmp_humanoid.xml")
+
+ not_last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+
+ # Modify the model before saving it in order to confirm that the changes are
+ # written to the XML.
+ last_parsed.geom_pos.flat[:] = np.arange(last_parsed.geom_pos.size)
+
+ core.save_last_parsed_model_to_xml(save_xml_path, check_model=last_parsed)
+
+ loaded = core.MjModel.from_xml_path(save_xml_path)
+ self._assert_attributes_equal(last_parsed, loaded, ["geom_pos"])
+ core.MjData(loaded)
+
+ # Test that `check_model` results in a ValueError if it is not the most
+ # recently parsed model.
+ with self.assertRaisesWithLiteralMatch(
+ ValueError, core._NOT_LAST_PARSED_ERROR):
+ core.save_last_parsed_model_to_xml(save_xml_path,
+ check_model=not_last_parsed)
+
+ def testBinaryIO(self):
+ bin_path = os.path.join(OUT_DIR, "tmp_humanoid.mjb")
+ self.model.save_binary(bin_path)
+ core.MjModel.from_binary_path(bin_path)
+ byte_string = self.model.to_bytes()
+ core.MjModel.from_byte_string(byte_string)
+
+ def testDimensions(self):
+ self.assertEqual(self.data.qpos.shape[0], self.model.nq)
+ self.assertEqual(self.data.qvel.shape[0], self.model.nv)
+ self.assertEqual(self.model.body_pos.shape, (self.model.nbody, 3))
+
+ def testStep(self):
+ t0 = self.data.time
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
+ self.assertTrue(np.all(np.isfinite(self.data.qpos[:])))
+ self.assertTrue(np.all(np.isfinite(self.data.qvel[:])))
+
+ def testMultipleData(self):
+ data2 = core.MjData(self.model)
+ self.assertNotEqual(self.data.ptr, data2.ptr)
+ t0 = self.data.time
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
+ self.assertEqual(data2.time, 0)
+
+ def testMultipleModel(self):
+ model2 = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
+ self.assertNotEqual(self.model.ptr, model2.ptr)
+ self.model.opt.timestep += 0.001
+ self.assertEqual(self.model.opt.timestep, model2.opt.timestep + 0.001)
+
+ def testModelName(self):
+ self.assertEqual(self.model.name, "humanoid")
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleModel(self, func):
+ timestep = 0.12345
+ self.model.opt.timestep = timestep
+ body_pos = self.model.body_pos + 1
+ self.model.body_pos[:] = body_pos
+ model2 = func(self.model)
+ self.assertNotEqual(model2.ptr, self.model.ptr)
+ self.assertEqual(model2.opt.timestep, timestep)
+ np.testing.assert_array_equal(model2.body_pos, body_pos)
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleData(self, func):
+ for _ in range(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ data2 = func(self.data)
+ attr_to_compare = ("time", "energy", "qpos", "xpos")
+ self.assertNotEqual(data2.ptr, self.data.ptr)
+ self._assert_attributes_equal(data2, self.data, attr_to_compare)
+ for _ in range(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mjlib.mj_step(data2.model.ptr, data2.ptr)
+ self._assert_attributes_equal(data2, self.data, attr_to_compare)
+
+ @parameterized.named_parameters(
+ ("_copy", lambda x: x.copy()),
+ ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),)
+ def testCopyOrPickleStructs(self, func):
+ for _ in range(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ data2 = func(self.data)
+ self.assertNotEqual(data2.ptr, self.data.ptr)
+ attr_to_compare = ("warning", "timer", "solver")
+ self._assert_attributes_equal(self.data, data2, attr_to_compare)
+ for _ in range(10):
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mjlib.mj_step(data2.model.ptr, data2.ptr)
+ self._assert_attributes_equal(self.data, data2, attr_to_compare)
+
+ @parameterized.parameters(
+ ("right_foot", "body", 6),
+ ("right_foot", enums.mjtObj.mjOBJ_BODY, 6),
+ ("left_knee", "joint", 11),
+ ("left_knee", enums.mjtObj.mjOBJ_JOINT, 11))
+ def testNamesIds(self, name, object_type, object_id):
+ output_id = self.model.name2id(name, object_type)
+ self.assertEqual(object_id, output_id)
+ output_name = self.model.id2name(object_id, object_type)
+ self.assertEqual(name, output_name)
+
+ def testNamesIdsExceptions(self):
+ with six.assertRaisesRegex(self, core.Error, "does not exist"):
+ self.model.name2id("nonexistent_body_name", "body")
+ with six.assertRaisesRegex(self, core.Error, "is not a valid object type"):
+ self.model.name2id("right_foot", "nonexistent_type_name")
+
+ def testNamelessObject(self):
+ # The model in humanoid.xml contains a single nameless camera.
+ name = self.model.id2name(0, "camera")
+ self.assertEqual("", name)
+
+ def testWarningCallback(self):
+ self.data.qpos[0] = np.inf
+ with mock.patch.object(core, "logging") as mock_logging:
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ mock_logging.warning.assert_called_once_with(
+ "Nan, Inf or huge value in QPOS at DOF 0. The simulation is unstable. "
+ "Time = 0.0000.")
+
+ def testErrorCallback(self):
+ with mock.patch.object(core, "logging") as mock_logging:
+ mjlib.mj_activate(b"nonexistent_activation_key")
+ mock_logging.fatal.assert_called_once_with(
+ "Could not open activation key file nonexistent_activation_key")
+
+ def testSingleCallbackContext(self):
+
+ callback_was_called = [False]
+
+ def callback(unused_model, unused_data):
+ callback_was_called[0] = True
+
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertFalse(callback_was_called[0])
+
+ class DummyError(RuntimeError):
+ pass
+
+ try:
+ with core.callback_context("mjcb_passive", callback):
+
+ # Stepping invokes the `mjcb_passive` callback.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertTrue(callback_was_called[0])
+
+ # Exceptions should not prevent `mjcb_passive` from being reset.
+ raise DummyError("Simulated exception.")
+
+ except DummyError:
+ pass
+
+ # `mjcb_passive` should have been reset to None.
+ callback_was_called[0] = False
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertFalse(callback_was_called[0])
+
+ def testNestedCallbackContexts(self):
+
+ last_called = [None]
+ outer_called = "outer called"
+ inner_called = "inner called"
+
+ def outer(unused_model, unused_data):
+ last_called[0] = outer_called
+
+ def inner(unused_model, unused_data):
+ last_called[0] = inner_called
+
+ with core.callback_context("mjcb_passive", outer):
+
+ # This should execute `outer` a few times.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], outer_called)
+
+ with core.callback_context("mjcb_passive", inner):
+
+ # This should execute `inner` a few times.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], inner_called)
+
+ # When we exit the inner context, the `mjcb_passive` callback should be
+ # reset to `outer`.
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertEqual(last_called[0], outer_called)
+
+ # When we exit the outer context, the `mjcb_passive` callback should be
+ # reset to None, and stepping should not affect `last_called`.
+ last_called[0] = None
+ mjlib.mj_step(self.model.ptr, self.data.ptr)
+ self.assertIsNone(last_called[0])
+
+ def testDisableFlags(self):
+ xml_string = """
+
+