diff --git a/DMC/LICENSE b/DMC/LICENSE new file mode 100644 index 0000000..a6418fc --- /dev/null +++ b/DMC/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Anonymous + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/DMC/README.md b/DMC/README.md new file mode 100644 index 0000000..0d88872 --- /dev/null +++ b/DMC/README.md @@ -0,0 +1,53 @@ + +# Overview +![Perf](figures/sgqnarchi.png) +![Perf](figures/sgqn_perf.png) + +## Setup +We assume that you have access to a GPU with CUDA >=9.2 support. All dependencies can then be installed with the following commands: + +``` +conda env create -f setup/conda.yml +conda activate dmcgb +sh setup/install_envs.sh +``` + +If you don't have the right mujoco version installed: +``` +sh setup/install_mujoco_deps.sh +sh setup/prepare_dm_control_xp.sh +``` + + +## Datasets + +``` +wget http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar +``` +After downloading and extracting the data, add your dataset directory to the datasets list in `setup/config.cfg`. + +# Run `SGQN` training: +```bash +python src/train.py --algorithm sgsac --sgqn_quantile [QUANTILE] --seed [SEED] --eval_mode video_easy --domain_name [DOMAIN] --task_name [TASK]; +``` +You can also run `SVEA`, `SODA`, `RAD`, `SAC`. + + +# DMControl Generalization Benchmark + +Benchmark for generalization in continuous control from pixels, based on [DMControl](https://github.com/deepmind/dm_control). + +## Test environments + +The DMControl Generalization Benchmark provides two distinct benchmarks for visual generalization, *random colors* and *video backgrounds*: + +![environment samples](figures/environments.png) + +Both benchmarks are offered in *easy* and *hard* variants. Samples are shown below. + +**video_easy**
+![video_easy](figures/video_easy.png) + +**video_hard**
+![video_hard](figures/video_hard.png) + diff --git a/DMC/figures/color_easy.png b/DMC/figures/color_easy.png new file mode 100644 index 0000000..9b30e84 Binary files /dev/null and b/DMC/figures/color_easy.png differ diff --git a/DMC/figures/color_hard.png b/DMC/figures/color_hard.png new file mode 100644 index 0000000..c517d55 Binary files /dev/null and b/DMC/figures/color_hard.png differ diff --git a/DMC/figures/environments.png b/DMC/figures/environments.png new file mode 100644 index 0000000..5563ad9 Binary files /dev/null and b/DMC/figures/environments.png differ diff --git a/DMC/figures/results_table.png b/DMC/figures/results_table.png new file mode 100644 index 0000000..3ef8d29 Binary files /dev/null and b/DMC/figures/results_table.png differ diff --git a/DMC/figures/sgqn_perf.png b/DMC/figures/sgqn_perf.png new file mode 100644 index 0000000..f4f4901 Binary files /dev/null and b/DMC/figures/sgqn_perf.png differ diff --git a/DMC/figures/sgqnarchi.png b/DMC/figures/sgqnarchi.png new file mode 100644 index 0000000..497f1f5 Binary files /dev/null and b/DMC/figures/sgqnarchi.png differ diff --git a/DMC/figures/video_easy.png b/DMC/figures/video_easy.png new file mode 100644 index 0000000..a81692c Binary files /dev/null and b/DMC/figures/video_easy.png differ diff --git a/DMC/figures/video_hard.png b/DMC/figures/video_hard.png new file mode 100644 index 0000000..095aea2 Binary files /dev/null and b/DMC/figures/video_hard.png differ diff --git a/DMC/scripts/curl.sh b/DMC/scripts/curl.sh new file mode 100644 index 0000000..4995c07 --- /dev/null +++ b/DMC/scripts/curl.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm curl \ + --aux_update_freq 1 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/drq.sh b/DMC/scripts/drq.sh new file mode 100644 index 0000000..a8ffe56 --- /dev/null +++ b/DMC/scripts/drq.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm drq \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/curl.sh b/DMC/scripts/eval/curl.sh new file mode 100644 index 0000000..0b6b820 --- /dev/null +++ b/DMC/scripts/eval/curl.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm curl \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/drq.sh b/DMC/scripts/eval/drq.sh new file mode 100644 index 0000000..7dace5c --- /dev/null +++ b/DMC/scripts/eval/drq.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm drq \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/pad.sh b/DMC/scripts/eval/pad.sh new file mode 100644 index 0000000..38afc8f --- /dev/null +++ b/DMC/scripts/eval/pad.sh @@ -0,0 +1,6 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm pad \ + --num_shared_layers 8 \ + --num_head_layers 3 \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/rad.sh b/DMC/scripts/eval/rad.sh new file mode 100644 index 0000000..70c55a7 --- /dev/null +++ b/DMC/scripts/eval/rad.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm rad \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/sac.sh b/DMC/scripts/eval/sac.sh new file mode 100644 index 0000000..130dca0 --- /dev/null +++ b/DMC/scripts/eval/sac.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm sac \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/soda.sh b/DMC/scripts/eval/soda.sh new file mode 100644 index 0000000..49f3acf --- /dev/null +++ b/DMC/scripts/eval/soda.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm soda \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/eval/svea.sh b/DMC/scripts/eval/svea.sh new file mode 100644 index 0000000..dabd734 --- /dev/null +++ b/DMC/scripts/eval/svea.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/eval.py \ + --algorithm svea \ + --eval_episodes 100 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/pad.sh b/DMC/scripts/pad.sh new file mode 100644 index 0000000..02093d1 --- /dev/null +++ b/DMC/scripts/pad.sh @@ -0,0 +1,5 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm pad \ + --num_shared_layers 8 \ + --num_head_layers 3 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/rad.sh b/DMC/scripts/rad.sh new file mode 100644 index 0000000..97b5a0a --- /dev/null +++ b/DMC/scripts/rad.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm rad \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/sac.sh b/DMC/scripts/sac.sh new file mode 100644 index 0000000..0a39f65 --- /dev/null +++ b/DMC/scripts/sac.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm sac \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/sgsac.sh b/DMC/scripts/sgsac.sh new file mode 100644 index 0000000..79ac13b --- /dev/null +++ b/DMC/scripts/sgsac.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm sgsac \ + --seed 0 --eval_mode all --domain_name cartpole --task_name swingup --sgqn_quantile 0.98 \ No newline at end of file diff --git a/DMC/scripts/soda.sh b/DMC/scripts/soda.sh new file mode 100644 index 0000000..33cca72 --- /dev/null +++ b/DMC/scripts/soda.sh @@ -0,0 +1,4 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm soda \ + --aux_lr 3e-4 \ + --seed 0 \ No newline at end of file diff --git a/DMC/scripts/svea.sh b/DMC/scripts/svea.sh new file mode 100644 index 0000000..fbc7157 --- /dev/null +++ b/DMC/scripts/svea.sh @@ -0,0 +1,3 @@ +CUDA_VISIBLE_DEVICES=0 python3 src/train.py \ + --algorithm svea \ + --seed 0 \ No newline at end of file diff --git a/DMC/setup/conda.yaml b/DMC/setup/conda.yaml new file mode 100644 index 0000000..bd24780 --- /dev/null +++ b/DMC/setup/conda.yaml @@ -0,0 +1,24 @@ +name: dmcgb +channels: + - defaults +dependencies: + - python=3.7.6 + - cudatoolkit=9.2 + - absl-py + - pyparsing + - pip + - pip: + - numpy==1.19.5 + - torch==1.7.1 + - torchvision==0.8.2 + - pillow==6.2.0 + - termcolor + - imageio + - imageio-ffmpeg + - opencv-python + - xmltodict + - tqdm + - einops + - kornia + - captum + - tensorboard diff --git a/DMC/setup/config.cfg b/DMC/setup/config.cfg new file mode 100644 index 0000000..71f5502 --- /dev/null +++ b/DMC/setup/config.cfg @@ -0,0 +1,5 @@ +{ + "datasets": [ + "places365_standard/" + ] +} diff --git a/DMC/setup/install_envs.sh b/DMC/setup/install_envs.sh new file mode 100644 index 0000000..8ba6329 --- /dev/null +++ b/DMC/setup/install_envs.sh @@ -0,0 +1,10 @@ +cd src/env/dm_control +pip install -e . + +cd ../dmc2gym +pip install -e . + +cd ../../.. + +curl https://codeload.github.com/nicklashansen/dmcontrol-generalization-benchmark/tar.gz/main | tar -xz --strip=3 dmcontrol-generalization-benchmark-main/src/env/data +mv data src/env/ \ No newline at end of file diff --git a/DMC/setup/install_mujoco_deps.sh b/DMC/setup/install_mujoco_deps.sh new file mode 100644 index 0000000..b784d5c --- /dev/null +++ b/DMC/setup/install_mujoco_deps.sh @@ -0,0 +1,5 @@ +cd $HOME/.mujoco +wget https://roboti.us/download/mujoco200_linux.zip +unzip mujoco200_linux.zip +wget https://roboti.us/file/mjkey.txt +mv mjkey.txt mujoco200_linux/bin \ No newline at end of file diff --git a/DMC/setup/prepare_dm_control_xp.sh b/DMC/setup/prepare_dm_control_xp.sh new file mode 100644 index 0000000..7094a47 --- /dev/null +++ b/DMC/setup/prepare_dm_control_xp.sh @@ -0,0 +1,2 @@ +export MJKEY_PATH=$HOME/.mujoco/mujoco200_linux/bin/mjkey.txt +export MUJOCO_GL=egl \ No newline at end of file diff --git a/DMC/src/algorithms/curl.py b/DMC/src/algorithms/curl.py new file mode 100644 index 0000000..32b9466 --- /dev/null +++ b/DMC/src/algorithms/curl.py @@ -0,0 +1,57 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m +from algorithms.sac import SAC + + +class CURL(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + self.aux_update_freq = args.aux_update_freq + + self.curl_head = m.CURLHead(self.critic.encoder).cuda() + + self.curl_optimizer = torch.optim.Adam( + self.curl_head.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999) + ) + self.train() + + def train(self, training=True): + super().train(training) + if hasattr(self, 'curl_head'): + self.curl_head.train(training) + + def update_curl(self, x, x_pos, L=None, step=None): + assert x.size(-1) == 84 and x_pos.size(-1) == 84 + + z_a = self.curl_head.encoder(x) + with torch.no_grad(): + z_pos = self.critic_target.encoder(x_pos) + + logits = self.curl_head.compute_logits(z_a, z_pos) + labels = torch.arange(logits.shape[0]).long().cuda() + curl_loss = F.cross_entropy(logits, labels) + + self.curl_optimizer.zero_grad() + curl_loss.backward() + self.curl_optimizer.step() + if L is not None: + L.log('train/aux_loss', curl_loss, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done, pos = replay_buffer.sample_curl() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() + + if step % self.aux_update_freq == 0: + self.update_curl(obs, pos, L, step) diff --git a/DMC/src/algorithms/drq.py b/DMC/src/algorithms/drq.py new file mode 100644 index 0000000..51999f8 --- /dev/null +++ b/DMC/src/algorithms/drq.py @@ -0,0 +1,24 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m +from algorithms.sac import SAC + + +class DrQ(SAC): # [K=1, M=1] + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample_drq() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() diff --git a/DMC/src/algorithms/factory.py b/DMC/src/algorithms/factory.py new file mode 100644 index 0000000..af3b374 --- /dev/null +++ b/DMC/src/algorithms/factory.py @@ -0,0 +1,23 @@ +from algorithms.sac import SAC +from algorithms.rad import RAD +from algorithms.curl import CURL +from algorithms.pad import PAD +from algorithms.soda import SODA +from algorithms.drq import DrQ +from algorithms.svea import SVEA +from algorithms.sgsac import SGSAC + +algorithm = { + "sac": SAC, + "rad": RAD, + "curl": CURL, + "pad": PAD, + "soda": SODA, + "drq": DrQ, + "svea": SVEA, + "sgsac": SGSAC, +} + + +def make_agent(obs_shape, action_shape, args): + return algorithm[args.algorithm](obs_shape, action_shape, args) diff --git a/DMC/src/algorithms/modules.py b/DMC/src/algorithms/modules.py new file mode 100644 index 0000000..8837dc2 --- /dev/null +++ b/DMC/src/algorithms/modules.py @@ -0,0 +1,354 @@ +from turtle import forward +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +from functools import partial + + +def _get_out_shape_cuda(in_shape, layers): + x = torch.randn(*in_shape).cuda().unsqueeze(0) + return layers(x).squeeze(0).shape + + +def _get_out_shape(in_shape, layers): + x = torch.randn(*in_shape).unsqueeze(0) + return layers(x).squeeze(0).shape + + +def gaussian_logprob(noise, log_std): + """Compute Gaussian log probability""" + residual = (-0.5 * noise.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * noise.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function, see appendix C from https://arxiv.org/pdf/1812.05905.pdf""" + mu = torch.tanh(mu) + if pi is not None: + pi = torch.tanh(pi) + if log_pi is not None: + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + """Truncated normal distribution, see https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf""" + + def norm_cdf(x): + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + with torch.no_grad(): + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + tensor.uniform_(2 * l - 1, 2 * u - 1) + tensor.erfinv_() + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + tensor.clamp_(min=a, max=b) + return tensor + + +def weight_init(m): + """Custom weight init for Conv2D and Linear layers""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): + # delta-orthogonal init from https://arxiv.org/pdf/1806.05393.pdf + assert m.weight.size(2) == m.weight.size(3) + m.weight.data.fill_(0.0) + if hasattr(m.bias, "data"): + m.bias.data.fill_(0.0) + mid = m.weight.size(2) // 2 + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data[:, :, mid, mid], gain) + + +class CenterCrop(nn.Module): + def __init__(self, size): + super().__init__() + assert size in {84, 100}, f"unexpected size: {size}" + self.size = size + + def forward(self, x): + assert x.ndim == 4, "input must be a 4D tensor" + if x.size(2) == self.size and x.size(3) == self.size: + return x + assert x.size(3) == 100, f"unexpected size: {x.size(3)}" + if self.size == 84: + p = 8 + return x[:, :, p:-p, p:-p] + + +class NormalizeImg(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x / 255.0 + + +class Flatten(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +class RLProjection(nn.Module): + def __init__(self, in_shape, out_dim): + super().__init__() + self.out_dim = out_dim + self.projection = nn.Sequential( + nn.Linear(in_shape[0], out_dim), nn.LayerNorm(out_dim), nn.Tanh() + ) + self.apply(weight_init) + + def forward(self, x): + y = self.projection(x) + return y + + +class SODAMLP(nn.Module): + def __init__(self, projection_dim, hidden_dim, out_dim): + super().__init__() + self.out_dim = out_dim + self.mlp = nn.Sequential( + nn.Linear(projection_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, out_dim), + ) + self.apply(weight_init) + + def forward(self, x): + return self.mlp(x) + + +class SharedCNN(nn.Module): + def __init__(self, obs_shape, num_layers=11, num_filters=32): + super().__init__() + assert len(obs_shape) == 3 + self.num_layers = num_layers + self.num_filters = num_filters + + self.layers = [ + CenterCrop(size=84), + NormalizeImg(), + nn.Conv2d(obs_shape[0], num_filters, 3, stride=2), + ] + for _ in range(1, num_layers): + self.layers.append(nn.ReLU()) + self.layers.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) + self.layers = nn.Sequential(*self.layers) + self.out_shape = _get_out_shape(obs_shape, self.layers) + self.apply(weight_init) + + def forward(self, x): + return self.layers(x) + + +class HeadCNN(nn.Module): + def __init__(self, in_shape, num_layers=0, num_filters=32): + super().__init__() + self.layers = [] + for _ in range(0, num_layers): + self.layers.append(nn.ReLU()) + self.layers.append(nn.Conv2d(num_filters, num_filters, 3, stride=1)) + self.layers.append(Flatten()) + self.layers = nn.Sequential(*self.layers) + self.out_shape = _get_out_shape(in_shape, self.layers) + self.apply(weight_init) + + def forward(self, x): + return self.layers(x) + + +class Encoder(nn.Module): + def __init__(self, shared_cnn, head_cnn, projection): + super().__init__() + self.shared_cnn = shared_cnn + self.head_cnn = head_cnn + self.projection = projection + self.out_dim = projection.out_dim + + def forward(self, x, detach=False): + x = self.shared_cnn(x) + x = self.head_cnn(x) + if detach: + x = x.detach() + return self.projection(x) + + +class Actor(nn.Module): + def __init__(self, encoder, action_shape, hidden_dim, log_std_min, log_std_max): + super().__init__() + self.encoder = encoder + self.log_std_min = log_std_min + self.log_std_max = log_std_max + self.mlp = nn.Sequential( + nn.Linear(self.encoder.out_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 2 * action_shape[0]), + ) + self.mlp.apply(weight_init) + + def forward( + self, + x, + compute_pi=True, + compute_log_pi=True, + detach=False, + compute_attrib=False, + ): + x = self.encoder(x, detach) + mu, log_std = self.mlp(x).chunk(2, dim=-1) + log_std = torch.tanh(log_std) + log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * ( + log_std + 1 + ) + + if compute_pi: + std = log_std.exp() + noise = torch.randn_like(mu) + pi = mu + noise * std + else: + pi = None + entropy = None + + if compute_log_pi: + log_pi = gaussian_logprob(noise, log_std) + else: + log_pi = None + + mu, pi, log_pi = squash(mu, pi, log_pi) + + return mu, pi, log_pi, log_std + + +class QFunction(nn.Module): + def __init__(self, obs_dim, action_dim, hidden_dim): + super().__init__() + self.trunk = nn.Sequential( + nn.Linear(obs_dim + action_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1), + ) + self.apply(weight_init) + + def forward(self, obs, action): + assert obs.size(0) == action.size(0) + return self.trunk(torch.cat([obs, action], dim=1)) + + +class Critic(nn.Module): + def __init__(self, encoder, action_shape, hidden_dim): + super().__init__() + self.encoder = encoder + self.Q1 = QFunction(self.encoder.out_dim, action_shape[0], hidden_dim) + self.Q2 = QFunction(self.encoder.out_dim, action_shape[0], hidden_dim) + + def forward(self, x, action, detach=False): + x = self.encoder(x, detach) + return self.Q1(x, action), self.Q2(x, action) + + +class CURLHead(nn.Module): + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + self.W = nn.Parameter(torch.rand(encoder.out_dim, encoder.out_dim)) + + def compute_logits(self, z_a, z_pos): + """ + Uses logits trick for CURL: + - compute (B,B) matrix z_a (W z_pos.T) + - positives are all diagonal elements + - negatives are all other elements + - to compute loss use multiclass cross entropy with identity matrix for labels + """ + Wz = torch.matmul(self.W, z_pos.T) # (z_dim,B) + logits = torch.matmul(z_a, Wz) # (B,B) + logits = logits - torch.max(logits, 1)[0][:, None] + return logits + + +class InverseDynamics(nn.Module): + def __init__(self, encoder, action_shape, hidden_dim): + super().__init__() + self.encoder = encoder + self.mlp = nn.Sequential( + nn.Linear(2 * encoder.out_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, action_shape[0]), + ) + self.apply(weight_init) + + def forward(self, x, x_next): + h = self.encoder(x) + h_next = self.encoder(x_next) + joint_h = torch.cat([h, h_next], dim=1) + return self.mlp(joint_h) + + +class SODAPredictor(nn.Module): + def __init__(self, encoder, hidden_dim): + super().__init__() + self.encoder = encoder + self.mlp = SODAMLP(encoder.out_dim, hidden_dim, encoder.out_dim) + self.apply(weight_init) + + def forward(self, x): + return self.mlp(self.encoder(x)) + + +class AttributionDecoder(nn.Module): + def __init__(self,action_shape, emb_dim=100) -> None: + super().__init__() + self.proj = nn.Linear(in_features=emb_dim+action_shape, out_features=14112) + self.conv1 = nn.Conv2d( + in_channels=32, out_channels=128, kernel_size=3, padding=1 + ) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d( + in_channels=128, out_channels=64, kernel_size=3, padding=1 + ) + self.conv3 = nn.Conv2d(in_channels=64, out_channels=9, kernel_size=3, padding=1) + + def forward(self, x, action): + x = torch.cat([x,action],dim=1) + x = self.proj(x).view(-1, 32, 21, 21) + x = self.relu(x) + x = self.conv1(x) + x = F.upsample(x, scale_factor=2) + x = self.relu(x) + x = self.conv2(x) + x = F.upsample(x, scale_factor=2) + x = self.relu(x) + x = self.conv3(x) + return x + + + +class AttributionPredictor(nn.Module): + def __init__(self, action_shape,encoder, emb_dim=100): + super().__init__() + self.encoder = encoder + self.decoder = AttributionDecoder(action_shape,encoder.out_dim) + self.features_decoder = nn.Sequential( + nn.Linear(emb_dim, 256), nn.ReLU(), nn.Linear(256, emb_dim) + ) + + def forward(self, x,action): + x = self.encoder(x) + return self.decoder(x,action) diff --git a/DMC/src/algorithms/pad.py b/DMC/src/algorithms/pad.py new file mode 100644 index 0000000..4a1d9a7 --- /dev/null +++ b/DMC/src/algorithms/pad.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m +from algorithms.sac import SAC + + +class PAD(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + self.aux_update_freq = args.aux_update_freq + self.aux_lr = args.aux_lr + self.aux_beta = args.aux_beta + + shared_cnn = self.critic.encoder.shared_cnn + aux_cnn = m.HeadCNN(shared_cnn.out_shape, args.num_head_layers, args.num_filters).cuda() + aux_encoder = m.Encoder( + shared_cnn, + aux_cnn, + m.RLProjection(aux_cnn.out_shape, args.projection_dim) + ) + self.pad_head = m.InverseDynamics(aux_encoder, action_shape, args.hidden_dim).cuda() + self.init_pad_optimizer() + self.train() + + def train(self, training=True): + super().train(training) + if hasattr(self, 'pad_head'): + self.pad_head.train(training) + + def init_pad_optimizer(self): + self.pad_optimizer = torch.optim.Adam( + self.pad_head.parameters(), lr=self.aux_lr, betas=(self.aux_beta, 0.999) + ) + + def update_inverse_dynamics(self, obs, obs_next, action, L=None, step=None): + assert obs.shape[-1] == 84 and obs_next.shape[-1] == 84 + + pred_action = self.pad_head(obs, obs_next) + pad_loss = F.mse_loss(pred_action, action) + + self.pad_optimizer.zero_grad() + pad_loss.backward() + self.pad_optimizer.step() + if L is not None: + L.log('train/aux_loss', pad_loss, step) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() + + if step % self.aux_update_freq == 0: + self.update_inverse_dynamics(obs, next_obs, action, L, step) diff --git a/DMC/src/algorithms/rad.py b/DMC/src/algorithms/rad.py new file mode 100644 index 0000000..89b257e --- /dev/null +++ b/DMC/src/algorithms/rad.py @@ -0,0 +1,13 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m +from algorithms.sac import SAC + + +class RAD(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) diff --git a/DMC/src/algorithms/rl_utils.py b/DMC/src/algorithms/rl_utils.py new file mode 100644 index 0000000..1bb27ad --- /dev/null +++ b/DMC/src/algorithms/rl_utils.py @@ -0,0 +1,107 @@ +import torch +from torchvision.utils import make_grid + +from captum.attr import GuidedBackprop, GuidedGradCam + + +class HookFeatures: + def __init__(self, module): + self.feature_hook = module.register_forward_hook(self.feature_hook_fn) + + def feature_hook_fn(self, module, input, output): + self.features = output.clone().detach() + self.gradient_hook = output.register_hook(self.gradient_hook_fn) + + def gradient_hook_fn(self, grad): + self.gradients = grad + + def close(self): + self.feature_hook.remove() + self.gradient_hook.remove() + + +class ModelWrapper(torch.nn.Module): + def __init__(self, model, action=None): + super(ModelWrapper, self).__init__() + self.model = model + self.action = action + + def forward(self, obs): + if self.action is None: + return self.model(obs)[0] + return self.model(obs, self.action)[0] + + +def compute_guided_backprop(obs, action, model): + model = ModelWrapper(model, action=action) + gbp = GuidedBackprop(model) + attribution = gbp.attribute(obs) + return attribution + +def compute_guided_gradcam(obs, action, model): + obs.requires_grad_() + obs.retain_grad() + model = ModelWrapper(model, action=action) + gbp = GuidedGradCam(model,layer=model.model.encoder.head_cnn.layers) + attribution = gbp.attribute(obs,attribute_to_layer_input=True) + return attribution + +def compute_vanilla_grad(critic_target, obs, action): + obs.requires_grad_() + obs.retain_grad() + q, q2 = critic_target(obs, action.detach()) + q.sum().backward() + return obs.grad + + +def compute_attribution(model, obs, action=None,method="guided_backprop"): + if method == "guided_backprop": + return compute_guided_backprop(obs, action, model) + if method == 'guided_gradcam': + return compute_guided_gradcam(obs,action,model) + return compute_vanilla_grad(model, obs, action) + + +def compute_features_attribution(critic_target, obs, action): + obs.requires_grad_() + obs.retain_grad() + hook = HookFeatures(critic_target.encoder) + q, _ = critic_target(obs, action.detach()) + q.sum().backward() + features_gardients = hook.gradients + hook.close() + return obs.grad, features_gardients + + +def compute_attribution_mask(obs_grad, quantile=0.95): + mask = [] + for i in [0, 3, 6]: + attributions = obs_grad[:, i : i + 3].abs().max(dim=1)[0] + q = torch.quantile(attributions.flatten(1), quantile, 1) + mask.append((attributions >= q[:, None, None]).unsqueeze(1).repeat(1, 3, 1, 1)) + return torch.cat(mask, dim=1) + + +def make_obs_grid(obs, n=4): + sample = [] + for i in range(n): + for j in range(0, 9, 3): + sample.append(obs[i, j : j + 3].unsqueeze(0)) + sample = torch.cat(sample, 0) + return make_grid(sample, nrow=3) / 255.0 + + +def make_attribution_pred_grid(attribution_pred, n=4): + return make_grid(attribution_pred[:n], nrow=1) + + +def make_obs_grad_grid(obs_grad, n=4): + sample = [] + for i in range(n): + for j in range(0, 9, 3): + channel_attribution, _ = torch.max(obs_grad[i, j : j + 3], dim=0) + sample.append(channel_attribution[(None,) * 2] / channel_attribution.max()) + sample = torch.cat(sample, 0) + q = torch.quantile(sample.flatten(1), 0.97, 1) + sample[sample <= q[:, None, None, None]] = 0 + return make_grid(sample, nrow=3) diff --git a/DMC/src/algorithms/sac.py b/DMC/src/algorithms/sac.py new file mode 100644 index 0000000..38e9c66 --- /dev/null +++ b/DMC/src/algorithms/sac.py @@ -0,0 +1,169 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m + + +class FeaturesHook: + def __init__(self, module): + self.hook = module.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + self.features = output + + def close(self): + self.hook.remove() + + +class SAC(object): + def __init__(self, obs_shape, action_shape, args): + self.discount = args.discount + self.critic_tau = args.critic_tau + self.encoder_tau = args.encoder_tau + self.actor_update_freq = args.actor_update_freq + self.critic_target_update_freq = args.critic_target_update_freq + + shared_cnn = m.SharedCNN( + obs_shape, args.num_shared_layers, args.num_filters + ).cuda() + head_cnn = m.HeadCNN( + shared_cnn.out_shape, args.num_head_layers, args.num_filters + ).cuda() + actor_encoder = m.Encoder( + shared_cnn, + head_cnn, + m.RLProjection(head_cnn.out_shape, args.projection_dim), + ) + critic_encoder = m.Encoder( + shared_cnn, + head_cnn, + m.RLProjection(head_cnn.out_shape, args.projection_dim), + ) + + self.actor = m.Actor( + actor_encoder, + action_shape, + args.hidden_dim, + args.actor_log_std_min, + args.actor_log_std_max, + ).cuda() + self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim).cuda() + self.critic_target = deepcopy(self.critic) + + self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda() + self.log_alpha.requires_grad = True + self.target_entropy = -np.prod(action_shape) + + self.actor_optimizer = torch.optim.Adam( + self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999) + ) + self.critic_optimizer = torch.optim.Adam( + self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999),weight_decay=args.critic_weight_decay, + ) + self.log_alpha_optimizer = torch.optim.Adam( + [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999) + ) + + self.train() + self.critic_target.train() + self.hook = FeaturesHook(self.critic.encoder.head_cnn) + + def train(self, training=True): + self.training = training + self.actor.train(training) + self.critic.train(training) + + def eval(self): + self.train(False) + + @property + def alpha(self): + return self.log_alpha.exp() + + def _obs_to_input(self, obs): + if isinstance(obs, utils.LazyFrames): + _obs = np.array(obs) + else: + _obs = obs + _obs = torch.FloatTensor(_obs).cuda() + _obs = _obs.unsqueeze(0) + return _obs + + def select_action(self, obs): + _obs = self._obs_to_input(obs) + with torch.no_grad(): + mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False) + return mu.cpu().data.numpy().flatten() + + def sample_action(self, obs): + _obs = self._obs_to_input(obs) + with torch.no_grad(): + mu, pi, _, _ = self.actor(_obs, compute_log_pi=False) + return pi.cpu().data.numpy().flatten() + + def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None): + with torch.no_grad(): + _, policy_action, log_pi, _ = self.actor(next_obs) + target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) + target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi + target_Q = reward + (not_done * self.discount * target_V) + + current_Q1, current_Q2 = self.critic(obs, action) + critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( + current_Q2, target_Q + ) + if L is not None: + L.log("train_critic/loss", critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True): + _, pi, log_pi, log_std = self.actor(obs, detach=True) + actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True) + + actor_Q = torch.min(actor_Q1, actor_Q2) + actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() + + if L is not None: + L.log("train_actor/loss", actor_loss, step) + entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)) + log_std.sum( + dim=-1 + ) + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + if update_alpha: + self.log_alpha_optimizer.zero_grad() + alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean() + + if L is not None: + L.log("train_alpha/loss", alpha_loss, step) + L.log("train_alpha/value", self.alpha, step) + + alpha_loss.backward() + self.log_alpha_optimizer.step() + + def soft_update_critic_target(self): + utils.soft_update_params(self.critic.Q1, self.critic_target.Q1, self.critic_tau) + utils.soft_update_params(self.critic.Q2, self.critic_target.Q2, self.critic_tau) + utils.soft_update_params( + self.critic.encoder, self.critic_target.encoder, self.encoder_tau + ) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() diff --git a/DMC/src/algorithms/sgsac.py b/DMC/src/algorithms/sgsac.py new file mode 100644 index 0000000..9dc63aa --- /dev/null +++ b/DMC/src/algorithms/sgsac.py @@ -0,0 +1,137 @@ +import os +from copy import deepcopy +import numpy as np +import torch +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter + +import utils +import augmentations +import algorithms.modules as m +from algorithms.sac import SAC + +from .rl_utils import ( + compute_attribution, + compute_attribution_mask, + make_attribution_pred_grid, + make_obs_grid, + make_obs_grad_grid, +) +import random + + +class SGSAC(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + + self.attribution_predictor = m.AttributionPredictor(action_shape[0],self.critic.encoder).cuda() + self.quantile = args.sgqn_quantile + self.aux_update_freq = args.aux_update_freq + self.consistency = args.consistency + + self.aux_optimizer = torch.optim.Adam( + self.attribution_predictor.parameters(), + lr=args.aux_lr, + betas=(args.aux_beta, 0.999), + ) + + tb_dir = os.path.join( + args.log_dir, + args.domain_name + "_" + args.task_name, + args.algorithm, + str(args.seed), + "tensorboard", + ) + self.writer = SummaryWriter(tb_dir) + + + def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None): + with torch.no_grad(): + _, policy_action, log_pi, _ = self.actor(next_obs) + target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) + target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi + target_Q = reward + (not_done * self.discount * target_V) + + current_Q1, current_Q2 = self.critic(obs, action) + + + critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( + current_Q2, target_Q + ) + if self.consistency: + obs_grad = compute_attribution(self.critic,obs,action.detach()) + mask = compute_attribution_mask(obs_grad,self.quantile) + masked_obs = obs*mask + masked_obs[mask<1] = random.uniform(obs.view(-1).min(),obs.view(-1).max()) + masked_Q1,masked_Q2 = self.critic(masked_obs,action) + critic_loss += 0.5 *(F.mse_loss(current_Q1,masked_Q1) + F.mse_loss(current_Q2,masked_Q2)) + if L is not None: + L.log("train_critic/loss", critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + def update_aux(self, obs, action, obs_grad, mask, step=None, L=None): + mask = compute_attribution_mask(obs_grad,self.quantile) + s_prime = augmentations.attribution_augmentation(obs.clone(), mask.float()) + + s_tilde = augmentations.random_overlay(obs.clone()) + self.aux_optimizer.zero_grad() + pred_attrib, aux_loss = self.compute_attribution_loss(s_tilde,action, mask) + aux_loss.backward() + self.aux_optimizer.step() + + if L is not None: + L.log("train/aux_loss", aux_loss, step) + + if step % 10000 == 0: + self.log_tensorboard(obs, action, step, prefix="original") + self.log_tensorboard(s_tilde, action, step, prefix="augmented") + self.log_tensorboard(s_prime, action, step, prefix="super_augmented") + + def log_tensorboard(self, obs, action, step, prefix="original"): + obs_grad = compute_attribution(self.critic, obs, action.detach()) + mask = compute_attribution_mask(obs_grad, quantile=self.quantile) + attrib = self.attribution_predictor(obs.detach(),action.detach()) + grid = make_obs_grid(obs) + self.writer.add_image(prefix + "/observation", grid, global_step=step) + grad_grid = make_obs_grad_grid(obs_grad.data.abs()) + self.writer.add_image(prefix + "/attributions", grad_grid, global_step=step) + mask = torch.sigmoid(attrib) + mask = (mask > 0.5).float() + masked_obs = make_obs_grid(obs * mask) + self.writer.add_image(prefix + "/masked_obs{}", masked_obs, global_step=step) + attrib_grid = make_obs_grad_grid(torch.sigmoid(attrib)) + self.writer.add_image( + prefix + "/predicted_attrib", attrib_grid, global_step=step + ) + for q in [0.95, 0.975, 0.9, 0.995, 0.999]: + mask = compute_attribution_mask(obs_grad, quantile=q) + masked_obs = make_obs_grid(obs * mask) + self.writer.add_image( + prefix + "/attrib_q{}".format(q), masked_obs, global_step=step + ) + + def compute_attribution_loss(self, obs,action, mask): + mask = mask.float() + attrib = self.attribution_predictor(obs.detach(),action.detach()) + aux_loss = F.binary_cross_entropy_with_logits(attrib, mask.detach()) + return attrib, aux_loss + + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample_drq() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + obs_grad = compute_attribution(self.critic, obs, action.detach()) + mask = compute_attribution_mask(obs_grad, quantile=self.quantile) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() + + if step % self.aux_update_freq == 0: + self.update_aux(obs, action, obs_grad, mask, step, L) diff --git a/DMC/src/algorithms/soda.py b/DMC/src/algorithms/soda.py new file mode 100644 index 0000000..21dcd68 --- /dev/null +++ b/DMC/src/algorithms/soda.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import algorithms.modules as m +from algorithms.sac import SAC +import augmentations + + +class SODA(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + self.aux_update_freq = args.aux_update_freq + self.soda_batch_size = args.soda_batch_size + self.soda_tau = args.soda_tau + + shared_cnn = self.critic.encoder.shared_cnn + aux_cnn = self.critic.encoder.head_cnn + soda_encoder = m.Encoder( + shared_cnn, + aux_cnn, + m.SODAMLP(aux_cnn.out_shape[0], args.projection_dim, args.projection_dim) + ) + + self.predictor = m.SODAPredictor(soda_encoder, args.projection_dim).cuda() + self.predictor_target = deepcopy(self.predictor) + + self.soda_optimizer = torch.optim.Adam( + self.predictor.parameters(), lr=args.aux_lr, betas=(args.aux_beta, 0.999) + ) + self.train() + + def train(self, training=True): + super().train(training) + if hasattr(self, 'soda_predictor'): + self.soda_predictor.train(training) + + def compute_soda_loss(self, x0, x1): + h0 = self.predictor(x0) + with torch.no_grad(): + h1 = self.predictor_target.encoder(x1) + h0 = F.normalize(h0, p=2, dim=1) + h1 = F.normalize(h1, p=2, dim=1) + + return F.mse_loss(h0, h1) + + def update_soda(self, replay_buffer, L=None, step=None): + x = replay_buffer.sample_soda(self.soda_batch_size) + assert x.size(-1) == 100 + + aug_x = x.clone() + + x = augmentations.random_crop(x) + aug_x = augmentations.random_crop(aug_x) + aug_x = augmentations.random_overlay(aug_x) + + soda_loss = self.compute_soda_loss(aug_x, x) + + self.soda_optimizer.zero_grad() + soda_loss.backward() + self.soda_optimizer.step() + if L is not None: + L.log('train/aux_loss', soda_loss, step) + + utils.soft_update_params( + self.predictor, self.predictor_target, + self.soda_tau + ) + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() + + if step % self.aux_update_freq == 0: + self.update_soda(replay_buffer, L, step) diff --git a/DMC/src/algorithms/svea.py b/DMC/src/algorithms/svea.py new file mode 100644 index 0000000..e337b20 --- /dev/null +++ b/DMC/src/algorithms/svea.py @@ -0,0 +1,63 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +import utils +import augmentations +import algorithms.modules as m +from algorithms.sac import SAC + + +class SVEA(SAC): + def __init__(self, obs_shape, action_shape, args): + super().__init__(obs_shape, action_shape, args) + self.svea_alpha = args.svea_alpha + self.svea_beta = args.svea_beta + + def update_critic(self, obs, action, reward, next_obs, not_done, L=None, step=None): + with torch.no_grad(): + _, policy_action, log_pi, _ = self.actor(next_obs) + target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) + target_V = torch.min(target_Q1, target_Q2) - self.alpha.detach() * log_pi + target_Q = reward + (not_done * self.discount * target_V) + + if self.svea_alpha == self.svea_beta: + obs = utils.cat(obs, augmentations.random_overlay(obs.clone())) + action = utils.cat(action, action) + target_Q = utils.cat(target_Q, target_Q) + + current_Q1, current_Q2 = self.critic(obs, action) + critic_loss = (self.svea_alpha + self.svea_beta) * ( + F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) + ) + else: + current_Q1, current_Q2 = self.critic(obs, action) + critic_loss = self.svea_alpha * ( + F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q) + ) + + obs_aug = augmentations.random_overlay(obs.clone()) + current_Q1_aug, current_Q2_aug = self.critic(obs_aug, action) + critic_loss += self.svea_beta * ( + F.mse_loss(current_Q1_aug, target_Q) + + F.mse_loss(current_Q2_aug, target_Q) + ) + + if L is not None: + L.log("train_critic/loss", critic_loss, step) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + def update(self, replay_buffer, L, step): + obs, action, reward, next_obs, not_done = replay_buffer.sample_drq() + + self.update_critic(obs, action, reward, next_obs, not_done, L, step) + + if step % self.actor_update_freq == 0: + self.update_actor_and_alpha(obs, L, step) + + if step % self.critic_target_update_freq == 0: + self.soft_update_critic_target() diff --git a/DMC/src/arguments.py b/DMC/src/arguments.py new file mode 100644 index 0000000..1359266 --- /dev/null +++ b/DMC/src/arguments.py @@ -0,0 +1,127 @@ +import argparse +import numpy as np + + +def parse_args(): + parser = argparse.ArgumentParser() + + # environment + parser.add_argument("--domain_name", default="walker") + parser.add_argument("--task_name", default="walk") + parser.add_argument("--frame_stack", default=3, type=int) + parser.add_argument("--action_repeat", default=4, type=int) + parser.add_argument("--episode_length", default=1000, type=int) + parser.add_argument("--eval_mode", default="color_hard", type=str) + + # agent + parser.add_argument("--algorithm", default="sgsac", type=str) + parser.add_argument("--train_steps", default="500k", type=str) + parser.add_argument("--discount", default=0.99, type=float) + parser.add_argument("--init_steps", default=1000, type=int) + parser.add_argument("--batch_size", default=128, type=int) + parser.add_argument("--hidden_dim", default=1024, type=int) + + # actor + parser.add_argument("--actor_lr", default=1e-3, type=float) + parser.add_argument("--actor_beta", default=0.9, type=float) + parser.add_argument("--actor_log_std_min", default=-10, type=float) + parser.add_argument("--actor_log_std_max", default=2, type=float) + parser.add_argument("--actor_update_freq", default=2, type=int) + + # critic + parser.add_argument("--critic_lr", default=1e-3, type=float) + parser.add_argument("--critic_beta", default=0.9, type=float) + parser.add_argument("--critic_tau", default=0.01, type=float) + parser.add_argument("--critic_target_update_freq", default=2, type=int) + parser.add_argument("--critic_weight_decay", default=0, type=float) + + + # architecture + parser.add_argument("--num_shared_layers", default=11, type=int) + parser.add_argument("--num_head_layers", default=0, type=int) + parser.add_argument("--num_filters", default=32, type=int) + parser.add_argument("--projection_dim", default=100, type=int) + parser.add_argument("--encoder_tau", default=0.05, type=float) + + # entropy maximization + parser.add_argument("--init_temperature", default=0.1, type=float) + parser.add_argument("--alpha_lr", default=1e-4, type=float) + parser.add_argument("--alpha_beta", default=0.5, type=float) + + # auxiliary tasks + parser.add_argument("--aux_lr", default=3e-4, type=float) + parser.add_argument("--aux_beta", default=0.9, type=float) + parser.add_argument("--aux_update_freq", default=2, type=int) + + # soda + parser.add_argument("--soda_batch_size", default=256, type=int) + parser.add_argument("--soda_tau", default=0.005, type=float) + + # svea + parser.add_argument("--svea_alpha", default=0.5, type=float) + parser.add_argument("--svea_beta", default=0.5, type=float) + parser.add_argument("--sgqn_quantile", default=0.90, type=float) + parser.add_argument("--svea_contrastive_coeff", default=0.1, type=float) + parser.add_argument("--svea_norm_coeff", default=0.1, type=float) + parser.add_argument("--attrib_coeff", default=0.25, type=float) + parser.add_argument("--consistency", default=1, type=int) + + # eval + parser.add_argument("--save_freq", default="100k", type=str) + parser.add_argument("--eval_freq", default="10k", type=str) + parser.add_argument("--eval_episodes", default=30, type=int) + parser.add_argument("--distracting_cs_intensity", default=0.0, type=float) + + # misc + parser.add_argument("--seed", default=10081, type=int) + parser.add_argument("--log_dir", default="logs", type=str) + parser.add_argument("--save_video", default=False, action="store_true") + + args = parser.parse_args() + + assert args.algorithm in { + "sac", + "rad", + "curl", + "pad", + "soda", + "drq", + "svea", + "saca", + "sacfa", + "sgsac", + }, f'specified algorithm "{args.algorithm}" is not supported' + + assert args.eval_mode in { + "train", + "color_easy", + "color_hard", + "video_easy", + "video_hard", + "distracting_cs", + "all", + "none", + }, f'specified mode "{args.eval_mode}" is not supported' + assert args.seed is not None, "must provide seed for experiment" + assert args.log_dir is not None, "must provide a log directory for experiment" + + intensities = {0.0, 0.025, 0.05, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5} + assert ( + args.distracting_cs_intensity in intensities + ), f"distracting_cs has only been implemented for intensities: {intensities}" + + args.train_steps = int(args.train_steps.replace("k", "000")) + args.save_freq = int(args.save_freq.replace("k", "000")) + args.eval_freq = int(args.eval_freq.replace("k", "000")) + + if args.eval_mode == "none": + args.eval_mode = None + + if args.algorithm in {"rad", "curl", "pad", "soda"}: + args.image_size = 100 + args.image_crop_size = 84 + else: + args.image_size = 84 + args.image_crop_size = 84 + + return args diff --git a/DMC/src/augmentations.py b/DMC/src/augmentations.py new file mode 100644 index 0000000..e40d20d --- /dev/null +++ b/DMC/src/augmentations.py @@ -0,0 +1,253 @@ +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as TF +import torchvision.datasets as datasets +import utils +import os +import kornia +import random + +places_dataloader = None +places_iter = None + + +def _load_places(batch_size=256, image_size=84, num_workers=8, use_val=False): + global places_dataloader, places_iter + partition = "val" if use_val else "train" + print(f"Loading {partition} partition of places365_standard...") + for data_dir in utils.load_config("datasets"): + if os.path.exists(data_dir): + fp = os.path.join(data_dir, "places365_standard", partition) + if not os.path.exists(fp): + print(f"Warning: path {fp} does not exist, falling back to {data_dir}") + fp = data_dir + places_dataloader = torch.utils.data.DataLoader( + datasets.ImageFolder( + fp, + TF.Compose( + [ + TF.RandomResizedCrop(image_size), + TF.RandomHorizontalFlip(), + TF.ToTensor(), + ] + ), + ), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + ) + places_iter = iter(places_dataloader) + break + if places_iter is None: + raise FileNotFoundError( + "failed to find places365 data at any of the specified paths" + ) + print("Loaded dataset from", data_dir) + + +def _get_places_batch(batch_size): + global places_iter + try: + imgs, _ = next(places_iter) + if imgs.size(0) < batch_size: + places_iter = iter(places_dataloader) + imgs, _ = next(places_iter) + except StopIteration: + places_iter = iter(places_dataloader) + imgs, _ = next(places_iter) + return imgs.cuda() + + +def random_overlay(x, dataset="places365_standard"): + """Randomly overlay an image from Places""" + global places_iter + alpha = 0.5 + + if dataset == "places365_standard": + if places_dataloader is None: + _load_places(batch_size=x.size(0), image_size=x.size(-1)) + imgs = _get_places_batch(batch_size=x.size(0)).repeat(1, x.size(1) // 3, 1, 1) + else: + raise NotImplementedError( + f'overlay has not been implemented for dataset "{dataset}"' + ) + + return ((1 - alpha) * (x / 255.0) + (alpha) * imgs) * 255.0 + + +def attribution_augmentation(x, mask, dataset="places365_standard"): + """Complete non importnant pixels with a random image from Places""" + global places_iter + + if dataset == "places365_standard": + if places_dataloader is None: + _load_places(batch_size=x.size(0), image_size=x.size(-1)) + imgs = _get_places_batch(batch_size=x.size(0)).repeat(1, x.size(1) // 3, 1, 1) + else: + raise NotImplementedError( + f'overlay has not been implemented for dataset "{dataset}"' + ) + + # s_plus = random_conv(x) * mask + s_plus = x * mask + s_tilde = (((s_plus) / 255.0) + (imgs * (torch.ones_like(mask) - mask))) * 255.0 + s_minus = imgs * 255 + return s_tilde + + +def paired_aug(obs, mask): + mask = mask.float() + SEMANTIC = [kornia.augmentation.RandomAffine([-45., 45.], [0.3, 0.3], [0.5, 1.5], [0., 0.15]),kornia.augmentation.RandomErasing()] + no_sem = lambda x : random_overlay(x) + sem = random.sample(SEMANTIC,k=1)[0] + img_out = no_sem(sem(obs)) + + mask_out = sem(mask, sem._params) + return img_out, mask_out + +def attribution_random_patch_augmentation( + x, + cam, + image_size=84, + output_size=4, + quantile=0.90, + patch_proba=0.7, + dataset="places365_standard", +): + + if dataset == "places365_standard": + if places_dataloader is None: + _load_places(batch_size=x.size(0), image_size=x.size(-1)) + negative = _get_places_batch(batch_size=x.size(0)).repeat( + 1, x.size(1) // 3, 1, 1 + ) + else: + raise NotImplementedError( + f'overlay has not been implemented for dataset "{dataset}"' + ) + cam = cam.to(x.device) + cam = F.adaptive_avg_pool2d(cam, output_size=output_size) + q = torch.quantile(cam.flatten(1), quantile, 1) + mask = (cam >= q[:, None, None]).long() + exploration_mask = torch.rand(*mask.shape).to(x.device) + exploration_mask[~mask] = 0 + expl_max = torch.amax(exploration_mask.view(mask.size(0), -1), dim=1) + exploration_mask = ( + exploration_mask.view(-1, mask.size(1), mask.size(2)) == expl_max[:, None, None] + ).long() + bern = torch.bernoulli(torch.ones_like(mask) * patch_proba).long().to(x.device) + selected_patch = (mask * bern) + exploration_mask + selected_patch[selected_patch > 1] = 1 + selected_patch = F.upsample_nearest(selected_patch.float().unsqueeze(1), image_size) + # augmented_x = (((0.5) * (x / 255.0) + (0.5) * negative) * 255.0) * selected_patch + augmented_x = x * selected_patch + complementary_mask = ~(selected_patch.bool()) + negative = negative * (complementary_mask.float()) + return augmented_x + (negative * 255) + + +def blending_augmentation(x, mask, overlay=True): + s_plus, s_minus, s_tilde = attribution_augmentation(x, mask) + if overlay: + overlay_x = (1 - 0.5) * x + (0.5) * s_minus + s_tilde = overlay_x * (~mask) + s_plus + else: + s_tilde = s_minus * (~mask) + s_plus + return s_plus, s_minus, s_tilde + + +def random_conv(x): + """Applies a random conv2d, deviates slightly from https://arxiv.org/abs/1910.05396""" + n, c, h, w = x.shape + for i in range(n): + weights = torch.randn(3, 3, 3, 3).to(x.device) + temp_x = x[i : i + 1].reshape(-1, 3, h, w) / 255.0 + temp_x = F.pad(temp_x, pad=[1] * 4, mode="replicate") + out = torch.sigmoid(F.conv2d(temp_x, weights)) * 255.0 + total_out = out if i == 0 else torch.cat([total_out, out], axis=0) + return total_out.reshape(n, c, h, w) + + +def batch_from_obs(obs, batch_size=32): + """Copy a single observation along the batch dimension""" + if isinstance(obs, torch.Tensor): + if len(obs.shape) == 3: + obs = obs.unsqueeze(0) + return obs.repeat(batch_size, 1, 1, 1) + + if len(obs.shape) == 3: + obs = np.expand_dims(obs, axis=0) + return np.repeat(obs, repeats=batch_size, axis=0) + + +def prepare_pad_batch(obs, next_obs, action, batch_size=32): + """Prepare batch for self-supervised policy adaptation at test-time""" + batch_obs = batch_from_obs(torch.from_numpy(obs).cuda(), batch_size) + batch_next_obs = batch_from_obs(torch.from_numpy(next_obs).cuda(), batch_size) + batch_action = torch.from_numpy(action).cuda().unsqueeze(0).repeat(batch_size, 1) + + return random_crop_cuda(batch_obs), random_crop_cuda(batch_next_obs), batch_action + + +def identity(x): + return x + + +def random_shift(imgs, pad=4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = imgs.shape + imgs = F.pad(imgs, (pad, pad, pad, pad), mode="replicate") + return kornia.augmentation.RandomCrop((h, w))(imgs) + + +def random_crop(x, size=84, w1=None, h1=None, return_w1_h1=False): + """Vectorized CUDA implementation of random crop, imgs: (B,C,H,W), size: output size""" + assert (w1 is None and h1 is None) or ( + w1 is not None and h1 is not None + ), "must either specify both w1 and h1 or neither of them" + assert isinstance(x, torch.Tensor) and x.is_cuda, "input must be CUDA tensor" + + n = x.shape[0] + img_size = x.shape[-1] + crop_max = img_size - size + + if crop_max <= 0: + if return_w1_h1: + return x, None, None + return x + + x = x.permute(0, 2, 3, 1) + + if w1 is None: + w1 = torch.LongTensor(n).random_(0, crop_max) + h1 = torch.LongTensor(n).random_(0, crop_max) + + windows = view_as_windows_cuda(x, (1, size, size, 1))[..., 0, :, :, 0] + cropped = windows[torch.arange(n), w1, h1] + + if return_w1_h1: + return cropped, w1, h1 + + return cropped + + +def view_as_windows_cuda(x, window_shape): + """PyTorch CUDA-enabled implementation of view_as_windows""" + assert isinstance(window_shape, tuple) and len(window_shape) == len( + x.shape + ), "window_shape must be a tuple with same number of dimensions as x" + + slices = tuple(slice(None, None, st) for st in torch.ones(4).long()) + win_indices_shape = [ + x.size(0), + x.size(1) - int(window_shape[1]), + x.size(2) - int(window_shape[2]), + x.size(3), + ] + + new_shape = tuple(list(win_indices_shape) + list(window_shape)) + strides = tuple(list(x[slices].stride()) + list(x.stride())) + + return x.as_strided(new_shape, strides) diff --git a/DMC/src/env/distracting_control/background.py b/DMC/src/env/distracting_control/background.py new file mode 100644 index 0000000..326ef63 --- /dev/null +++ b/DMC/src/env/distracting_control/background.py @@ -0,0 +1,265 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A wrapper for dm_control environments which applies color distractions.""" +import os + +from PIL import Image +import collections +from dm_control.rl import control +import numpy as np + +import utils +from dm_control.mujoco.wrapper import mjbindings + +DAVIS17_TRAINING_VIDEOS = [ + 'bear', 'bmx-bumps', 'boat', 'boxing-fisheye', 'breakdance-flare', 'bus', + 'car-turn', 'cat-girl', 'classic-car', 'color-run', 'crossing', + 'dance-jump', 'dancing', 'disc-jockey', 'dog-agility', 'dog-gooses', + 'dogs-scale', 'drift-turn', 'drone', 'elephant', 'flamingo', 'hike', + 'hockey', 'horsejump-low', 'kid-football', 'kite-walk', 'koala', + 'lady-running', 'lindy-hop', 'longboard', 'lucia', 'mallard-fly', + 'mallard-water', 'miami-surf', 'motocross-bumps', 'motorbike', 'night-race', + 'paragliding', 'planes-water', 'rallye', 'rhino', 'rollerblade', + 'schoolgirls', 'scooter-board', 'scooter-gray', 'sheep', 'skate-park', + 'snowboard', 'soccerball', 'stroller', 'stunt', 'surf', 'swing', 'tennis', + 'tractor-sand', 'train', 'tuk-tuk', 'upside-down', 'varanus-cage', 'walking' +] +DAVIS17_VALIDATION_VIDEOS = [ + 'bike-packing', 'blackswan', 'bmx-trees', 'breakdance', 'camel', + 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 'dog', 'dogs-jump', + 'drift-chicane', 'drift-straight', 'goat', 'gold-fish', 'horsejump-high', + 'india', 'judo', 'kite-surf', 'lab-coat', 'libby', 'loading', 'mbike-trick', + 'motocross-jump', 'paragliding-launch', 'parkour', 'pigs', 'scooter-black', + 'shooting', 'soapbox' +] +SKY_TEXTURE_INDEX = 0 +Texture = collections.namedtuple('Texture', ('size', 'address', 'textures')) + + +def imread(filename): + img = Image.open(filename) + img_np = np.asarray(img) + return img_np + + +def size_and_flatten(image, ref_height, ref_width): + # Resize image if necessary and flatten the result. + image_height, image_width = image.shape[:2] + + if image_height != ref_height or image_width != ref_width: + image = np.asarray(Image.fromarray(image).resize(size=(ref_width, ref_height))) + + #utils.save_obs(torch.from_numpy(image).permute(2,0,1), 'image') + + return image.flatten(order='K') + + +def blend_to_background(alpha, image, background): + if alpha == 1.0: + return image + elif alpha == 0.0: + return background + else: + return (alpha * image.astype(np.float32) + + (1. - alpha) * background.astype(np.float32)).astype(np.uint8) + + +class DistractingBackgroundEnv(control.Environment): + """Environment wrapper for background visual distraction. + + **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure + the background image changes are applied before rendering occurs. + """ + + def __init__(self, + env, + dataset_paths=None, + dataset_videos=None, + video_alpha=1.0, + ground_plane_alpha=1.0, + num_videos=None, + dynamic=False, + seed=None, + shuffle_buffer_size=None): + + if not 0 <= video_alpha <= 1: + raise ValueError('`video_alpha` must be in the range [0, 1]') + + self._env = env + self._video_alpha = video_alpha + self._ground_plane_alpha = ground_plane_alpha + self._random_state = np.random.RandomState(seed=seed) + self._dynamic = dynamic + self._shuffle_buffer_size = shuffle_buffer_size + self._background = None + self._current_img_index = 0 + self._frame_count = 0 + + if not dataset_paths or num_videos == 0: + # Allow running the wrapper without backgrounds to still set the ground + # plane alpha value. + self._video_paths = [] + print('Warning: no dataset paths and/or number of videos set to 0!') + else: + # Use all videos if no specific ones were passed. + if not dataset_videos: + _, _, filenames = next(walk(mypath)) + dataset_videos = sorted(filenames) + # Replace video placeholders 'train'/'val' with the list of videos. + elif dataset_videos in ['train', 'training']: + dataset_videos = DAVIS17_TRAINING_VIDEOS + elif dataset_videos in ['val', 'validation']: + dataset_videos = DAVIS17_VALIDATION_VIDEOS + # Get complete paths for all videos. + for dataset_path in dataset_paths: + video_paths = [ + os.path.join(dataset_path, subdir) for subdir in dataset_videos + ] + if len(video_paths) > 0: + break + assert len(video_paths) > 0, 'DAVIS dataset not found!' + + # Optionally use only the first num_paths many paths. + if num_videos is not None: + if num_videos > len(video_paths) or num_videos < 0: + raise ValueError(f'`num_bakground_paths` is {num_videos} but ' + 'should not be larger than the number of available ' + f'background paths ({len(video_paths)}) and at ' + 'least 0.') + video_paths = video_paths[:num_videos] + + self._video_paths = video_paths + + def reset(self): + """Reset the background state.""" + self._frame_count = 0 + time_step = self._env.reset() + self._reset_background() + return time_step + + def _reset_background(self): + # Make grid semi-transparent. + if self._ground_plane_alpha is not None: + self._env.physics.named.model.mat_rgba['grid', + 'a'] = self._ground_plane_alpha + + # For some reason the height of the skybox is set to 4800 by default, + # which does not work with new textures. + self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] = 800 + + # Set the sky texture reference. + sky_height = self._env.physics.model.tex_height[SKY_TEXTURE_INDEX] + sky_width = self._env.physics.model.tex_width[SKY_TEXTURE_INDEX] + sky_size = sky_height * sky_width * 3 + sky_address = self._env.physics.model.tex_adr[SKY_TEXTURE_INDEX] + + sky_texture = self._env.physics.model.tex_rgb[sky_address:sky_address + + sky_size].astype(np.float32) + + if self._video_paths: + + if self._shuffle_buffer_size: + # Shuffle images from all videos together to get background frames. + file_names = [ + os.path.join(path, fn) + for path in self._video_paths + for fn in utils.listdir(path) + ] + self._random_state.shuffle(file_names) + # Load only the first n images for performance reasons. + file_names = file_names[:self._shuffle_buffer_size] + images = [imread(fn) for fn in file_names] + else: + # Randomly pick a video and load all images. + video_path = self._random_state.choice(self._video_paths) + file_names = utils.listdir(video_path) + if not self._dynamic: + # Randomly pick a single static frame. + file_names = [self._random_state.choice(file_names)] + images = [imread(os.path.join(video_path, fn)) for fn in file_names] + + # Pick a random starting point and steping direction. + self._current_img_index = self._random_state.choice(len(images)) + self._step_direction = self._random_state.choice([-1, 1]) + + # Prepare images in the texture format by resizing and flattening. + + # Generate image textures. + texturized_images = [] + for image in images: + image_flattened = size_and_flatten(image, sky_height, sky_width) + new_texture = blend_to_background(self._video_alpha, image_flattened, + sky_texture) + texturized_images.append(new_texture) + + else: + + self._current_img_index = 0 + texturized_images = [sky_texture] + + self._background = Texture(sky_size, sky_address, texturized_images) + self._apply() + + def step(self, action): + self._frame_count += 1 + time_step = self._env.step(action) + + if time_step.first(): + self._reset_background() + return time_step + + if self._dynamic and self._video_paths and self._frame_count % 2 == 0: + # Move forward / backward in the image sequence by updating the index. + self._current_img_index += self._step_direction + + # Start moving forward if we are past the start of the images. + if self._current_img_index <= 0: + self._current_img_index = 0 + self._step_direction = abs(self._step_direction) + # Start moving backwards if we are past the end of the images. + if self._current_img_index >= len(self._background.textures): + self._current_img_index = len(self._background.textures) - 1 + self._step_direction = -abs(self._step_direction) + + self._apply() + return time_step + + def _apply(self): + """Apply the background texture to the physics.""" + + if self._background: + start = self._background.address + end = self._background.address + self._background.size + texture = self._background.textures[self._current_img_index] + + self._env.physics.model.tex_rgb[start:end] = texture + # Upload the new texture to the GPU. Note: we need to make sure that the + # OpenGL context belonging to this Physics instance is the current one. + with self._env.physics.contexts.gl.make_current() as ctx: + ctx.call( + mjbindings.mjlib.mjr_uploadTexture, + self._env.physics.model.ptr, + self._env.physics.contexts.mujoco.ptr, + SKY_TEXTURE_INDEX, + ) + + # Forward property and method calls to self._env. + def __getattr__(self, attr): + if hasattr(self._env, attr): + return getattr(self._env, attr) + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) diff --git a/DMC/src/env/distracting_control/camera.py b/DMC/src/env/distracting_control/camera.py new file mode 100644 index 0000000..3f91aac --- /dev/null +++ b/DMC/src/env/distracting_control/camera.py @@ -0,0 +1,358 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A wrapper for dm_control environments which applies camera distractions.""" + +import copy +from dm_control.rl import control +import numpy as np + +CAMERA_MODES = ['fixed', 'track', 'trackcom', 'targetbody', 'targetbodycom'] + + +def eul2mat(theta): + """Converts euler angles (x, y, z) to a rotation matrix.""" + + return np.array([[ + np.cos(theta[1]) * np.cos(theta[2]), + np.sin(theta[0]) * np.sin(theta[1]) * np.cos(theta[2]) - + np.sin(theta[2]) * np.cos(theta[0]), + np.sin(theta[1]) * np.cos(theta[0]) * np.cos(theta[2]) + + np.sin(theta[0]) * np.sin(theta[2]) + ], + [ + np.sin(theta[2]) * np.cos(theta[1]), + np.sin(theta[0]) * np.sin(theta[1]) * np.sin(theta[2]) + + np.cos(theta[0]) * np.cos(theta[2]), + np.sin(theta[1]) * np.sin(theta[2]) * np.cos(theta[0]) - + np.sin(theta[0]) * np.cos(theta[2]) + ], + [ + -np.sin(theta[1]), + np.sin(theta[0]) * np.cos(theta[1]), + np.cos(theta[0]) * np.cos(theta[1]) + ]]) + + +def _mat_from_theta(cos_theta, sin_theta, a): + """Builds a rotation matrix from theta and an orientation vector.""" + + row1 = [ + cos_theta + a[0]**2. * (1. - cos_theta), + a[0] * a[1] * (1 - cos_theta) - a[2] * sin_theta, + a[0] * a[2] * (1 - cos_theta) + a[1] * sin_theta + ] + row2 = [ + a[1] * a[0] * (1 - cos_theta) + a[2] * sin_theta, + cos_theta + a[1]**2. * (1 - cos_theta), + a[1] * a[2] * (1. - cos_theta) - a[0] * sin_theta + ] + row3 = [ + a[2] * a[0] * (1. - cos_theta) - a[1] * sin_theta, + a[2] * a[1] * (1. - cos_theta) + a[0] * sin_theta, + cos_theta + (a[2]**2.) * (1. - cos_theta) + ] + return np.stack([row1, row2, row3]) + + +def rotvec2mat(theta, vec): + """Converts a rotation around a vector to a rotation matrix.""" + + a = vec / np.sqrt(np.sum(vec**2.)) + sin_theta = np.sin(theta) + cos_theta = np.cos(theta) + + return _mat_from_theta(cos_theta, sin_theta, a) + + +def get_lookat_xmat_no_roll(agent_pos, camera_pos): + """Solves for the cam rotation centering the agent with 0 roll.""" + + # NOTE(austinstone): This method leads to wild oscillations around the north + # and south polls. + # For example, if agent is at (0., 0., 0.) and the camera is at (.01, 0., 1.), + # this will produce a yaw of 90 degrees whereas if the camera is slightly + # adjacent at (-.01, 0., 1.) this will produce a yaw of -90 degrees. I'm + # not sure what the fix is, as this seems like the behavior we want in all + # environments except for reacher. + delta_vec = agent_pos - camera_pos + delta_vec /= np.sqrt(np.sum(delta_vec**2.)) + yaw = np.arctan2(delta_vec[0], delta_vec[1]) + pitch = np.arctan2(delta_vec[2], np.sqrt(np.sum(delta_vec[:2]**2.))) + pitch += np.pi / 2. # Camera starts out looking at [0, 0, -1.] + return eul2mat([pitch, 0., -yaw]).flatten() + + +def get_lookat_xmat(agent_pos, camera_pos): + """Solves for the cam rotation centering the agent, allowing roll.""" + + # Solve for the rotation which centers the agent in the scene. + delta_vec = agent_pos - camera_pos + delta_vec /= np.sqrt(np.sum(delta_vec**2.)) + y_vec = np.array([0., 0., -1.]) # This is where the cam starts from. + a = np.cross(y_vec, delta_vec) + sin_theta = np.sqrt(np.sum(a**2.)) + cos_theta = np.dot(delta_vec, y_vec) + a /= (np.sqrt(np.sum(a**2.)) + .0001) + return _mat_from_theta(cos_theta, sin_theta, a) + + +def cart2sphere(cart): + r = np.sqrt(np.sum(cart**2.)) + h_angle = np.arctan2(cart[1], cart[0]) + v_angle = np.arctan2(np.sqrt(np.sum(cart[:2]**2.)), cart[2]) + return np.array([r, h_angle, v_angle]) + + +def sphere2cart(sphere): + r, h_angle, v_angle = sphere + x = r * np.sin(v_angle) * np.cos(h_angle) + y = r * np.sin(v_angle) * np.sin(h_angle) + z = r * np.cos(v_angle) + return np.array([x, y, z]) + + +def clip_cam_position(position, min_radius, max_radius, min_h_angle, + max_h_angle, min_v_angle, max_v_angle): + new_position = [-1., -1., -1.] + new_position[0] = np.clip(position[0], min_radius, max_radius) + new_position[1] = np.clip(position[1], min_h_angle, max_h_angle) + new_position[2] = np.clip(position[2], min_v_angle, max_v_angle) + return new_position + + +def get_lookat_point(physics, camera_id): + """Get the point that the camera is looking at. + + It is assumed that the "point" the camera looks at the agent distance + away and projected along the camera viewing matrix. + + Args: + physics: mujoco physics objects + camera_id: int + + Returns: + position: float32 np.array of length 3 + """ + dist_to_agent = physics.named.data.cam_xpos[ + camera_id] - physics.named.data.subtree_com[1] + dist_to_agent = np.sqrt(np.sum(dist_to_agent**2.)) + initial_viewing_mat = copy.deepcopy(physics.named.data.cam_xmat[camera_id]) + initial_viewing_mat = np.reshape(initial_viewing_mat, (3, 3)) + z_vec = np.array([0., 0., -dist_to_agent]) + rotated_vec = np.dot(initial_viewing_mat, z_vec) + return rotated_vec + physics.named.data.cam_xpos[camera_id] + + +class DistractingCameraEnv(control.Environment): + """Environment wrapper for camera pose visual distraction. + + **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure + the camera pose changes are applied before rendering occurs. + """ + + def __init__(self, + env, + camera_id, + horizontal_delta, + vertical_delta, + max_vel, + vel_std, + roll_delta, + max_roll_vel, + roll_std, + max_zoom_in_percent, + max_zoom_out_percent, + limit_to_upper_quadrant=False, + seed=None): + self._env = env + self._camera_id = camera_id + self._horizontal_delta = horizontal_delta + self._vertical_delta = vertical_delta + + self._horizontal_delta = horizontal_delta + self._vertical_delta = vertical_delta + self._max_vel = max_vel + self._vel_std = vel_std + self._roll_delta = roll_delta + self._max_roll_vel = max_roll_vel + self._roll_vel_std = roll_std + self._max_zoom_in_percent = max_zoom_in_percent + self._max_zoom_out_percent = max_zoom_out_percent + self._limit_to_upper_quadrant = limit_to_upper_quadrant + + self._random_state = np.random.RandomState(seed=seed) + + # These camera state parameters will be set on the first reset call. + self._camera_type = None + self._camera_initial_lookat_point = None + + self._camera_vel = None + self._max_h_angle = None + self._max_v_angle = None + self._min_h_angle = None + self._min_v_angle = None + self._radius = None + self._roll_vel = None + self._vel_scaling = None + self._frame_count = 0 + + def setup_camera(self): + """Set up camera motion ranges and state.""" + # Define boundaries on the range of the camera motion. + mode = self._env._physics.model.cam_mode[0] + + camera_type = CAMERA_MODES[mode] + assert camera_type in ['fixed', 'trackcom'] + + self._camera_type = camera_type + self._cam_initial_lookat_point = get_lookat_point(self._env.physics, + self._camera_id) + + start_pos = copy.deepcopy( + self._env.physics.named.data.cam_xpos[self._camera_id]) + + if self._camera_type != 'fixed': + # Center the camera relative to the agent's center of mass. + start_pos -= self._env.physics.named.data.subtree_com[1] + + start_r, start_h_angle, start_v_angle = cart2sphere(start_pos) + # Scale the velocity by the starting radius. Most environments have radius 4, + # but this downscales the velocity for the envs with radius < 4. + self._vel_scaling = start_r / 4. + self._max_h_angle = start_h_angle + self._horizontal_delta + self._min_h_angle = start_h_angle - self._horizontal_delta + self._max_v_angle = start_v_angle + self._vertical_delta + self._min_v_angle = start_v_angle - self._vertical_delta + + if self._limit_to_upper_quadrant: + # A centered cam is at np.pi / 2. + self._max_v_angle = min(self._max_v_angle, np.pi / 2.) + self._min_v_angle = max(self._min_v_angle, 0.) + # A centered cam is at -np.pi / 2. + self._max_h_angle = min(self._max_h_angle, 0.) + self._min_h_angle = max(self._min_h_angle, -np.pi) + + self._max_roll = self._roll_delta + self._min_roll = -self._roll_delta + self._min_radius = max(start_r - start_r * self._max_zoom_in_percent, 0.) + self._max_radius = start_r + start_r * self._max_zoom_out_percent + + # Decide the starting position for the camera. + self._h_angle = self._random_state.uniform(self._min_h_angle, + self._max_h_angle) + + self._v_angle = self._random_state.uniform(self._min_v_angle, + self._max_v_angle) + + self._radius = self._random_state.uniform(self._min_radius, + self._max_radius) + + self._roll = self._random_state.uniform(self._min_roll, self._max_roll) + + # Decide the starting velocity for the camera. + vel = self._random_state.randn(3) + vel /= np.sqrt(np.sum(vel**2.)) + vel *= self._random_state.uniform(0., self._max_vel) + self._camera_vel = vel + self._roll_vel = self._random_state.uniform(-self._max_roll_vel, + self._max_roll_vel) + + def reset(self): + """Reset the camera state. """ + self._frame_count = 0 + time_step = self._env.reset() + self.setup_camera() + self._apply() + return time_step + + + def step(self, action): + self._frame_count += 1 + time_step = self._env.step(action) + + if time_step.first(): + self.setup_camera() + + if self._frame_count % 2 == 0: + self._apply() + return time_step + + def _apply(self): + if not self._camera_type: + self.setup_camera() + + # Random walk the velocity. + vel_delta = self._random_state.randn(3) + self._camera_vel += vel_delta * self._vel_std * self._vel_scaling + self._roll_vel += self._random_state.randn() * self._roll_vel_std + + # Clip velocity if it gets too big. + vel_norm = np.sqrt(np.sum(self._camera_vel**2.)) + if vel_norm > self._max_vel * self._vel_scaling: + self._camera_vel *= (self._max_vel * self._vel_scaling) / vel_norm + + self._roll_vel = np.clip(self._roll_vel, -self._max_roll_vel, + self._max_roll_vel) + + cart_cam_pos = sphere2cart([self._radius, self._h_angle, self._v_angle]) + # Apply velocity vector to camera + sphere_cam_pos2 = cart2sphere(cart_cam_pos + self._camera_vel) + sphere_cam_pos2 = clip_cam_position(sphere_cam_pos2, self._min_radius, + self._max_radius, self._min_h_angle, + self._max_h_angle, self._min_v_angle, + self._max_v_angle) + + self._camera_vel = sphere2cart(sphere_cam_pos2) - cart_cam_pos + + self._radius, self._h_angle, self._v_angle = sphere_cam_pos2 + + roll2 = self._roll + self._roll_vel + roll2 = np.clip(roll2, self._min_roll, self._max_roll) + + self._roll_vel = roll2 - self._roll + self._roll = roll2 + + cart_cam_pos = sphere2cart(sphere_cam_pos2) + + if self._limit_to_upper_quadrant: + lookat_method = get_lookat_xmat_no_roll + else: + # This method avoids jitteriness at the pole but allows some roll + # in the camera matrix. This is important for reacher. + lookat_method = get_lookat_xmat + + if self._camera_type == 'fixed': + lookat_mat = lookat_method(self._cam_initial_lookat_point, + cart_cam_pos) + else: + # Go from agent centric to world coords + cart_cam_pos += self._env.physics.named.data.subtree_com[1] + lookat_mat = lookat_method( + get_lookat_point(self._env.physics, self._camera_id), cart_cam_pos) + + lookat_mat = np.reshape(lookat_mat, (3, 3)) + roll_mat = rotvec2mat(self._roll, np.array([0., 0., 1.])) + xmat = np.dot(lookat_mat, roll_mat) + self._env.physics.named.data.cam_xpos[self._camera_id] = cart_cam_pos + self._env.physics.named.data.cam_xmat[self._camera_id] = xmat.flatten() + + # Forward property and method calls to self._env. + def __getattr__(self, attr): + if hasattr(self._env, attr): + return getattr(self._env, attr) + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) diff --git a/DMC/src/env/distracting_control/color.py b/DMC/src/env/distracting_control/color.py new file mode 100644 index 0000000..878c713 --- /dev/null +++ b/DMC/src/env/distracting_control/color.py @@ -0,0 +1,106 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""A wrapper for dm_control environments which applies color distractions.""" + +from dm_control.rl import control +import numpy as np + + +class DistractingColorEnv(control.Environment): + """Environment wrapper for color visual distraction. + + **NOTE**: This wrapper should be applied BEFORE the pixel wrapper to make sure + the color changes are applied before rendering occurs. + """ + + def __init__(self, env, step_std, max_delta, seed=None): + """Initialize the environment wrapper. + + Args: + env: instance of dm_control Environment to wrap with augmentations. + """ + if step_std < 0: + raise ValueError("`step_std` must be greater than or equal to 0.") + if max_delta < 0: + raise ValueError("`max_delta` must be greater than or equal to 0.") + + self._env = env + self._step_std = step_std + self._max_delta = max_delta + self._random_state = np.random.RandomState() + + self._cam_type = None + self._current_rgb = None + self._max_rgb = None + self._min_rgb = None + self._original_rgb = None + self._frame_count = 0 + + def reset(self): + """Reset the distractions state.""" + self._frame_count = 0 + time_step = self._env.reset() + self._reset_color() + return time_step + + def _reset_color(self): + # Save all original colors. + if self._original_rgb is None: + self._original_rgb = np.copy(self._env.physics.model.mat_rgba)[:, :3] + # Determine minimum and maximum rgb values. + self._max_rgb = np.clip(self._original_rgb + self._max_delta, 0.0, 1.0) + self._min_rgb = np.clip(self._original_rgb - self._max_delta, 0.0, 1.0) + + # Pick random colors in the allowed ranges. + r = self._random_state.uniform(size=self._min_rgb.shape) + self._current_rgb = self._min_rgb + r * (self._max_rgb - self._min_rgb) + + # Apply the color changes. + self._env.physics.model.mat_rgba[:, :3] = self._current_rgb + + def step(self, action): + self._frame_count += 1 + time_step = self._env.step(action) + + if time_step.first(): + self._reset_color() + return time_step + + if self._frame_count % 2 == 0: + color_change = self._random_state.randn(*self._current_rgb.shape) + color_change = color_change * self._step_std + else: + color_change = 0 + + new_color = self._current_rgb + color_change + + self._current_rgb = np.clip( + new_color, + a_min=self._min_rgb, + a_max=self._max_rgb, + ) + + # Apply the color changes. + self._env.physics.model.mat_rgba[:, :3] = self._current_rgb + return time_step + + # Forward property and method calls to self._env. + def __getattr__(self, attr): + if hasattr(self._env, attr): + return getattr(self._env, attr) + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) diff --git a/DMC/src/env/distracting_control/suite.py b/DMC/src/env/distracting_control/suite.py new file mode 100644 index 0000000..3ff0d45 --- /dev/null +++ b/DMC/src/env/distracting_control/suite.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of MuJoCo-based Reinforcement Learning environments. + +The suite provides a similar API to the original dm_control suite. +Users can configure the distractions on top of the original tasks. The suite is +targeted for loading environments directly with similar configurations as those +used in the original paper. Each distraction wrapper can be used independently +though. +""" +try: + from dm_control import suite # pylint: disable=g-import-not-at-top + from dm_control.suite.wrappers import pixels # pylint: disable=g-import-not-at-top +except ImportError: + suite = None + +from env.distracting_control import background +from env.distracting_control import camera +from env.distracting_control import color +from env.distracting_control import suite_utils + + +def is_available(): + return suite is not None + + +def load(domain_name, + task_name, + difficulty=None, + dynamic=False, + background_dataset_paths=None, + background_dataset_videos="train", + background_kwargs=None, + camera_kwargs=None, + color_kwargs=None, + task_kwargs=None, + environment_kwargs=None, + visualize_reward=False, + render_kwargs=None, + pixels_only=True, + pixels_observation_key="pixels"): + """Returns an environment from a domain name, task name and optional settings. + + ```python + env = suite.load('cartpole', 'balance') + ``` + + Note that this benchmark differs from the original implementation of Distracting Control Suite. + Notably, the `difficulty` argument has been changed, as well as changes to rate of change in the dynamic setting. + Tensorflow dependencies have also been dropped and replaced by PyTorch/NumPy equivalents. + The original codebase is available at https://github.com/google-research/google-research/tree/master/distracting_control. + + Users can also toggle dynamic properties for distractions. + + Args: + domain_name: A string containing the name of a domain. + task_name: A string containing the name of a task. + task_kwargs: Optional `dict` of keyword arguments for the task. + difficulty: Difficulty for the suite. Intensity scale, [0,1]. + dynamic: Boolean controlling whether distractions are dynamic or static. + backgound_dataset_path: String to the davis directory that contains the + video directories. + background_dataset_videos: String ('train'/'val') or list of strings of the + DAVIS videos to be used for backgrounds. + background_kwargs: Dict, overwrites settings for background distractions. + camera_kwargs: Dict, overwrites settings for camera distractions. + color_kwargs: Dict, overwrites settings for color distractions. + task_kwargs: Dict, dm control task kwargs. + environment_kwargs: Optional `dict` specifying keyword arguments for the + environment. + visualize_reward: Optional `bool`. If `True`, object colours in rendered + frames are set to indicate the reward at each step. Default `False`. + render_kwargs: Dict, render kwargs for pixel wrapper. + pixels_only: Boolean controlling the exclusion of states in the observation. + pixels_observation_key: Key in the observation used for the rendered image. + + Returns: + The requested environment. + """ + if not is_available(): + raise ImportError("dm_control module is not available. Make sure you " + "follow the installation instructions from the " + "dm_control package.") + + assert isinstance(difficulty, float), f'intensity must be a float' + assert difficulty >= 0 and difficulty <= 1, f'intensity must be in the [0,1] interval' + assert str(difficulty) in suite_utils.DIFFICULTY_NUM_VIDEOS.keys(), \ + f'intensity has only been implemented for the following values: {suite_utils.DIFFICULTY_NUM_VIDEOS.keys()}' + render_kwargs = render_kwargs or {} + if "camera_id" not in render_kwargs: + render_kwargs["camera_id"] = 2 if domain_name == "quadruped" else 0 + + env = suite.load( + domain_name, + task_name, + task_kwargs=task_kwargs, + environment_kwargs=environment_kwargs, + visualize_reward=visualize_reward) + + # Apply background distractions. + if difficulty: + final_background_kwargs = dict() + num_videos = suite_utils.DIFFICULTY_NUM_VIDEOS[str(difficulty)] + final_background_kwargs.update( + suite_utils.get_background_kwargs(domain_name, num_videos, dynamic, + background_dataset_paths, + background_dataset_videos)) + if background_kwargs: + # Overwrite kwargs with those passed here. + final_background_kwargs.update(background_kwargs) + env = background.DistractingBackgroundEnv(env, **final_background_kwargs) + + # Apply camera distractions. + if difficulty: + final_camera_kwargs = dict(camera_id=render_kwargs["camera_id"]) + final_camera_kwargs.update( + suite_utils.get_camera_kwargs(domain_name, difficulty, dynamic)) + if camera_kwargs: + # Overwrite kwargs with those passed here. + final_camera_kwargs.update(camera_kwargs) + env = camera.DistractingCameraEnv(env, **final_camera_kwargs) + + # Apply color distractions. + if difficulty: + final_color_kwargs = dict() + final_color_kwargs.update(suite_utils.get_color_kwargs(difficulty, dynamic)) + if color_kwargs: + # Overwrite kwargs with those passed here. + final_color_kwargs.update(color_kwargs) + env = color.DistractingColorEnv(env, **final_color_kwargs) + + # Apply Pixel wrapper after distractions. This is needed to ensure the + # changes from the distraction wrapper are applied to the MuJoCo environment + # before the rendering occurs. + env = pixels.Wrapper( + env, + pixels_only=pixels_only, + render_kwargs=render_kwargs, + observation_key=pixels_observation_key) + + return env diff --git a/DMC/src/env/distracting_control/suite_utils.py b/DMC/src/env/distracting_control/suite_utils.py new file mode 100644 index 0000000..9a262ed --- /dev/null +++ b/DMC/src/env/distracting_control/suite_utils.py @@ -0,0 +1,80 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A collection of MuJoCo-based Reinforcement Learning environments. + +The suite provides a similar API to the original dm_control suite. +Users can configure the distractions on top of the original tasks. The suite is +targeted for loading environments directly with similar configurations as those +used in the original paper. Each distraction wrapper can be used independently +though. +""" +import numpy as np + +DIFFICULTY_NUM_VIDEOS = {'0.025': 2, '0.05': 2, '0.1': 4, '0.15': 6, '0.2': 8, '0.3': None, '0.4': None, '0.5': None} +DEFAULT_BACKGROUND_PATH = "$HOME/davis/" + + +def get_color_kwargs(scale, dynamic): + max_delta = scale + step_std = 0.03 * scale if dynamic else 0.0 + return dict(max_delta=max_delta, step_std=step_std) + + +def get_camera_kwargs(domain_name, scale, dynamic): + assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah', + 'ball_in_cup', 'walker'] + assert scale >= 0.0 + assert scale <= 1.0 + return dict( + vertical_delta=np.pi / 2 * scale, + horizontal_delta=np.pi / 2 * scale, + # Limit camera to -90 / 90 degree rolls. + roll_delta=np.pi / 2. * scale, + vel_std=.1 * scale if dynamic else 0., + max_vel=.4 * scale if dynamic else 0., + roll_std=np.pi / 300 * scale if dynamic else 0., + max_roll_vel=np.pi / 50 * scale if dynamic else 0., + max_zoom_in_percent=.5 * scale, + max_zoom_out_percent=1.5 * scale, + limit_to_upper_quadrant='reacher' not in domain_name, + ) + + +def get_background_kwargs(domain_name, + num_videos, + dynamic, + dataset_paths, + dataset_videos=None, + shuffle=False, + video_alpha=1.0): + assert domain_name in ['reacher', 'cartpole', 'finger', 'cheetah', + 'ball_in_cup', 'walker'] + if domain_name == 'reacher': + ground_plane_alpha = 0.0 + elif domain_name in ['walker', 'cheetah']: + ground_plane_alpha = 1.0 + else: + ground_plane_alpha = 0.3 + + return dict( + num_videos=num_videos, + video_alpha=video_alpha, + ground_plane_alpha=ground_plane_alpha, + dynamic=dynamic, + dataset_paths=dataset_paths, + dataset_videos=dataset_videos, + shuffle_buffer_size=100 if shuffle else None, + ) diff --git a/DMC/src/env/dm_control/AUTHORS b/DMC/src/env/dm_control/AUTHORS new file mode 100644 index 0000000..e72ca79 --- /dev/null +++ b/DMC/src/env/dm_control/AUTHORS @@ -0,0 +1,7 @@ +# This is the list of dm_control authors for copyright purposes. +# +# This does not necessarily list everyone who has contributed code, since in +# some cases, their employer may be the copyright holder. To see the full list +# of contributors, see the revision history in source control. + +DeepMind Technologies diff --git a/DMC/src/env/dm_control/CONTRIBUTING.md b/DMC/src/env/dm_control/CONTRIBUTING.md new file mode 100644 index 0000000..ae319c7 --- /dev/null +++ b/DMC/src/env/dm_control/CONTRIBUTING.md @@ -0,0 +1,23 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution, +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. diff --git a/DMC/src/env/dm_control/LICENSE b/DMC/src/env/dm_control/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/DMC/src/env/dm_control/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/DMC/src/env/dm_control/README.md b/DMC/src/env/dm_control/README.md new file mode 100644 index 0000000..a352c4b --- /dev/null +++ b/DMC/src/env/dm_control/README.md @@ -0,0 +1,146 @@ +# `dm_control`: The DeepMind Control Suite and Package + +# ![all domains](all_domains.png) + +This package consists of the following "core" components: + +- [`dm_control.mujoco`]: Libraries that provide Python bindings to the MuJoCo + physics engine. + +- [`dm_control.suite`]: A set of Python Reinforcement Learning environments + powered by the MuJoCo physics engine. + +- [`dm_control.viewer`]: An interactive environment viewer. + +Additionally, the following components are available for the creation of more +complex control tasks: + +- [`dm_control.mjcf`]: A library for composing and modifying MuJoCo MJCF + models in Python. + +- `dm_control.composer`: A library for defining rich RL environments from + reusable, self-contained components. + +- [`dm_control.locomotion`]: Additional libraries for custom tasks. + +- [`dm_control.locomotion.soccer`]: Multi-agent soccer tasks. + +If you use this package, please cite our accompanying [tech report]: + +``` +@techreport{deepmindcontrolsuite2018, + title = {Deep{Mind} Control Suite}, + author = {Yuval Tassa and Yotam Doron and Alistair Muldal and Tom Erez + and Yazhe Li and Diego de Las Casas and David Budden and Abbas + Abdolmaleki and Josh Merel and Andrew Lefrancq and Timothy Lillicrap + and Martin Riedmiller}, + year = 2018, + month = jan, + howpublished = {https://arxiv.org/abs/1801.00690}, + url = {https://arxiv.org/abs/1801.00690}, + volume = {abs/1504.04804}, + institution = {DeepMind}, +} +``` + +## Requirements and Installation + +`dm_control` is regularly tested using the following platforms and Python +versions: + +| | Python 2.7 | Python 3.5 | +| ------------ | ---------- | ---------- | +| Ubuntu 14.04 | ✓ | ✓ | +| Ubuntu 16.04 | | ✓ | + +Various people have been successful in getting `dm_control` to work on other +Linux distros, OS X, and Windows. We do not provide active support for these, +but will endeavour to answer questions on a best-effort basis. + +Follow these steps to install `dm_control`: + +1. Download MuJoCo Pro 2.00 from the download page on the [MuJoCo website]. + MuJoCo Pro must be installed before `dm_control`, since `dm_control`'s + install script generates Python [`ctypes`] bindings based on MuJoCo's header + files. By default, `dm_control` assumes that the MuJoCo Zip archive is + extracted as `~/.mujoco/mujoco200_$PLATFORM` where `$PLATFORM` is either + `linux`, `win64`, or `macos`. + +2. Install the `dm_control` Python package by running `pip install dm_control`. + We recommend `pip install`ing into a `virtualenv`, or with the `--user` flag + to avoid interfering with system packages. At installation time, + `dm_control` looks for the MuJoCo headers from Step 1 in + `~/.mujoco/mujoco200_$PLATFORM/include`, however this path can be configured + with the `headers-dir` command line argument. + +3. Install a license key for MuJoCo, required by `dm_control` at runtime. See + the [MuJoCo license key page] for further details. By default, `dm_control` + looks for the MuJoCo license key file at `~/.mujoco/mjkey.txt`. + +4. If the license key (e.g. `mjkey.txt`) or the shared library provided by + MuJoCo Pro (e.g. `libmujoco200.so` or `libmujoco200.dylib`) are installed at + non-default paths, specify their locations using the `MJKEY_PATH` and + `MJLIB_PATH` environment variables respectively. These environment variables + should be set to the full path to the relevant file itself, e.g. + `export MJLIB_PATH=/path/to/libmujoco200.so`. + +## Versioning + +`dm_control` is released on a rolling basis: the latest commit on the `master` +branch of our GitHub repository represents our latest release. Our Python +package is versioned `0.0.N`, where `N` is the number that appears in the +`PiperOrigin-RevId` field of the commit message. We always ensure that `N` +strictly increases between a parent commit and its children. We do not upload +all versions to PyPI, and occasionally the latest version on PyPI may lag behind +the latest commit on GitHub. Should this happen, you can still install the +newest version available by running `pip install +git+git://github.com/deepmind/dm_control.git`. + +## Rendering + +The MuJoCo Python bindings support three different OpenGL rendering backends: +EGL (headless, hardware-accelerated), GLFW (windowed, hardware-accelerated), and +OSMesa (purely software-based). At least one of these three backends must be +available in order render through `dm_control`. + +* Hardware rendering with a windowing system is supported via GLFW and GLEW. + On Linux these can be installed using your distribution's package manager. + For example, on Debian and Ubuntu, this can be done by running `sudo apt-get + install libglfw3 libglew2.0`. Please note that: + + - [`dm_control.viewer`] can only be used with GLFW. + - GLFW will not work on headless machines. + +* "Headless" hardware rendering (i.e. without a windowing system such as X11) + requires [EXT_platform_device] support in the EGL driver. Recent Nvidia + drivers support this. You will also need GLEW. On Debian and Ubuntu, this + can be installed via `sudo apt-get install libglew2.0`. + +* Software rendering requires GLX and OSMesa. On Debian and Ubuntu these can + be installed using `sudo apt-get install libgl1-mesa-glx libosmesa6`. + +By default, `dm_control` will attempt to use GLFW first, then EGL, then OSMesa. +You can also specify a particular backend to use by setting the `MUJOCO_GL=` +environment variable to `"glfw"`, `"egl"`, or `"osmesa"`, respectively. + +## Additional instructions for Homebrew users on macOS + +1. The above instructions using `pip` should work, provided that you use a + Python interpreter that is installed by Homebrew (rather than the + system-default one). + +2. Before running, the `DYLD_LIBRARY_PATH` environment variable needs to be + updated with the path to the GLFW library. This can be done by running + `export DYLD_LIBRARY_PATH=$(brew --prefix)/lib:$DYLD_LIBRARY_PATH`. + +[EXT_platform_device]: https://www.khronos.org/registry/EGL/extensions/EXT/EGL_EXT_platform_device.txt +[MuJoCo license key page]: https://www.roboti.us/license.html +[MuJoCo website]: http://www.mujoco.org/ +[tech report]: https://arxiv.org/abs/1801.00690 +[`ctypes`]: https://docs.python.org/2/library/ctypes.html +[`dm_control.mjcf`]: dm_control/mjcf/README.md +[`dm_control.mujoco`]: dm_control/mujoco/README.md +[`dm_control.suite`]: dm_control/suite/README.md +[`dm_control.viewer`]: dm_control/viewer/README.md +[`dm_control.locomotion`]: dm_control/locomotion/README.md +[`dm_control.locomotion.soccer`]: dm_control/locomotion/soccer/README.md diff --git a/DMC/src/env/dm_control/all_domains.png b/DMC/src/env/dm_control/all_domains.png new file mode 100644 index 0000000..10c22fa Binary files /dev/null and b/DMC/src/env/dm_control/all_domains.png differ diff --git a/DMC/src/env/dm_control/dm_control/__init__.py b/DMC/src/env/dm_control/dm_control/__init__.py new file mode 100644 index 0000000..1ebb270 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/DMC/src/env/dm_control/dm_control/_render/__init__.py b/DMC/src/env/dm_control/dm_control/_render/__init__.py new file mode 100644 index 0000000..c3a678b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/__init__.py @@ -0,0 +1,89 @@ +# Copyright 2017-2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""OpenGL context management for rendering MuJoCo scenes. + +By default, the `Renderer` class will try to load one of the following rendering +APIs, in descending order of priority: GLFW > EGL > OSMesa. + +It is also possible to select a specific backend by setting the `MUJOCO_GL=` +environment variable to 'glfw', 'egl', or 'osmesa'. +""" + +import collections +import os + +from absl import logging +from dm_control._render import constants + +BACKEND = os.environ.get(constants.MUJOCO_GL) + + +# pylint: disable=g-import-not-at-top +def _import_egl(): + from dm_control._render.pyopengl.egl_renderer import EGLContext + return EGLContext + + +def _import_glfw(): + from dm_control._render.glfw_renderer import GLFWContext + return GLFWContext + + +def _import_osmesa(): + from dm_control._render.pyopengl.osmesa_renderer import OSMesaContext + return OSMesaContext +# pylint: enable=g-import-not-at-top + +_ALL_RENDERERS = collections.OrderedDict([ + (constants.GLFW, _import_glfw), + (constants.EGL, _import_egl), + (constants.OSMESA, _import_osmesa), +]) + + +if BACKEND is not None: + # If a backend was specified, try importing it and error if unsuccessful. + try: + import_func = _ALL_RENDERERS[BACKEND] + except KeyError: + raise RuntimeError( + 'Environment variable {} must be one of {!r}: got {!r}.' + .format(constants.MUJOCO_GL, _ALL_RENDERERS.keys(), BACKEND)) + logging.info('MUJOCO_GL=%s, attempting to import specified OpenGL backend.', + BACKEND) + Renderer = import_func() # pylint: disable=invalid-name +else: + logging.info('MUJOCO_GL is not set, so an OpenGL backend will be chosen ' + 'automatically.') + # Otherwise try importing them in descending order of priority until + # successful. + for name, import_func in _ALL_RENDERERS.items(): + try: + Renderer = import_func() + BACKEND = name + logging.info('Successfully imported OpenGL backend: %s', name) + break + except ImportError: + logging.info('Failed to import OpenGL backend: %s', name) + if BACKEND is None: + logging.info('No OpenGL backend could be imported. Attempting to create a ' + 'rendering context will result in a RuntimeError.') + + def Renderer(*args, **kwargs): # pylint: disable=function-redefined,invalid-name + del args, kwargs + raise RuntimeError('No OpenGL rendering backend is available.') + +USING_GPU = BACKEND in (constants.EGL, constants.GLFW) diff --git a/DMC/src/env/dm_control/dm_control/_render/base.py b/DMC/src/env/dm_control/dm_control/_render/base.py new file mode 100644 index 0000000..2c9f2ca --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/base.py @@ -0,0 +1,141 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for OpenGL context handlers. + +`ContextBase` defines a common API that OpenGL rendering contexts should conform +to. In addition, it provides a `make_current` context manager that: + +1. Makes this OpenGL context current within the appropriate rendering thread. +2. Yields an object exposing a `call` method that can be used to execute OpenGL + calls within the rendering thread. + +See the docstring for `dm_control.utils.render_executor` for further details +regarding rendering threads. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import atexit +import collections +import contextlib +import weakref + +from absl import logging +from dm_control._render import executor +import six + +_CURRENT_CONTEXT_FOR_THREAD = collections.defaultdict(lambda: None) +_CURRENT_THREAD_FOR_CONTEXT = collections.defaultdict(lambda: None) + + +@six.add_metaclass(abc.ABCMeta) +class ContextBase(object): + """Base class for managing OpenGL contexts.""" + + def __init__(self, + max_width, + max_height, + render_executor_class=executor.RenderExecutor): + """Initializes this context.""" + logging.debug('Using render executor class: %s', + render_executor_class.__name__) + self._render_executor = render_executor_class() + self._refcount = 0 + + self_weakref = weakref.ref(self) + def _free_at_exit(): + if self_weakref(): + self_weakref()._free_unconditionally() # pylint: disable=protected-access + atexit.register(_free_at_exit) + + with self._render_executor.execution_context() as ctx: + ctx.call(self._platform_init, max_width, max_height) + + def increment_refcount(self): + self._refcount += 1 + + def decrement_refcount(self): + self._refcount -= 1 + + @property + def terminated(self): + return self._render_executor.terminated + + @property + def thread(self): + return self._render_executor.thread + + def _free_on_executor_thread(self): + if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] == id(self): + del _CURRENT_THREAD_FOR_CONTEXT[id(self)] + del _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] + self._platform_free() + + def free(self): + """Frees resources associated with this context if its refcount is zero.""" + if self._refcount == 0: + self._free_unconditionally() + + def _free_unconditionally(self): + self._render_executor.terminate(self._free_on_executor_thread) + + def __del__(self): + self._free_unconditionally() + + @contextlib.contextmanager + def make_current(self): + """Context manager that makes this Renderer's OpenGL context current. + + Yields: + An object that exposes a `call` method that can be used to call a + function on the dedicated rendering thread. + + Raises: + RuntimeError: If this context is already current on another thread. + """ + + with self._render_executor.execution_context() as ctx: + if _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] != id(self): + if _CURRENT_THREAD_FOR_CONTEXT[id(self)]: + raise RuntimeError( + 'Cannot make context {!r} current on thread {!r}: ' + 'this context is already current on another thread {!r}.' + .format(self, self._render_executor.thread, + _CURRENT_THREAD_FOR_CONTEXT[id(self)])) + else: + current_context = ( + _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread]) + if current_context: + del _CURRENT_THREAD_FOR_CONTEXT[current_context] + _CURRENT_THREAD_FOR_CONTEXT[id(self)] = self._render_executor.thread + _CURRENT_CONTEXT_FOR_THREAD[self._render_executor.thread] = id(self) + ctx.call(self._platform_make_current) + yield ctx + + @abc.abstractmethod + def _platform_init(self, max_width, max_height): + """Performs an implementation-specific context initialization.""" + + @abc.abstractmethod + def _platform_make_current(self): + """Make the OpenGL context current on the executing thread.""" + + @abc.abstractmethod + def _platform_free(self): + """Performs an implementation-specific context cleanup.""" diff --git a/DMC/src/env/dm_control/dm_control/_render/base_test.py b/DMC/src/env/dm_control/dm_control/_render/base_test.py new file mode 100644 index 0000000..409b7c3 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/base_test.py @@ -0,0 +1,146 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for the base rendering module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +# Internal dependencies. +from absl.testing import absltest +from dm_control._render import base +from dm_control._render import executor + +WIDTH = 1024 +HEIGHT = 768 + + +class ContextBaseTests(absltest.TestCase): + + class ContextMock(base.ContextBase): + + def _platform_init(self, max_width, max_height): + self.init_thread = threading.current_thread() + self.make_current_count = 0 + self.max_width = max_width + self.max_height = max_height + self.free_thread = None + + def _platform_make_current(self): + self.make_current_count += 1 + self.make_current_thread = threading.current_thread() + + def _platform_free(self): + self.free_thread = threading.current_thread() + + def setUp(self): + super(ContextBaseTests, self).setUp() + self.context = ContextBaseTests.ContextMock(WIDTH, HEIGHT) + + def test_init(self): + self.assertIs(self.context.init_thread, self.context.thread) + self.assertEqual(self.context.max_width, WIDTH) + self.assertEqual(self.context.max_height, HEIGHT) + + def test_make_current(self): + self.assertEqual(self.context.make_current_count, 0) + + with self.context.make_current(): + pass + self.assertEqual(self.context.make_current_count, 1) + self.assertIs(self.context.make_current_thread, self.context.thread) + + # Already current, shouldn't trigger a call to `_platform_make_current`. + with self.context.make_current(): + pass + self.assertEqual(self.context.make_current_count, 1) + self.assertIs(self.context.make_current_thread, self.context.thread) + + def test_thread_sharing(self): + first_context = ContextBaseTests.ContextMock( + WIDTH, HEIGHT, executor.PassthroughRenderExecutor) + second_context = ContextBaseTests.ContextMock( + WIDTH, HEIGHT, executor.PassthroughRenderExecutor) + + with first_context.make_current(): + pass + self.assertEqual(first_context.make_current_count, 1) + + with first_context.make_current(): + pass + self.assertEqual(first_context.make_current_count, 1) + + with second_context.make_current(): + pass + self.assertEqual(second_context.make_current_count, 1) + + with second_context.make_current(): + pass + self.assertEqual(second_context.make_current_count, 1) + + with first_context.make_current(): + pass + self.assertEqual(first_context.make_current_count, 2) + + with second_context.make_current(): + pass + self.assertEqual(second_context.make_current_count, 2) + + def test_free(self): + with self.context.make_current(): + pass + + thread = self.context.thread + self.assertIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT) + self.assertIn(thread, base._CURRENT_CONTEXT_FOR_THREAD) + + self.context.free() + self.assertIs(self.context.free_thread, thread) + self.assertIsNone(self.context.thread) + + self.assertNotIn(id(self.context), base._CURRENT_THREAD_FOR_CONTEXT) + self.assertNotIn(thread, base._CURRENT_CONTEXT_FOR_THREAD) + + def test_refcounting(self): + thread = self.context.thread + + self.assertEqual(self.context._refcount, 0) + self.context.increment_refcount() + self.assertEqual(self.context._refcount, 1) + + # Context should not be freed yet, since its refcount is still positive. + self.context.free() + self.assertIsNone(self.context.free_thread) + self.assertIs(self.context.thread, thread) + + # Decrement the refcount to zero. + self.context.decrement_refcount() + self.assertEqual(self.context._refcount, 0) + + # Now the context can be freed. + self.context.free() + self.assertIs(self.context.free_thread, thread) + self.assertIsNone(self.context.thread) + + def test_del(self): + self.assertIsNone(self.context.free_thread) + self.context.__del__() + self.assertIsNotNone(self.context.free_thread) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/_render/constants.py b/DMC/src/env/dm_control/dm_control/_render/constants.py new file mode 100644 index 0000000..2509dce --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/constants.py @@ -0,0 +1,31 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""String constants for the rendering module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Name of the environment variable that selects a renderer platform. +MUJOCO_GL = 'MUJOCO_GL' + +# Name of the environment variable that selects a platform for PyOpenGL. +PYOPENGL_PLATFORM = 'PYOPENGL_PLATFORM' + +# Renderer platform specifiers. +OSMESA = 'osmesa' +GLFW = 'glfw' +EGL = 'egl' diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py b/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py new file mode 100644 index 0000000..12a9c48 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/executor/__init__.py @@ -0,0 +1,51 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RenderExecutor executes OpenGL rendering calls on an appropriate thread. + +OpenGL calls must be made on the same thread as where an OpenGL context is +made current on. With GPU rendering, migrating OpenGL contexts between threads +can become expensive. We provide a thread-safe executor that maintains a +thread on which an OpenGL context can be kept permanently current, and any other +threads that wish to render with that context will have their rendering calls +offloaded to the dedicated thread. + +For single-threaded applications, set the `DISABLE_RENDER_THREAD_OFFLOADING` +environment variable before launching the Python interpreter. This will +eliminate the overhead of unnecessary thread-switching. +""" + +# pylint: disable=g-import-not-at-top +import os +_OFFLOAD = not bool(os.environ.get('DISABLE_RENDER_THREAD_OFFLOADING', '')) +del os + +from dm_control._render.executor.render_executor import BaseRenderExecutor +from dm_control._render.executor.render_executor import OffloadingRenderExecutor +from dm_control._render.executor.render_executor import PassthroughRenderExecutor + +_EXECUTORS = (PassthroughRenderExecutor, OffloadingRenderExecutor) + +try: + from dm_control._render.executor.native_mutex.render_executor import NativeMutexOffloadingRenderExecutor + _EXECUTORS += (NativeMutexOffloadingRenderExecutor,) +except ImportError: + NativeMutexOffloadingRenderExecutor = None + +if _OFFLOAD: + RenderExecutor = ( # pylint: disable=invalid-name + NativeMutexOffloadingRenderExecutor or OffloadingRenderExecutor) +else: + RenderExecutor = PassthroughRenderExecutor # pylint: disable=invalid-name diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py new file mode 100644 index 0000000..a7911e7 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor.py @@ -0,0 +1,222 @@ +# Copyright 2017-2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RenderExecutors executes OpenGL rendering calls on an appropriate thread. + +The purpose of these classes is to ensure that OpenGL calls are made on the +same thread as where an OpenGL context was made current. + +In a single-threaded setting, `PassthroughRenderExecutor` is essentially a no-op +that executes rendering calls on the same thread. This is provided to minimize +thread-switching overhead. + +In a multithreaded setting, `OffloadingRenderExecutor` maintains a separate +dedicated thread on which the OpenGL context is created and made current. All +subsequent rendering calls are then offloaded onto this dedicated thread. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import contextlib +import threading + +from concurrent import futures +import six + +_NOT_IN_CONTEXT = 'Cannot be called outside of an `execution_context`.' +_ALREADY_TERMINATED = 'This executor has already been terminated.' + + +class _FakeLock(object): + """An object with the same API as `threading.Lock` but that does nothing.""" + + def acquire(self, blocking=True): + pass + + def release(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + del exc_type, exc_value, traceback + + +_FAKE_LOCK = _FakeLock() + + +@six.add_metaclass(abc.ABCMeta) +class BaseRenderExecutor(object): + """An object that manages rendering calls for an OpenGL context. + + This class helps ensure that OpenGL calls are made on the correct thread. The + usage pattern is as follows: + + ```python + executor = SomeRenderExecutorClass() + with executor.execution_context() as ctx: + ctx.call(an_opengl_call, arg, kwarg=foo) + result = ctx.call(another_opengl_call) + ``` + """ + + def __init__(self): + self._locked = 0 + self._terminated = False + + def _check_locked(self): + if not self._locked: + raise RuntimeError(_NOT_IN_CONTEXT) + + def _check_not_terminated(self): + if self._terminated: + raise RuntimeError(_ALREADY_TERMINATED) + + @contextlib.contextmanager + def execution_context(self): + """A context manager that allows calls to be offloaded to this executor.""" + self._check_not_terminated() + with self._lock_if_necessary: + self._locked += 1 + yield self + self._locked -= 1 + + @property + def terminated(self): + return self._terminated + + @abc.abstractproperty + def thread(self): + pass + + @abc.abstractproperty + def _lock_if_necessary(self): + pass + + @abc.abstractmethod + def call(self, *args, **kwargs): + pass + + @abc.abstractmethod + def terminate(self, cleanup_callable=None): + pass + + +class PassthroughRenderExecutor(BaseRenderExecutor): + """A no-op render executor that executes on the calling thread.""" + + def __init__(self): + super(PassthroughRenderExecutor, self).__init__() + self._mutex = threading.RLock() + + @property + def thread(self): + if not self._terminated: + return threading.current_thread() + else: + return None + + @property + def _lock_if_necessary(self): + return self._mutex + + def call(self, func, *args, **kwargs): + self._check_locked() + return func(*args, **kwargs) + + def terminate(self, cleanup_callable=None): + with self._lock_if_necessary: + if not self._terminated: + if cleanup_callable: + cleanup_callable() + self._terminated = True + + +class _ThreadPoolExecutorPool(object): + """A pool of reusable ThreadPoolExecutors.""" + + def __init__(self): + self._deque = collections.deque() + self._lock = threading.Lock() + + def acquire(self): + with self._lock: + if self._deque: + return self._deque.popleft() + else: + return futures.ThreadPoolExecutor(max_workers=1) + + def release(self, thread_pool_executor): + with self._lock: + self._deque.append(thread_pool_executor) + + +_THREAD_POOL_EXECUTOR_POOL = _ThreadPoolExecutorPool() + + +class OffloadingRenderExecutor(BaseRenderExecutor): + """A render executor that executes calls on a dedicated offload thread.""" + + def __init__(self): + super(OffloadingRenderExecutor, self).__init__() + self._mutex = threading.RLock() + self._executor = _THREAD_POOL_EXECUTOR_POOL.acquire() + self._thread = self._executor.submit(threading.current_thread).result() + + @property + def thread(self): + return self._thread + + @property + def _lock_if_necessary(self): + if threading.current_thread() is self.thread: + # If the offload thread needs to make a call to its own executor, for + # example when a weakref callback is triggered during an offloaded call, + # then we must not try to reacquire our own lock. + # Otherwise, a deadlock ensues. + return _FAKE_LOCK + else: + return self._mutex + + def call(self, func, *args, **kwargs): + self._check_locked() + return self._call_locked(func, *args, **kwargs) + + def _call_locked(self, func, *args, **kwargs): + if threading.current_thread() is self.thread: + # If the offload thread needs to make a call to its own executor, for + # example when a weakref callback is triggered during an offloaded call, + # we should just directly call the function. + # Otherwise, a deadlock ensues. + return func(*args, **kwargs) + else: + return self._executor.submit(func, *args, **kwargs).result() + + def terminate(self, cleanup_callable=None): + if self._terminated: + return + with self._lock_if_necessary: + if not self._terminated: + if cleanup_callable: + self._call_locked(cleanup_callable) + _THREAD_POOL_EXECUTOR_POOL.release(self._executor) + self._executor = None + self._thread = None + self._terminated = True diff --git a/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py new file mode 100644 index 0000000..56f68f8 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/executor/render_executor_test.py @@ -0,0 +1,210 @@ +# Copyright 2017-2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.utils.render_executor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +import time +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control._render import executor +import mock +from six.moves import range + + +def enforce_timeout(timeout): + def wrap(test_func): + def wrapped_test(self, *args, **kwargs): + thread = threading.Thread( + target=test_func, args=((self,) + args), kwargs=kwargs) + thread.daemon = True + thread.start() + thread.join(timeout=timeout) + self.assertFalse( + thread.is_alive(), + msg='Test timed out after {} seconds.'.format(timeout)) + return wrapped_test + return wrap + + +class RenderExecutorTest(parameterized.TestCase): + + def _make_executor(self, executor_type): + if (executor_type == executor.NativeMutexOffloadingRenderExecutor and + executor_type is None): + raise unittest.SkipTest( + 'NativeMutexOffloadingRenderExecutor is not available.') + else: + return executor_type() + + def test_passthrough_executor_thread(self): + render_executor = self._make_executor(executor.PassthroughRenderExecutor) + self.assertIs(render_executor.thread, threading.current_thread()) + render_executor.terminate() + + @parameterized.parameters(executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_offloading_executor_thread(self, executor_type): + render_executor = self._make_executor(executor_type) + self.assertIsNot(render_executor.thread, threading.current_thread()) + render_executor.terminate() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_call_on_correct_thread(self, executor_type): + render_executor = self._make_executor(executor_type) + with render_executor.execution_context() as ctx: + actual_executed_thread = ctx.call(threading.current_thread) + self.assertIs(actual_executed_thread, render_executor.thread) + render_executor.terminate() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_multithreaded(self, executor_type): + render_executor = self._make_executor(executor_type) + list_length = 5 + shared_list = [None] * list_length + + def fill_list(thread_idx): + def assign_value(i): + shared_list[i] = thread_idx + for _ in range(1000): + with render_executor.execution_context() as ctx: + for i in range(list_length): + ctx.call(assign_value, i) + # Other threads should be prevented from calling `assign_value` while + # this thread is inside the `execution_context`. + self.assertEqual(shared_list, [thread_idx] * list_length) + + threads = [threading.Thread(target=fill_list, args=(i,)) for i in range(9)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + render_executor.terminate() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_exception(self, executor_type): + render_executor = self._make_executor(executor_type) + message = 'fake error' + def raise_value_error(): + raise ValueError(message) + with render_executor.execution_context() as ctx: + with self.assertRaisesWithLiteralMatch(ValueError, message): + ctx.call(raise_value_error) + render_executor.terminate() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_terminate(self, executor_type): + render_executor = self._make_executor(executor_type) + cleanup = mock.MagicMock() + render_executor.terminate(cleanup) + cleanup.assert_called_once_with() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_call_outside_of_context(self, executor_type): + render_executor = self._make_executor(executor_type) + func = mock.MagicMock() + with self.assertRaisesWithLiteralMatch( + RuntimeError, executor.render_executor._NOT_IN_CONTEXT): + render_executor.call(func) + # Also test that the locked flag is properly cleared when leaving a context. + with render_executor.execution_context(): + render_executor.call(lambda: None) + with self.assertRaisesWithLiteralMatch( + RuntimeError, executor.render_executor._NOT_IN_CONTEXT): + render_executor.call(func) + func.assert_not_called() + render_executor.terminate() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_call_after_terminate(self, executor_type): + render_executor = self._make_executor(executor_type) + render_executor.terminate() + func = mock.MagicMock() + with self.assertRaisesWithLiteralMatch( + RuntimeError, executor.render_executor._ALREADY_TERMINATED): + with render_executor.execution_context() as ctx: + ctx.call(func) + func.assert_not_called() + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + def test_locking(self, executor_type): + render_executor = self._make_executor(executor_type) + other_thread_context_entered = threading.Condition() + other_thread_context_done = [False] + def other_thread_func(): + with render_executor.execution_context(): + with other_thread_context_entered: + other_thread_context_entered.notify() + time.sleep(1) + other_thread_context_done[0] = True + other_thread = threading.Thread(target=other_thread_func) + with other_thread_context_entered: + other_thread.start() + other_thread_context_entered.wait() + with render_executor.execution_context(): + self.assertTrue( + other_thread_context_done[0], + msg=('Main thread should not be able to enter the execution context ' + 'until the other thread is done.')) + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + @enforce_timeout(timeout=5.) + def test_reentrant_locking(self, executor_type): + render_executor = self._make_executor(executor_type) + def triple_lock(render_executor): + with render_executor.execution_context(): + with render_executor.execution_context(): + with render_executor.execution_context(): + pass + triple_lock(render_executor) + + @parameterized.parameters(executor.PassthroughRenderExecutor, + executor.OffloadingRenderExecutor, + executor.NativeMutexOffloadingRenderExecutor) + @enforce_timeout(timeout=5.) + def test_no_deadlock_in_callbacks(self, executor_type): + render_executor = self._make_executor(executor_type) + # This test times out in the event of a deadlock. + def callback(): + with render_executor.execution_context() as ctx: + ctx.call(lambda: None) + with render_executor.execution_context() as ctx: + ctx.call(callback) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py new file mode 100644 index 0000000..847b1ad --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer.py @@ -0,0 +1,77 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""An OpenGL renderer backed by GLFW.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +from dm_control._render import base +from dm_control._render import executor +import six + +# Re-raise any exceptions that occur during module import as `ImportError`s. +# This simplifies the conditional imports in `render/__init__.py`. +try: + import glfw # pylint: disable=g-import-not-at-top +except (ImportError, IOError, OSError) as exc: + _, exc, tb = sys.exc_info() + six.reraise(ImportError, ImportError(str(exc)), tb) +try: + glfw.init() +except glfw.GLFWError as exc: + _, exc, tb = sys.exc_info() + six.reraise(ImportError, ImportError(str(exc)), tb) + + +class GLFWContext(base.ContextBase): + """An OpenGL context backed by GLFW.""" + + def __init__(self, max_width, max_height): + # GLFWContext always uses `PassthroughRenderExecutor` rather than offloading + # rendering calls to a separate thread because GLFW can only be safely used + # from the main thread. + super(GLFWContext, self).__init__(max_width, max_height, + executor.PassthroughRenderExecutor) + + def _platform_init(self, max_width, max_height): + """Initializes this context. + + Args: + max_width: Integer specifying the maximum framebuffer width in pixels. + max_height: Integer specifying the maximum framebuffer height in pixels. + """ + glfw.window_hint(glfw.VISIBLE, 0) + glfw.window_hint(glfw.DOUBLEBUFFER, 0) + self._context = glfw.create_window(width=max_width, height=max_height, + title='Invisible window', monitor=None, + share=None) + # This reference prevents `glfw.destroy_window` from being garbage-collected + # before the last window is destroyed, otherwise we may get + # `AttributeError`s when the `__del__` method is later called. + self._destroy_window = glfw.destroy_window + + def _platform_make_current(self): + glfw.make_context_current(self._context) + + def _platform_free(self): + """Frees resources associated with this context.""" + if self._context: + if glfw.get_current_context() == self._context: + glfw.make_context_current(None) + self._destroy_window(self._context) + self._context = None diff --git a/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py new file mode 100644 index 0000000..2e3be3e --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/glfw_renderer_test.py @@ -0,0 +1,75 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for GLFWContext.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +# Internal dependencies. +from absl.testing import absltest +from dm_control import _render +from dm_control.mujoco import wrapper +from dm_control.mujoco.testing import decorators + + +import mock # pylint: disable=g-import-not-at-top + +MAX_WIDTH = 1024 +MAX_HEIGHT = 1024 + +CONTEXT_PATH = _render.__name__ + '.glfw_renderer.glfw' + + +@unittest.skipUnless( + _render.BACKEND == _render.constants.GLFW, + reason='GLFW beckend not selected.') +class GLFWContextTest(absltest.TestCase): + + def test_init(self): + mock_context = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_glfw: + mock_glfw.create_window.return_value = mock_context + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + self.assertIs(renderer._context, mock_context) + + def test_make_current(self): + mock_context = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_glfw: + mock_glfw.create_window.return_value = mock_context + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + with renderer.make_current(): + pass + mock_glfw.make_context_current.assert_called_once_with(mock_context) + + def test_freeing(self): + mock_context = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_glfw: + mock_glfw.create_window.return_value = mock_context + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + renderer.free() + mock_glfw.destroy_window.assert_called_once_with(mock_context) + self.assertIsNone(renderer._context) + + @decorators.run_threaded(num_threads=1, calls_per_thread=20) + def test_repeatedly_create_and_destroy_context(self): + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + wrapper.MjrContext(wrapper.MjModel.from_xml_string(''), renderer) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py new file mode 100644 index 0000000..a514c4b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py new file mode 100644 index 0000000..2e373ed --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_ext.py @@ -0,0 +1,79 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Extends OpenGL.EGL with definitions necessary for headless rendering.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ctypes +from OpenGL.platform import ctypesloader # pylint: disable=g-bad-import-order +try: + # Nvidia driver seems to need libOpenGL.so (as opposed to libGL.so) + # for multithreading to work properly. We load this in before everything else. + ctypesloader.loadLibrary(ctypes.cdll, 'OpenGL', mode=ctypes.RTLD_GLOBAL) +except OSError: + pass + +# pylint: disable=g-import-not-at-top + +from OpenGL import EGL +from OpenGL import error +from six.moves import range + + +# From the EGL_EXT_device_enumeration extension. +PFNEGLQUERYDEVICESEXTPROC = ctypes.CFUNCTYPE( + EGL.EGLBoolean, + EGL.EGLint, + ctypes.POINTER(EGL.EGLDeviceEXT), + ctypes.POINTER(EGL.EGLint), +) +try: + _eglQueryDevicesEXT = PFNEGLQUERYDEVICESEXTPROC( # pylint: disable=invalid-name + EGL.eglGetProcAddress('eglQueryDevicesEXT')) +except TypeError: + raise ImportError('eglQueryDevicesEXT is not available.') + + +# From the EGL_EXT_platform_device extension. +EGL_PLATFORM_DEVICE_EXT = 0x313F +PFNEGLGETPLATFORMDISPLAYEXTPROC = ctypes.CFUNCTYPE( + EGL.EGLDisplay, EGL.EGLenum, ctypes.c_void_p, ctypes.POINTER(EGL.EGLint)) +try: + eglGetPlatformDisplayEXT = PFNEGLGETPLATFORMDISPLAYEXTPROC( # pylint: disable=invalid-name + EGL.eglGetProcAddress('eglGetPlatformDisplayEXT')) +except TypeError: + raise ImportError('eglGetPlatformDisplayEXT is not available.') + + +# Wrap raw _eglQueryDevicesEXT function into something more Pythonic. +def eglQueryDevicesEXT(max_devices=10): # pylint: disable=invalid-name + devices = (EGL.EGLDeviceEXT * max_devices)() + num_devices = EGL.EGLint() + success = _eglQueryDevicesEXT(max_devices, devices, num_devices) + if success == EGL.EGL_TRUE: + return [devices[i] for i in range(num_devices.value)] + else: + raise error.GLError(err=EGL.eglGetError(), + baseOperation=eglQueryDevicesEXT, + result=success) + + +# Expose everything from upstream so that +# we can use this as a drop-in replacement for OpenGL.EGL. +# pylint: disable=wildcard-import,g-bad-import-order +from OpenGL.EGL import * diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py new file mode 100644 index 0000000..249b490 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/egl_renderer.py @@ -0,0 +1,135 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""An OpenGL renderer backed by EGL, provided through PyOpenGL.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import atexit +import ctypes +import os + +from dm_control._render import base +from dm_control._render import constants +from dm_control._render import executor + +PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM) + +if not PYOPENGL_PLATFORM: + os.environ[constants.PYOPENGL_PLATFORM] = constants.EGL +elif PYOPENGL_PLATFORM != constants.EGL: + raise ImportError( + 'Cannot use EGL rendering platform. ' + 'The PYOPENGL_PLATFORM environment variable is set to {!r} ' + '(should be either unset or {!r}).' + .format(PYOPENGL_PLATFORM, constants.EGL)) + + +# pylint: disable=g-import-not-at-top +from dm_control._render.pyopengl import egl_ext as EGL +from OpenGL import error + + +def create_initialized_headless_egl_display(): + """Creates an initialized EGL display directly on a device.""" + devices = EGL.eglQueryDevicesEXT() + if os.environ.get("CUDA_VISIBLE_DEVICES", None) is not None: + devices = [devices[int(os.environ["CUDA_VISIBLE_DEVICES"])]] + for device in devices: + display = EGL.eglGetPlatformDisplayEXT( + EGL.EGL_PLATFORM_DEVICE_EXT, device, None) + if display != EGL.EGL_NO_DISPLAY and EGL.eglGetError() == EGL.EGL_SUCCESS: + # `eglInitialize` may or may not raise an exception on failure depending + # on how PyOpenGL is configured. We therefore catch a `GLError` and also + # manually check the output of `eglGetError()` here. + try: + initialized = EGL.eglInitialize(display, None, None) + except error.GLError: + pass + else: + if initialized == EGL.EGL_TRUE and EGL.eglGetError() == EGL.EGL_SUCCESS: + return display + return EGL.EGL_NO_DISPLAY + + +EGL_DISPLAY = create_initialized_headless_egl_display() +if EGL_DISPLAY == EGL.EGL_NO_DISPLAY: + raise ImportError('Cannot initialize a headless EGL display.') +atexit.register(EGL.eglTerminate, EGL_DISPLAY) + + +EGL_ATTRIBUTES = ( + EGL.EGL_RED_SIZE, 8, + EGL.EGL_GREEN_SIZE, 8, + EGL.EGL_BLUE_SIZE, 8, + EGL.EGL_ALPHA_SIZE, 8, + EGL.EGL_DEPTH_SIZE, 24, + EGL.EGL_STENCIL_SIZE, 8, + EGL.EGL_COLOR_BUFFER_TYPE, EGL.EGL_RGB_BUFFER, + EGL.EGL_SURFACE_TYPE, EGL.EGL_PBUFFER_BIT, + EGL.EGL_RENDERABLE_TYPE, EGL.EGL_OPENGL_BIT, + EGL.EGL_NONE +) + + +class EGLContext(base.ContextBase): + """An OpenGL context backed by EGL.""" + + def __init__(self, max_width, max_height): + # EGLContext currently only works with `PassthroughRenderExecutor`. + # TODO(b/110927854) Make this work with the offloading executor. + super(EGLContext, self).__init__(max_width, max_height, + executor.PassthroughRenderExecutor) + + def _platform_init(self, unused_max_width, unused_max_height): + """Initialization this EGL context.""" + num_configs = ctypes.c_long() + config_size = 1 + config = EGL.EGLConfig() + EGL.eglReleaseThread() + EGL.eglChooseConfig( + EGL_DISPLAY, + EGL_ATTRIBUTES, + ctypes.byref(config), + config_size, + num_configs) + if num_configs.value < 1: + raise RuntimeError( + 'EGL failed to find a framebuffer configuration that matches the ' + 'desired attributes: {}'.format(EGL_ATTRIBUTES)) + EGL.eglBindAPI(EGL.EGL_OPENGL_API) + self._context = EGL.eglCreateContext( + EGL_DISPLAY, config, EGL.EGL_NO_CONTEXT, None) + if not self._context: + raise RuntimeError('Cannot create an EGL context.') + + def _platform_make_current(self): + if self._context: + success = EGL.eglMakeCurrent( + EGL_DISPLAY, EGL.EGL_NO_SURFACE, EGL.EGL_NO_SURFACE, self._context) + if not success: + raise RuntimeError('Failed to make the EGL context current.') + + def _platform_free(self): + """Frees resources associated with this context.""" + if self._context: + current_context = EGL.eglGetCurrentContext() + if current_context and self._context.address == current_context.address: + EGL.eglMakeCurrent(EGL_DISPLAY, EGL.EGL_NO_SURFACE, + EGL.EGL_NO_SURFACE, EGL.EGL_NO_CONTEXT) + EGL.eglDestroyContext(EGL_DISPLAY, self._context) + self._context = None diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py new file mode 100644 index 0000000..4c2c54a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer.py @@ -0,0 +1,86 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""An OpenGL renderer backed by OSMesa.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from dm_control._render import base +from dm_control._render import constants + +PYOPENGL_PLATFORM = os.environ.get(constants.PYOPENGL_PLATFORM) + +if not PYOPENGL_PLATFORM: + os.environ[constants.PYOPENGL_PLATFORM] = constants.OSMESA +elif PYOPENGL_PLATFORM != constants.OSMESA: + raise ImportError( + 'Cannot use OSMesa rendering platform. ' + 'The PYOPENGL_PLATFORM environment variable is set to {!r} ' + '(should be either unset or {!r}).' + .format(PYOPENGL_PLATFORM, constants.OSMESA)) + +# pylint: disable=g-import-not-at-top +from OpenGL import GL +from OpenGL import osmesa +from OpenGL.GL import arrays + +_DEPTH_BITS = 24 +_STENCIL_BITS = 8 +_ACCUM_BITS = 0 + + +class OSMesaContext(base.ContextBase): + """An OpenGL context backed by OSMesa.""" + + def _platform_init(self, max_width, max_height): + """Initializes this OSMesa context.""" + self._context = osmesa.OSMesaCreateContextExt( + osmesa.OSMESA_RGBA, + _DEPTH_BITS, + _STENCIL_BITS, + _ACCUM_BITS, + None, # sharelist + ) + if not self._context: + raise RuntimeError('Failed to create OSMesa GL context.') + + self._height = max_height + self._width = max_width + + # Allocate a buffer to render into. + self._buffer = arrays.GLfloatArray.zeros((max_height, max_width, 4)) + + def _platform_make_current(self): + if self._context: + success = osmesa.OSMesaMakeCurrent( + self._context, + self._buffer, + GL.GL_FLOAT, + self._width, + self._height) + if not success: + raise RuntimeError('Failed to make OSMesa context current.') + + def _platform_free(self): + """Frees resources associated with this context.""" + if self._context and self._context == osmesa.OSMesaGetCurrentContext(): + osmesa.OSMesaMakeCurrent(None, None, GL.GL_FLOAT, 0, 0) + osmesa.OSMesaDestroyContext(self._context) + self._buffer = None + self._context = None diff --git a/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py new file mode 100644 index 0000000..774d7b8 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/_render/pyopengl/osmesa_renderer_test.py @@ -0,0 +1,75 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for OSMesaContext.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +# Internal dependencies. +from absl.testing import absltest +from dm_control import _render +import mock +from OpenGL import GL + +MAX_WIDTH = 640 +MAX_HEIGHT = 480 + +CONTEXT_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.osmesa' +GL_ARRAYS_PATH = _render.__name__ + '.pyopengl.osmesa_renderer.arrays' + + +@unittest.skipUnless( + _render.BACKEND == _render.constants.OSMESA, + reason='OSMesa backend not selected.') +class OSMesaContextTest(absltest.TestCase): + + def test_init(self): + mock_context = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_osmesa: + mock_osmesa.OSMesaCreateContextExt.return_value = mock_context + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + self.assertIs(renderer._context, mock_context) + renderer.free() + + def test_make_current(self): + mock_context = mock.MagicMock() + mock_buffer = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_osmesa: + with mock.patch(GL_ARRAYS_PATH) as mock_glarrays: + mock_osmesa.OSMesaCreateContextExt.return_value = mock_context + mock_glarrays.GLfloatArray.zeros.return_value = mock_buffer + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + with renderer.make_current(): + pass + mock_osmesa.OSMesaMakeCurrent.assert_called_once_with( + mock_context, mock_buffer, GL.GL_FLOAT, MAX_WIDTH, MAX_HEIGHT) + renderer.free() + + def test_freeing(self): + mock_context = mock.MagicMock() + with mock.patch(CONTEXT_PATH) as mock_osmesa: + mock_osmesa.OSMesaCreateContextExt.return_value = mock_context + renderer = _render.Renderer(MAX_WIDTH, MAX_HEIGHT) + renderer.free() + mock_osmesa.OSMesaDestroyContext.assert_called_once_with(mock_context) + self.assertIsNone(renderer._context) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/autowrap/__init__.py b/DMC/src/env/dm_control/dm_control/autowrap/__init__.py new file mode 100644 index 0000000..f9817d4 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + diff --git a/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py b/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py new file mode 100644 index 0000000..25fcdcf --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/autowrap.py @@ -0,0 +1,147 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +r"""Automatically generates ctypes Python bindings for MuJoCo. + +Parses mjdata.h, mjmodel.h, mjrender.h, mjvisualize.h, mjxmacro.h and mujoco.h; +generates the following Python source files: + + constants.py: constants + enums.py: enums + sizes.py: size information for dynamically-shaped arrays + types.py: ctypes declarations for structs + wrappers.py: low-level Python wrapper classes for structs (these implement + getter/setter methods for struct members where applicable) + functions.py: ctypes function declarations for MuJoCo API functions + +Example usage: + + autowrap --header_paths='/path/to/mjmodel.h /path/to/mjdata.h ...' \ + --output_dir=/path/to/mjbindings +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import io +import os + +# Internal dependencies. +from absl import app +from absl import flags +from absl import logging +from dm_control.autowrap import binding_generator +from dm_control.autowrap import codegen_util + +import six + +_MJMODEL_H = "mjmodel.h" +_MJXMACRO_H = "mjxmacro.h" + +FLAGS = flags.FLAGS + +flags.DEFINE_spaceseplist( + "header_paths", None, + "Space-separated list of paths to MuJoCo header files.") + +flags.DEFINE_string("output_dir", None, + "Path to output directory for wrapper source files.") + + +def main(unused_argv): + special_header_paths = {} + + # Get the path to the mjmodel and mjxmacro header files. + # These header files need special handling. + for header in (_MJMODEL_H, _MJXMACRO_H): + for path in FLAGS.header_paths: + if path.endswith(header): + special_header_paths[header] = path + break + if header not in special_header_paths: + logging.fatal("List of inputs must contain a path to %s", header) + + # Make sure mjmodel.h is parsed first, since it is included by other headers. + srcs = codegen_util.UniqueOrderedDict() + sorted_header_paths = sorted(FLAGS.header_paths) + sorted_header_paths.remove(special_header_paths[_MJMODEL_H]) + sorted_header_paths.insert(0, special_header_paths[_MJMODEL_H]) + for p in sorted_header_paths: + with io.open(p, "r", errors="ignore") as f: + srcs[p] = f.read() + + # consts_dict should be a codegen_util.UniqueOrderedDict. + # This is a temporary workaround due to the fact that the parser does not yet + # handle nested `#if define(predicate)` blocks, which results in some + # constants being parsed twice. We therefore can't enforce the uniqueness of + # the keys in `consts_dict`. As of MuJoCo v1.30 there is only a single problem + # block beginning on line 10 in mujoco.h, and a single constant is affected + # (MJAPI). + consts_dict = collections.OrderedDict() + + # These are commented in `mjdata.h` but have no macros in `mjxmacro.h`. + hints_dict = codegen_util.UniqueOrderedDict({"buffer": ("nbuffer",), + "stack": ("nstack",)}) + + parser = binding_generator.BindingGenerator( + consts_dict=consts_dict, hints_dict=hints_dict) + + # Parse enums. + for pth, src in six.iteritems(srcs): + if pth is not special_header_paths[_MJXMACRO_H]: + parser.parse_enums(src) + + # Parse constants and type declarations. + for pth, src in six.iteritems(srcs): + if pth is not special_header_paths[_MJXMACRO_H]: + parser.parse_consts_typedefs(src) + + # Get shape hints from mjxmacro.h. + parser.parse_hints(srcs[special_header_paths[_MJXMACRO_H]]) + + # Parse structs and function pointer type declarations. + for pth, src in six.iteritems(srcs): + if pth is not special_header_paths[_MJXMACRO_H]: + parser.parse_structs_and_function_pointer_typedefs(src) + + # Parse functions. + for pth, src in six.iteritems(srcs): + if pth is not special_header_paths[_MJXMACRO_H]: + parser.parse_functions(src) + + # Parse global strings and function pointers. + for pth, src in six.iteritems(srcs): + if pth is not special_header_paths[_MJXMACRO_H]: + parser.parse_global_strings(src) + parser.parse_function_pointers(src) + + # Create the output directory if it doesn't already exist. + if not os.path.exists(FLAGS.output_dir): + os.makedirs(FLAGS.output_dir) + + # Generate Python source files and write them to the output directory. + parser.write_consts(os.path.join(FLAGS.output_dir, "constants.py")) + parser.write_enums(os.path.join(FLAGS.output_dir, "enums.py")) + parser.write_types(os.path.join(FLAGS.output_dir, "types.py")) + parser.write_wrappers(os.path.join(FLAGS.output_dir, "wrappers.py")) + parser.write_funcs_and_globals(os.path.join(FLAGS.output_dir, "functions.py")) + parser.write_index_dict(os.path.join(FLAGS.output_dir, "sizes.py")) + +if __name__ == "__main__": + flags.mark_flag_as_required("header_paths") + flags.mark_flag_as_required("output_dir") + app.run(main) diff --git a/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py b/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py new file mode 100644 index 0000000..74bb30b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/binding_generator.py @@ -0,0 +1,597 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Parses MuJoCo header files and generates Python bindings.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pprint +import textwrap + +from absl import logging +from dm_control.autowrap import c_declarations +from dm_control.autowrap import codegen_util +from dm_control.autowrap import header_parsing +import pyparsing +import six + +# Absolute path to the top-level module. +_MODULE = "dm_control.mujoco.wrapper" + +# Imports used in all generated source files. +_BOILERPLATE_IMPORTS = [ + "from __future__ import absolute_import", + "from __future__ import division", + "from __future__ import print_function\n", +] + + +class Error(Exception): + pass + + +class BindingGenerator(object): + """Parses declarations from MuJoCo headers and generates Python bindings.""" + + def __init__(self, + enums_dict=None, + consts_dict=None, + typedefs_dict=None, + hints_dict=None, + types_dict=None, + funcs_dict=None, + strings_dict=None, + func_ptrs_dict=None, + index_dict=None): + """Constructs a new HeaderParser instance. + + The optional arguments listed below can be used to passing in dict-like + objects specifying pre-defined declarations. By default empty + UniqueOrderedDicts will be instantiated and then populated according to the + contents of the headers. + + Args: + enums_dict: Nested mappings from {enum_name: {member_name: value}}. + consts_dict: Mapping from {const_name: value}. + typedefs_dict: Mapping from {type_name: ctypes_typename}. + hints_dict: Mapping from {var_name: shape_tuple}. + types_dict: Mapping from {type_name: type_instance}. + funcs_dict: Mapping from {func_name: Function_instance}. + strings_dict: Mapping from {var_name: StaticStringArray_instance}. + func_ptrs_dict: Mapping from {var_name: FunctionPtr_instance}. + index_dict: Mapping from {lowercase_struct_name: {var_name: shape_tuple}}. + """ + self.enums_dict = (enums_dict if enums_dict is not None + else codegen_util.UniqueOrderedDict()) + self.consts_dict = (consts_dict if consts_dict is not None + else codegen_util.UniqueOrderedDict()) + self.typedefs_dict = (typedefs_dict if typedefs_dict is not None + else codegen_util.UniqueOrderedDict()) + self.hints_dict = (hints_dict if hints_dict is not None + else codegen_util.UniqueOrderedDict()) + self.types_dict = (types_dict if types_dict is not None + else codegen_util.UniqueOrderedDict()) + self.funcs_dict = (funcs_dict if funcs_dict is not None + else codegen_util.UniqueOrderedDict()) + self.strings_dict = (strings_dict if strings_dict is not None + else codegen_util.UniqueOrderedDict()) + self.func_ptrs_dict = (func_ptrs_dict if func_ptrs_dict is not None + else codegen_util.UniqueOrderedDict()) + self.index_dict = (index_dict if index_dict is not None + else codegen_util.UniqueOrderedDict()) + + def get_consts_and_enums(self): + consts_and_enums = self.consts_dict.copy() + for enum in six.itervalues(self.enums_dict): + consts_and_enums.update(enum) + return consts_and_enums + + def resolve_size(self, old_size): + """Resolves an array size identifier. + + The following conversions will be attempted: + + * If `old_size` is an integer it will be returned as-is. + * If `old_size` is a string of the form `"3"` it will be cast to an int. + * If `old_size` is a string in `self.consts_dict` then the value of the + constant will be returned. + * If `old_size` is a string of the form `"3*constant_name"` then the + result of `3*constant_value` will be returned. + * If `old_size` is a string that does not specify an int constant and + cannot be cast to an int (e.g. an identifier for a dynamic dimension, + such as `"ncontact"`) then it will be returned as-is. + + Args: + old_size: An int or string. + + Returns: + An int or string. + """ + if isinstance(old_size, int): + return old_size # If it's already an int then there's nothing left to do + elif "*" in old_size: + # If it's a string specifying a product (such as "2*mjMAXLINEPNT"), + # recursively resolve the components to ints and calculate the result. + size = 1 + for part in old_size.split("*"): + dim = self.resolve_size(part) + assert isinstance(dim, int) + size *= dim + return size + else: + # Recursively dereference any sizes declared in header macros + size = codegen_util.recursive_dict_lookup(old_size, + self.get_consts_and_enums()) + # Try to coerce the result to an int, return a string if this fails + return codegen_util.try_coerce_to_num(size, try_types=(int,)) + + def get_shape_tuple(self, old_size, squeeze=False): + """Generates a shape tuple from parser results. + + Args: + old_size: Either a `pyparsing.ParseResults`, or a valid int or string + input to `self.resolve_size` (see method docstring for further details). + squeeze: If True, any dimensions that are statically defined as 1 will be + removed from the shape tuple. + + Returns: + A shape tuple containing ints for dimensions that are statically defined, + and string size identifiers for dimensions that can only be determined at + runtime. + """ + if isinstance(old_size, pyparsing.ParseResults): + # For multi-dimensional arrays, convert each dimension separately + shape = tuple(self.resolve_size(dim) for dim in old_size) + else: + shape = (self.resolve_size(old_size),) + if squeeze: + shape = tuple(d for d in shape if d != 1) # Remove singleton dimensions + return shape + + def resolve_typename(self, old_ctypes_typename): + """Gets a qualified ctypes typename from typedefs_dict and C_TO_CTYPES.""" + + # Recursively dereference any typenames declared in self.typedefs_dict + new_ctypes_typename = codegen_util.recursive_dict_lookup( + old_ctypes_typename, self.typedefs_dict) + + # Try to convert to a ctypes native typename + new_ctypes_typename = header_parsing.C_TO_CTYPES.get( + new_ctypes_typename, new_ctypes_typename) + + if new_ctypes_typename == old_ctypes_typename: + logging.warning("Could not resolve typename '%s'", old_ctypes_typename) + + return new_ctypes_typename + + def get_type_from_token(self, token, parent=None): + """Accepts a token returned by a parser, returns a subclass of CDeclBase.""" + + comment = codegen_util.mangle_comment(token.comment) + is_const = token.is_const == "const" + + # An anonymous union declaration + if token.anonymous_union: + if not parent and parent.name: + raise Error( + "Anonymous unions must be members of a named struct or union.") + + # Generate a name based on the name of the parent. + name = codegen_util.mangle_varname(parent.name + "_anon_union") + + members = codegen_util.UniqueOrderedDict() + sub_structs = codegen_util.UniqueOrderedDict() + out = c_declarations.AnonymousUnion( + name, members, sub_structs, comment, parent) + + # Add members + for sub_token in token.members: + + # Recurse into nested structs + member = self.get_type_from_token(sub_token, parent=out) + out.members[member.name] = member + + # Nested sub-structures need special treatment + if isinstance(member, c_declarations.Struct): + out.sub_structs[member.name] = member + + # Add to dict of unions + self.types_dict[out.ctypes_typename] = out + + # A struct declaration + elif token.members: + + name = token.name + + # If the name is empty, see if there is a type declaration that matches + # this struct's typename + if not name: + for k, v in six.iteritems(self.typedefs_dict): + if v == token.typename: + name = k + + # Anonymous structs need a dummy typename + typename = token.typename + if not typename: + if parent: + typename = token.name + else: + raise Error( + "Anonymous structs that aren't members of a named struct are not " + "supported (name = '{token.name}').".format(token=token)) + + # Mangle the name if it contains any protected keywords + name = codegen_util.mangle_varname(name) + + members = codegen_util.UniqueOrderedDict() + sub_structs = codegen_util.UniqueOrderedDict() + out = c_declarations.Struct(name, typename, members, sub_structs, comment, + parent, is_const) + + # Map the old typename to the mangled typename in typedefs_dict + self.typedefs_dict[typename] = out.ctypes_typename + + # Add members + for sub_token in token.members: + + # Recurse into nested structs + member = self.get_type_from_token(sub_token, parent=out) + out.members[member.name] = member + + # Nested sub-structures need special treatment + if isinstance(member, c_declarations.Struct): + out.sub_structs[member.name] = member + + # Add to dict of structs + self.types_dict[out.ctypes_typename] = out + + else: + + name = codegen_util.mangle_varname(token.name) + typename = self.resolve_typename(token.typename) + + # 1D array with size defined at compile time + if token.size: + shape = self.get_shape_tuple(token.size) + if typename in {header_parsing.NONE, header_parsing.CTYPES_CHAR}: + out = c_declarations.StaticPtrArray( + name, typename, shape, comment, parent, is_const) + else: + out = c_declarations.StaticNDArray( + name, typename, shape, comment, parent, is_const) + + elif token.ptr: + + # Pointer to a numpy-compatible type, could be an array or a scalar + if typename in header_parsing.CTYPES_TO_NUMPY: + + # Multidimensional array (one or more dimensions might be undefined) + if name in self.hints_dict: + + # Dynamically-sized dimensions have string identifiers + shape = self.hints_dict[name] + if any(isinstance(d, six.string_types) for d in shape): + out = c_declarations.DynamicNDArray(name, typename, shape, + comment, parent, is_const) + else: + out = c_declarations.StaticNDArray(name, typename, shape, comment, + parent, is_const) + + # This must be a pointer to a scalar primitive + else: + out = c_declarations.ScalarPrimitivePtr(name, typename, comment, + parent, is_const) + + # Pointer to struct or other arbitrary type + else: + out = c_declarations.ScalarPrimitivePtr(name, typename, comment, + parent, is_const) + + # A struct we've already encountered + elif typename in self.types_dict: + s = self.types_dict[typename] + if isinstance(s, c_declarations.FunctionPtrTypedef): + out = c_declarations.FunctionPtr( + name, token.name, s.typename, comment) + else: + out = c_declarations.Struct(name, s.typename, s.members, + s.sub_structs, comment, parent) + + # Presumably this is a scalar primitive + else: + out = c_declarations.ScalarPrimitive(name, typename, comment, parent, + is_const) + + return out + + # Parsing functions. + # ---------------------------------------------------------------------------- + + def parse_hints(self, xmacro_src): + """Parses mjxmacro.h, update self.hints_dict.""" + parser = header_parsing.XMACRO + for tokens, _, _ in parser.scanString(xmacro_src): + for xmacro in tokens: + for member in xmacro.members: + # "Squeeze out" singleton dimensions. + shape = self.get_shape_tuple(member.dims, squeeze=True) + self.hints_dict.update({member.name: shape}) + + if codegen_util.is_macro_pointer(xmacro.name): + struct_name = codegen_util.macro_struct_name(xmacro.name) + if struct_name not in self.index_dict: + self.index_dict[struct_name] = {} + + self.index_dict[struct_name].update({member.name: shape}) + + def parse_enums(self, src): + """Parses mj*.h, update self.enums_dict.""" + parser = header_parsing.ENUM_DECL + for tokens, _, _ in parser.scanString(src): + for enum in tokens: + members = codegen_util.UniqueOrderedDict() + value = 0 + for member in enum.members: + # Leftward bitshift + if member.bit_lshift_a: + value = int(member.bit_lshift_a) << int(member.bit_lshift_b) + # Assignment + elif member.value: + value = int(member.value) + # Implicit count + else: + value += 1 + members.update({member.name: value}) + self.enums_dict.update({enum.name: members}) + + def parse_consts_typedefs(self, src): + """Updates self.consts_dict, self.typedefs_dict.""" + parser = (header_parsing.COND_DECL | + header_parsing.UNCOND_DECL) + for tokens, _, _ in parser.scanString(src): + self.recurse_into_conditionals(tokens) + + def recurse_into_conditionals(self, tokens): + """Called recursively within nested #if(n)def... #else... #endif blocks.""" + for token in tokens: + # Another nested conditional block + if token.predicate: + if (token.predicate in self.get_consts_and_enums() + and self.get_consts_and_enums()[token.predicate]): + self.recurse_into_conditionals(token.if_true) + else: + self.recurse_into_conditionals(token.if_false) + # One or more declarations + else: + if token.typename: + self.typedefs_dict.update({token.name: token.typename}) + elif token.value: + value = codegen_util.try_coerce_to_num(token.value) + # Avoid adding function aliases. + if isinstance(value, six.string_types): + continue + else: + self.consts_dict.update({token.name: value}) + else: + self.consts_dict.update({token.name: True}) + + def parse_structs_and_function_pointer_typedefs(self, src): + """Updates self.types_dict.""" + parser = (header_parsing.NESTED_STRUCTS | + header_parsing.FUNCTION_PTR_TYPE_DECL) + for tokens, _, _ in parser.scanString(src): + for token in tokens: + if token.return_type: + # This is a function type declaration. + self.types_dict[token.typename] = c_declarations.FunctionPtrTypedef( + token.typename, + self.get_type_from_token(token.return_type), + tuple(self.get_type_from_token(arg) for arg in token.arguments)) + else: + # This is a struct or a union. + self.get_type_from_token(token) + + def parse_functions(self, src): + """Updates self.funcs_dict.""" + parser = header_parsing.MJAPI_FUNCTION_DECL + for tokens, _, _ in parser.scanString(src): + for token in tokens: + name = codegen_util.mangle_varname(token.name) + comment = codegen_util.mangle_comment(token.comment) + if token.arguments: + args = codegen_util.UniqueOrderedDict() + for arg in token.arguments: + a = self.get_type_from_token(arg) + args[a.name] = a + else: + args = None + if token.return_value: + ret_val = self.get_type_from_token(token.return_value) + else: + ret_val = None + func = c_declarations.Function(name, args, ret_val, comment) + self.funcs_dict[func.name] = func + + def parse_global_strings(self, src): + """Updates self.strings_dict.""" + parser = header_parsing.MJAPI_STRING_ARRAY + for token, _, _ in parser.scanString(src): + name = codegen_util.mangle_varname(token.name) + shape = self.get_shape_tuple(token.dims) + self.strings_dict[name] = c_declarations.StaticStringArray( + name, shape, symbol_name=token.name) + + def parse_function_pointers(self, src): + """Updates self.func_ptrs_dict.""" + parser = header_parsing.MJAPI_FUNCTION_PTR + for token, _, _ in parser.scanString(src): + name = codegen_util.mangle_varname(token.name) + self.func_ptrs_dict[name] = c_declarations.FunctionPtr( + name, symbol_name=token.name, + type_name=token.typename, comment=token.comment) + + # Code generation methods + # ---------------------------------------------------------------------------- + + def make_header(self, imports=()): + """Returns a header string for an auto-generated Python source file.""" + docstring = textwrap.dedent(""" + \"\"\"Automatically generated by {scriptname:}. + + MuJoCo header version: {mujoco_version:} + \"\"\" + """.format(scriptname=os.path.split(__file__)[-1], + mujoco_version=self.consts_dict["mjVERSION_HEADER"])) + docstring = docstring[1:] # Strip the leading line break. + return "\n".join( + [docstring] + _BOILERPLATE_IMPORTS + list(imports) + ["\n"]) + + def write_consts(self, fname): + """Write constants.""" + imports = [ + "# pylint: disable=invalid-name", + ] + with open(fname, "w") as f: + f.write(self.make_header(imports)) + f.write(codegen_util.comment_line("Constants") + "\n") + for name, value in six.iteritems(self.consts_dict): + f.write("{0} = {1}\n".format(name, value)) + f.write("\n" + codegen_util.comment_line("End of generated code")) + + def write_enums(self, fname): + """Write enum definitions.""" + with open(fname, "w") as f: + imports = [ + "import collections", + "# pylint: disable=invalid-name", + "# pylint: disable=line-too-long", + ] + f.write(self.make_header(imports)) + f.write(codegen_util.comment_line("Enums")) + for enum_name, members in six.iteritems(self.enums_dict): + fields = ["\"{}\"".format(name) for name in six.iterkeys(members)] + values = [str(value) for value in six.itervalues(members)] + s = textwrap.dedent(""" + {0} = collections.namedtuple( + "{0}", + [{1}] + )({2}) + """).format(enum_name, ",\n ".join(fields), ", ".join(values)) + f.write(s) + f.write("\n" + codegen_util.comment_line("End of generated code")) + + def write_types(self, fname): + """Write ctypes struct and function type declarations.""" + imports = [ + "import ctypes", + ] + with open(fname, "w") as f: + f.write(self.make_header(imports)) + f.write(codegen_util.comment_line( + "ctypes struct, union, and function type declarations")) + for type_decl in six.itervalues(self.types_dict): + f.write("\n" + type_decl.ctypes_decl) + f.write("\n" + codegen_util.comment_line("End of generated code")) + + def write_wrappers(self, fname): + """Write wrapper classes for ctypes structs.""" + with open(fname, "w") as f: + imports = [ + "import ctypes", + "# pylint: disable=undefined-variable", + "# pylint: disable=wildcard-import", + "from {} import util".format(_MODULE), + "from {}.mjbindings.types import *".format(_MODULE), + ] + f.write(self.make_header(imports)) + f.write(codegen_util.comment_line("Low-level wrapper classes")) + for type_decl in six.itervalues(self.types_dict): + if isinstance(type_decl, c_declarations.Struct): + f.write("\n" + type_decl.wrapper_class) + f.write("\n" + codegen_util.comment_line("End of generated code")) + + def write_funcs_and_globals(self, fname): + """Write ctypes declarations for functions and global data.""" + imports = [ + "import collections", + "import ctypes", + "# pylint: disable=undefined-variable", + "# pylint: disable=wildcard-import", + "from {} import util".format(_MODULE), + "from {}.mjbindings.types import *".format(_MODULE), + "import numpy as np", + "# pylint: disable=line-too-long", + "# pylint: disable=invalid-name", + "# common_typos_disable", + ] + with open(fname, "w") as f: + f.write(self.make_header(imports)) + f.write("mjlib = util.get_mjlib()\n") + + f.write("\n" + codegen_util.comment_line("ctypes function declarations")) + for function in six.itervalues(self.funcs_dict): + f.write("\n" + function.ctypes_func_decl(cdll_name="mjlib")) + + # Only require strings for UI purposes. + f.write("\n" + codegen_util.comment_line("String arrays") + "\n") + for string_arr in six.itervalues(self.strings_dict): + f.write(string_arr.ctypes_var_decl(cdll_name="mjlib")) + + f.write("\n" + codegen_util.comment_line("Callback function pointers")) + + fields = ["'_{0}'".format(func_ptr.name) + for func_ptr in self.func_ptrs_dict.values()] + values = [func_ptr.ctypes_var_decl(cdll_name="mjlib") + for func_ptr in self.func_ptrs_dict.values()] + f.write(textwrap.dedent(""" + class _Callbacks(object): + + __slots__ = [ + {0} + ] + + def __init__(self): + {1} + """).format(",\n ".join(fields), "\n ".join(values))) + + indent = codegen_util.Indenter() + with indent: + for func_ptr in self.func_ptrs_dict.values(): + f.write(indent(func_ptr.getters_setters_with_custom_prefix("self._"))) + + f.write("\n\ncallbacks = _Callbacks() # pylint: disable=invalid-name") + f.write("\ndel _Callbacks\n") + + f.write("\n" + codegen_util.comment_line("End of generated code")) + + def write_index_dict(self, fname): + """Write file containing array shape information for indexing.""" + pp = pprint.PrettyPrinter() + output_string = pp.pformat(dict(self.index_dict)) + indent = codegen_util.Indenter() + imports = [ + "# pylint: disable=bad-continuation", + "# pylint: disable=line-too-long", + ] + with open(fname, "w") as f: + f.write(self.make_header(imports)) + f.write("array_sizes = (\n") + with indent: + f.write(output_string) + f.write("\n)") + f.write("\n" + codegen_util.comment_line("End of generated code")) diff --git a/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py b/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py new file mode 100644 index 0000000..fa532ac --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/c_declarations.py @@ -0,0 +1,520 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Python representations of C declarations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import textwrap +from dm_control.autowrap import codegen_util +from dm_control.autowrap import header_parsing +import six + + +class CDeclBase(object): + """Base class for Python representations of C declarations.""" + + def __init__(self, **attrs): + self._attrs = attrs + for k, v in six.iteritems(attrs): + setattr(self, k, v) + + def __repr__(self): + """Pretty string representation.""" + attr_str = ", ".join("{0}={1!r}".format(k, v) + for k, v in six.iteritems(self._attrs)) + return "{0}({1})".format(type(self).__name__, attr_str) + + @property + def docstring(self): + """Auto-generate a docstring for self.""" + return "\n".join(textwrap.wrap(self.comment, 74)) + + @property + def ctypes_typename(self): + """ctypes typename.""" + return self.typename + + @property + def ctypes_ptr(self): + """String representation of self as a ctypes pointer.""" + return header_parsing.CTYPES_PTRS.get( + self.ctypes_typename, "ctypes.POINTER({})".format(self.ctypes_typename)) + + @property + def np_dtype(self): + """Get a numpy dtype name for self, fall back on self.ctypes_typename.""" + return header_parsing.CTYPES_TO_NUMPY.get(self.ctypes_typename, + self.ctypes_typename) + + @property + def np_flags(self): + """Tuple of strings specifying numpy.ndarray flags.""" + return ("C", "W") + + +class Struct(CDeclBase): + """C struct declaration.""" + + def __init__(self, name, typename, members, sub_structs, comment="", + parent=None, is_const=None): + super(Struct, self).__init__(name=name, + typename=typename, + members=members, + sub_structs=sub_structs, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def ctypes_decl(self): + """Generates a ctypes.Structure declaration for self.""" + indent = codegen_util.Indenter() + lines = [] + lines.append(textwrap.dedent(""" + class {0.ctypes_typename}(ctypes.Structure): + \"\"\"{0.docstring}\"\"\"""".format(self))) + anonymous_fields = [member.name for member in six.itervalues(self.members) + if isinstance(member, AnonymousUnion)] + with indent: + if anonymous_fields: + lines.append(indent("_anonymous_ = [")) + with indent: + with indent: + for name in anonymous_fields: + lines.append(indent("'" + name + "',")) + lines.append(indent("]")) + + if self.members: + lines.append(indent("_fields_ = [")) + with indent: + with indent: + for member in six.itervalues(self.members): + lines.append(indent(member.ctypes_field_decl + ",")) + lines.append(indent("]\n")) + return "\n".join(lines) + + @property + def ctypes_typename(self): + """Mangles ctypes.Structure typenames to distinguish them from wrappers.""" + return codegen_util.mangle_struct_typename(self.typename) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute + + @property + def wrapper_name(self): + return codegen_util.camel_case(self.typename) + "Wrapper" + + @property + def wrapper_class(self): + """Generates a Python class containing getter/setter methods for members.""" + indent = codegen_util.Indenter() + lines = [textwrap.dedent(""" + class {0.wrapper_name}(util.WrapperBase): + \"\"\"{0.docstring}\"\"\"""".format(self))] + with indent: + for member in six.itervalues(self.members): + if isinstance(member, AnonymousUnion): + for submember in six.itervalues(member.members): + lines.append(indent(submember.getters_setters)) + else: + lines.append(indent(member.getters_setters)) + lines.append("") # Add an extra newline at the end of the class definition. + return "\n".join(lines) + + @property + def getters_setters(self): + """Populates a Python class with getter & setter methods for self.""" + return textwrap.dedent(""" + @util.CachedProperty + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return {0.wrapper_name}(ctypes.pointer(self._ptr.contents.{0.name}))""" # pylint: disable=missing-format-attribute + .format(self)) + + @property + def arg(self): + """String representation of self as a ctypes function argument.""" + return self.ctypes_typename + + +class AnonymousUnion(CDeclBase): + """Anonymous union declaration.""" + + def __init__(self, name, members, sub_structs, comment="", parent=None): + super(AnonymousUnion, self).__init__(name=name, + members=members, + sub_structs=sub_structs, + comment=comment, + parent=parent) + + @property + def ctypes_decl(self): + """Generates a ctypes.Union declaration for self.""" + indent = codegen_util.Indenter() + lines = [] + lines.append(textwrap.dedent(""" + class {0.ctypes_typename}(ctypes.Union): + \"\"\"{0.docstring}\"\"\"""".format(self))) + with indent: + if self.members: + lines.append(indent("_fields_ = [")) + with indent: + with indent: + for member in six.itervalues(self.members): + lines.append(indent(member.ctypes_field_decl + ",")) + lines.append(indent("]\n")) + return "\n".join(lines) + + @property + def ctypes_typename(self): + """Mangles ctypes.Union typenames to distinguish them from wrappers.""" + return codegen_util.mangle_struct_typename(self.name) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute + + +class ScalarPrimitive(CDeclBase): + """A scalar value corresponding to a C primitive type.""" + + def __init__(self, name, typename, comment="", parent=None, is_const=None): + super(ScalarPrimitive, self).__init__(name=name, + typename=typename, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_typename})".format(self) # pylint: disable=missing-format-attribute + + @property + def getters_setters(self): + """Populates a Python class with getter & setter methods for self.""" + return textwrap.dedent(""" + @property + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return self._ptr.contents.{0.name} + + @{0.name}.setter + def {0.name}(self, value): + self._ptr.contents.{0.name} = value""".format(self)) # pylint: disable=missing-format-attribute + + @property + def arg(self): + """String representation of self as a ctypes function argument.""" + return self.ctypes_typename + + +class ScalarPrimitivePtr(CDeclBase): + """Pointer to a ScalarPrimitive.""" + + def __init__(self, name, typename, comment="", parent=None, is_const=None): + super(ScalarPrimitivePtr, self).__init__(name=name, + typename=typename, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute + + @property + def getters_setters(self): + """Populates a Python class with getter & setter methods for self.""" + return textwrap.dedent(""" + @property + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return self._ptr.contents.{0.name} + + @{0.name}.setter + def {0.name}(self, value): + self._ptr.contents.{0.name} = value""".format(self)) # pylint: disable=missing-format-attribute + + @property + def arg(self): + """Generates string representation of self as a ctypes function argument.""" + # we assume that every pointer that maps to a numpy dtype corresponds to an + # array argument/return value + if self.ctypes_typename in header_parsing.CTYPES_TO_NUMPY: + return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s})" + .format(self)) # pylint: disable=missing-format-attribute + else: + return self.ctypes_ptr + + +class StaticPtrArray(CDeclBase): + """Array of arbitrary pointers whose size can be inferred from the headers.""" + + def __init__(self, name, typename, shape, comment="", parent=None, + is_const=None): + super(StaticPtrArray, self).__init__(name=name, + typename=typename, + shape=shape, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + if self.typename in header_parsing.CTYPES_PTRS: + return "('{0.name}', {0.ctypes_ptr} * {1})".format( # pylint: disable=missing-format-attribute + self, " * ".join(str(d) for d in self.shape)) + else: + return "('{0.name}', {0.ctypes_typename} * {1})".format( # pylint: disable=missing-format-attribute + self, " * ".join(str(d) for d in self.shape)) + + @property + def getters_setters(self): + """Populates a Python class with getter & setter methods for self.""" + return textwrap.dedent(""" + @property + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return self._ptr.contents.{0.name}""".format(self)) # pylint: disable=missing-format-attribute + + @property + def arg(self): + """Generates string representation of self as a ctypes function argument.""" + return "{0.ctypes_typename}".format(self) + + +class StaticNDArray(CDeclBase): + """Numeric array whose dimensions can all be inferred from the headers.""" + + def __init__(self, name, typename, shape, comment="", parent=None, + is_const=None): + super(StaticNDArray, self).__init__(name=name, + typename=typename, + shape=shape, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_typename} * ({1}))".format( # pylint: disable=missing-format-attribute + self, " * ".join(str(d) for d in self.shape)) + + @property + def getters_setters(self): + """Populates a Python class with a getter method for self (no setter).""" + return textwrap.dedent(""" + @util.CachedProperty + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return util.buf_to_npy(self._ptr.contents.{0.name}, {0.shape!s})""" # pylint: disable=missing-format-attribute + .format(self)) + + @property + def arg(self): + """Generates string representation of self as a ctypes function argument.""" + return ("util.ndptr(shape={0.shape}, dtype={0.np_dtype}, " # pylint: disable=missing-format-attribute + "flags={0.np_flags!s})".format(self)) + + +class DynamicNDArray(CDeclBase): + """Numeric array where one or more dimensions are determined at runtime.""" + + def __init__(self, name, typename, shape, comment="", parent=None, + is_const=None): + super(DynamicNDArray, self).__init__(name=name, + typename=typename, + shape=shape, + comment=comment, + parent=parent, + is_const=is_const) + + @property + def runtime_shape_str(self): + """String representation of shape tuple at runtime.""" + rs = [] + for d in self.shape: + # dynamically-sized dimension + if isinstance(d, six.string_types): + if self.parent and d in self.parent.members: + rs.append("self.{}".format(d)) + else: + rs.append("self._model.{}".format(d)) + # static dimension + else: + rs.append(str(d)) + return str(tuple(rs)).replace("'", "") # strip quotes from string rep + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.ctypes_ptr})".format(self) # pylint: disable=missing-format-attribute + + @property + def getters_setters(self): + """Populates a Python class with a getter method for self (no setter).""" + return textwrap.dedent(""" + @util.CachedProperty + def {0.name}(self): + \"\"\"{0.docstring}\"\"\" + return util.buf_to_npy(self._ptr.contents.{0.name}, + {0.runtime_shape_str})""".format(self)) # pylint: disable=missing-format-attribute + + @property + def arg(self): + """Generates string representation of self as a ctypes function argument.""" + return ("util.ndptr(dtype={0.np_dtype}, flags={0.np_flags!s})" + .format(self)) # pylint: disable=missing-format-attribute + + +class Function(CDeclBase): + """A function declaration including input type(s) and return type.""" + + def __init__(self, name, arguments, return_value, comment=""): + super(Function, self).__init__(name=name, + arguments=arguments, + return_value=return_value, + comment=comment) + + def ctypes_func_decl(self, cdll_name): + """Generates a ctypes function declaration.""" + indent = codegen_util.Indenter() + lines = [] + lines.append("{0}.{1}.__doc__ = \"\"\"\n{2}\"\"\"".format( + cdll_name, self.name, self.docstring)) + if self.arguments: + lines.append("{0}.{1}.argtypes = [".format(cdll_name, self.name)) + with indent: + with indent: + lines.extend(indent(a.arg + ",") + for a in six.itervalues(self.arguments)) + lines.append("]") + else: + lines.append("{0}.{1}.argtypes = None".format(cdll_name, self.name)) + if self.return_value: + lines.append("{0}.{1}.restype = {2}".format( + cdll_name, self.name, self.return_value.arg)) + else: + lines.append("{0}.{1}.restype = None".format(cdll_name, self.name)) + lines.append("") # Force a newline after the declaration. + return "\n".join(lines) + + @property + def docstring(self): + """Generates a docstring.""" + indent = codegen_util.Indenter() + lines = textwrap.wrap(self.comment, width=80) + if self.arguments: + lines.append("\nArgs:") + with indent: + for a in six.itervalues(self.arguments): + s = "{a.name}: {a.arg}{const}".format( + a=a, const=(" " if a.is_const else "")) + lines.append(indent(s)) + if self.return_value: + lines.append("\nReturns:") + with indent: + lines.append(indent(self.return_value.arg)) + lines.append("") # Force a newline at the end of the docstring. + return "\n".join(lines) + + +class StaticStringArray(CDeclBase): + """A string array of fixed dimensions exported by MuJoCo.""" + + def __init__(self, name, shape, symbol_name): + super(StaticStringArray, self).__init__(name=name, + shape=shape, + symbol_name=symbol_name) + + def ctypes_var_decl(self, cdll_name=""): + """Generates a ctypes export statement.""" + + ptr_str = "ctypes.c_char_p" + for dim in self.shape[::-1]: + ptr_str = "({0} * {1!s})".format(ptr_str, dim) + + return "{0} = {1}.in_dll({2}, {3!r})\n".format( + self.name, ptr_str, cdll_name, self.symbol_name) + + +class FunctionPtrTypedef(CDeclBase): + """A type declaration for a C function pointer.""" + + def __init__(self, typename, return_type, argument_types): + super(FunctionPtrTypedef, self).__init__( + typename=typename, return_type=return_type, + argument_types=argument_types) + + @property + def ctypes_decl(self): + """Generates a ctypes.CFUNCTYPE declaration for self.""" + types = (self.return_type,) + self.argument_types + types_decl = ", ".join(t.arg for t in types) + return "{0} = ctypes.CFUNCTYPE({1})".format(self.typename, types_decl) + + +class FunctionPtr(CDeclBase): + """A pointer to an externally defined C function.""" + + def __init__(self, name, symbol_name, type_name, comment=""): + super(FunctionPtr, self).__init__( + name=name, symbol_name=symbol_name, + type_name=type_name, comment=comment) + + @property + def ctypes_field_decl(self): + """Generates a declaration for self as a field of a ctypes.Structure.""" + return "('{0.name}', {0.type_name})".format(self) # pylint: disable=missing-format-attribute + + def ctypes_var_decl(self, cdll_name=""): + """Generates a ctypes export statement.""" + + return "self._{0} = ctypes.c_void_p.in_dll({1}, {2!r})".format( + self.name, cdll_name, self.symbol_name) + + def getters_setters_with_custom_prefix(self, prefix): + return textwrap.dedent(""" + @property + def {0.name}(self): + if {1}{0.name}.value: + return {0.type_name}({1}{0.name}.value) + else: + return None + + @{0.name}.setter + def {0.name}(self, value): + new_func_ptr, wrapped_pyfunc = util.cast_func_to_c_void_p( + value, {0.type_name}) + # Prevents wrapped_pyfunc from being inadvertently garbage collected. + {1}{0.name}._wrapped_pyfunc = wrapped_pyfunc + {1}{0.name}.value = new_func_ptr.value + """.format(self, prefix)) # pylint: disable=missing-format-attribute + + @property + def getters_setters(self): + """Populates a Python class with getter & setter methods for self.""" + return self.getters_setters_with_custom_prefix(prefix="self._ptr.contents.") diff --git a/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py b/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py new file mode 100644 index 0000000..1565a57 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/codegen_util.py @@ -0,0 +1,154 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Misc helper functions needed by autowrap.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import keyword +import re + +import six +from six.moves import builtins + +_MJXMACRO_SUFFIX = "_POINTERS" +_PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist + dir(builtins)) +if not six.PY2: + _PYTHON_RESERVED_KEYWORDS.add("buffer") + + +class Indenter(object): + r"""Callable context manager for tracking string indentation levels. + + Args: + level: The initial indentation level. + indent_str: The string used to indent each line. + + Example usage: + + ```python + idt = Indenter() + s = idt("level 0\n") + with idt: + s += idt("level 1\n") + with idt: + s += idt("level 2\n") + s += idt("level 1 again\n") + s += idt("back to level 0\n") + print(s) + ``` + """ + + def __init__(self, level=0, indent_str=" "): + self.indent_str = indent_str + self.level = level + + def __enter__(self): + self.level += 1 + return self + + def __exit__(self, type_, value, traceback): + self.level -= 1 + + def __call__(self, string): + return indent(string, self.level, self.indent_str) + + +def indent(s, n=1, indent_str=" "): + """Inserts `n * indent_str` at the start of each non-empty line in `s`.""" + p = n * indent_str + return "".join((p + l) if l.lstrip() else l for l in s.splitlines(True)) + + +class UniqueOrderedDict(collections.OrderedDict): + """Subclass of `OrderedDict` that enforces the uniqueness of keys.""" + + def __setitem__(self, k, v): + if k in self: + raise ValueError("Key '{}' already exists.".format(k)) + super(UniqueOrderedDict, self).__setitem__(k, v) + + +def macro_struct_name(name, suffix=None): + """Converts mjxmacro struct names, e.g. "MJDATA_POINTERS" to "mjdata".""" + if suffix is None: + suffix = _MJXMACRO_SUFFIX + return name[:-len(suffix)].lower() + + +def is_macro_pointer(name): + """Returns True if the mjxmacro struct name contains pointer sizes.""" + return name.endswith(_MJXMACRO_SUFFIX) + + +def mangle_varname(s): + """Append underscores to ensure that `s` is not a reserved Python keyword.""" + while s in _PYTHON_RESERVED_KEYWORDS: + s += "_" + return s + + +def mangle_struct_typename(s): + """Strip leading underscores and make uppercase.""" + return s.lstrip("_").upper() + + +def mangle_comment(s): + """Strip extraneous whitespace, add full-stops at end of each line.""" + if not isinstance(s, six.string_types): + return "\n".join(mangle_comment(line) for line in s) + elif not s: + return "." + else: + out = "\n".join(" ".join(line.split()) for line in s.splitlines()) + if not out.endswith("."): + out += "." + return out + + +def camel_case(s): + """Convert a snake_case string (maybe with lowerCaseFirst) to CamelCase.""" + tokens = re.sub(r"([A-Z])", r" \1", s.replace("_", " ")).split() + return "".join(w.title() for w in tokens) + + +def try_coerce_to_num(s, try_types=(int, float)): + """Try to coerce string to Python numeric type, return None if empty.""" + if not s: + return None + for try_type in try_types: + try: + return try_type(s.rstrip("UuFf")) + except (ValueError, AttributeError): + continue + return s + + +def recursive_dict_lookup(key, try_dict, max_depth=10): + """Recursively map dictionary keys to values.""" + if max_depth < 0: + raise KeyError("Maximum recursion depth exceeded") + while key in try_dict: + key = try_dict[key] + return recursive_dict_lookup(key, try_dict, max_depth - 1) + return key + + +def comment_line(string, width=79, fill_char="-"): + """Wraps `string` in a padded comment line.""" + return "# {0:{2}^{1}}\n".format(string, width - 2, fill_char) diff --git a/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py b/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py new file mode 100644 index 0000000..002f109 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/autowrap/header_parsing.py @@ -0,0 +1,318 @@ +# Copyright 2017 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""pyparsing definitions and helper functions for parsing MuJoCo headers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pyparsing as pp +import six +from six.moves import map + +# NB: Don't enable parser memoization (`pp.ParserElement.enablePackrat()`), +# since this results in a ~6x slowdown. + + +NONE = "None" +CTYPES_CHAR = "ctypes.c_char" + +C_TO_CTYPES = { + # integers + "int": "ctypes.c_int", + "unsigned int": "ctypes.c_uint", + "char": CTYPES_CHAR, + "unsigned char": "ctypes.c_ubyte", + "size_t": "ctypes.c_size_t", + # floats + "float": "ctypes.c_float", + "double": "ctypes.c_double", + # pointers + "void": NONE, +} + +CTYPES_PTRS = {NONE: "ctypes.c_void_p"} + +CTYPES_TO_NUMPY = { + # integers + "ctypes.c_int": "np.intc", + "ctypes.c_uint": "np.uintc", + "ctypes.c_ubyte": "np.ubyte", + # floats + "ctypes.c_float": "np.float32", + "ctypes.c_double": "np.float64", +} + +# Helper functions for constructing recursive parsers. +# ------------------------------------------------------------------------------ + + +def _nested_scopes(opening, closing, body): + """Constructs a parser for (possibly nested) scopes.""" + scope = pp.Forward() + scope << pp.Group( # pylint: disable=expression-not-assigned + opening + + pp.ZeroOrMore(body | scope)("members") + + closing) + return scope + + +def _nested_if_else(if_, pred, else_, endif, match_if_true, match_if_false): + """Constructs a parser for (possibly nested) if...(else)...endif blocks.""" + ifelse = pp.Forward() + ifelse << pp.Group( # pylint: disable=expression-not-assigned + if_ + + pred("predicate") + + pp.ZeroOrMore(match_if_true | ifelse)("if_true") + + pp.Optional(else_ + + pp.ZeroOrMore(match_if_false | ifelse)("if_false")) + + endif) + return ifelse + + +# Some common string patterns to suppress. +# ------------------------------------------------------------------------------ +(X, LPAREN, RPAREN, LBRACK, RBRACK, LBRACE, RBRACE, SEMI, COMMA, EQUAL, FSLASH, + BSLASH) = list(map(pp.Suppress, "X()[]{};,=/\\")) +EOL = pp.LineEnd().suppress() + +# Comments, continuation. +# ------------------------------------------------------------------------------ +COMMENT = pp.Combine( + pp.Suppress("//") + + pp.Optional(pp.White()).suppress() + + pp.SkipTo(pp.LineEnd())) + +MULTILINE_COMMENT = pp.delimitedList( + COMMENT.copy().setWhitespaceChars(" \t"), delim=EOL) + +CONTINUATION = (BSLASH + pp.LineEnd()).suppress() + +# Preprocessor directives. +# ------------------------------------------------------------------------------ +DEFINE = pp.Keyword("#define").suppress() +IFDEF = pp.Keyword("#ifdef").suppress() +IFNDEF = pp.Keyword("#ifndef").suppress() +ELSE = pp.Keyword("#else").suppress() +ENDIF = pp.Keyword("#endif").suppress() + +# Variable names, types, literals etc. +# ------------------------------------------------------------------------------ +NAME = pp.Word(pp.alphanums + "_") +INT = pp.Word(pp.nums + "UuLl") +FLOAT = pp.Word(pp.nums + ".+-EeFf") +NUMBER = FLOAT | INT + +# Dimensions can be of the form `[3]`, `[constant_name]` or `[2*constant_name]` +ARRAY_DIM = pp.Combine( + LBRACK + + (INT | NAME) + + pp.Optional(pp.Literal("*")) + + pp.Optional(INT | NAME) + + RBRACK) + +PTR = pp.Literal("*") +EXTERN = pp.Keyword("extern") +NATIVE_TYPENAME = pp.MatchFirst( + [pp.Keyword(n) for n in six.iterkeys(C_TO_CTYPES)]) + +# Macros. +# ------------------------------------------------------------------------------ + +HDR_GUARD = DEFINE + "THIRD_PARTY_MUJOCO_HDRS_" + +# e.g. "#define mjUSEDOUBLE" +DEF_FLAG = pp.Group( + DEFINE + + NAME("name") + + (COMMENT("comment") | EOL)).ignore(HDR_GUARD) + +# e.g. "#define mjMINVAL 1E-14 // minimum value in any denominator" +DEF_CONST = pp.Group( + DEFINE + + NAME("name") + + (NUMBER | NAME)("value") + + (COMMENT("comment") | EOL)) + +# e.g. "X( mjtNum*, name_textadr, ntext, 1 )" +XMEMBER = pp.Group( + X + + LPAREN + + (NATIVE_TYPENAME | NAME)("typename") + + pp.Optional(PTR("ptr")) + + COMMA + + NAME("name") + + COMMA + + pp.delimitedList((INT | NAME), delim=COMMA)("dims") + + RPAREN) + +XMACRO = pp.Group( + pp.Optional(COMMENT("comment")) + + DEFINE + + NAME("name") + + CONTINUATION + + pp.delimitedList(XMEMBER, delim=CONTINUATION)("members")) + + +# Type/variable declarations. +# ------------------------------------------------------------------------------ +TYPEDEF = pp.Keyword("typedef").suppress() +STRUCT = pp.Keyword("struct") +UNION = pp.Keyword("union") +ENUM = pp.Keyword("enum").suppress() + +# e.g. "typedef unsigned char mjtByte; // used for true/false" +TYPE_DECL = pp.Group( + TYPEDEF + + pp.Optional(STRUCT) + + (NATIVE_TYPENAME | NAME)("typename") + + pp.Optional(PTR("ptr")) + + NAME("name") + + SEMI + + pp.Optional(COMMENT("comment"))) + +# Declarations of flags/constants/types. +UNCOND_DECL = DEF_FLAG | DEF_CONST | TYPE_DECL + +# Declarations inside (possibly nested) #if(n)def... #else... #endif... blocks. +COND_DECL = _nested_if_else(IFDEF, NAME, ELSE, ENDIF, UNCOND_DECL, UNCOND_DECL) +# Note: this doesn't work for '#if defined(FLAG)' blocks + +# e.g. "mjtNum gravity[3]; // gravitational acceleration" +STRUCT_MEMBER = pp.Group( + pp.Optional(STRUCT("struct")) + + (NATIVE_TYPENAME | NAME)("typename") + + pp.Optional(PTR("ptr")) + + NAME("name") + + pp.ZeroOrMore(ARRAY_DIM)("size") + + SEMI + + pp.Optional(COMMENT("comment"))) + +# Struct declaration within a union (non-nested). +UNION_STRUCT_DECL = pp.Group( + STRUCT("struct") + + pp.Optional(NAME("typename")) + + pp.Optional(COMMENT("comment")) + + LBRACE + + pp.OneOrMore(STRUCT_MEMBER)("members") + + RBRACE + + pp.Optional(NAME("name")) + + SEMI) + +ANONYMOUS_UNION_DECL = pp.Group( + pp.Optional(MULTILINE_COMMENT("comment")) + + UNION("anonymous_union") + + LBRACE + + pp.OneOrMore( + UNION_STRUCT_DECL | + STRUCT_MEMBER | + COMMENT.suppress())("members") + + RBRACE + + SEMI) + +# Multiple (possibly nested) struct declarations. +NESTED_STRUCTS = _nested_scopes( + opening=(STRUCT + + pp.Optional(NAME("typename")) + + pp.Optional(COMMENT("comment")) + + LBRACE), + closing=(RBRACE + pp.Optional(NAME("name")) + SEMI), + body=pp.OneOrMore( + STRUCT_MEMBER | + ANONYMOUS_UNION_DECL | + COMMENT.suppress())("members")) + +BIT_LSHIFT = INT("bit_lshift_a") + pp.Suppress("<<") + INT("bit_lshift_b") + +ENUM_LINE = pp.Group( + NAME("name") + + pp.Optional(EQUAL + (INT("value") ^ BIT_LSHIFT)) + + pp.Optional(COMMA) + + pp.Optional(COMMENT("comment"))) + +ENUM_DECL = pp.Group( + TYPEDEF + + ENUM + + NAME("typename") + + pp.Optional(COMMENT("comment")) + + LBRACE + + pp.OneOrMore(ENUM_LINE | COMMENT.suppress())("members") + + RBRACE + + pp.Optional(NAME("name")) + + SEMI) + +# Function declarations. +# ------------------------------------------------------------------------------ +MJAPI = pp.Keyword("MJAPI").suppress() +CONST = pp.Keyword("const") +VOID = pp.Group(pp.Keyword("void") + ~PTR).suppress() + +ARG = pp.Group( + pp.Optional(CONST("is_const")) + + (NATIVE_TYPENAME | NAME)("typename") + + pp.Optional(PTR("ptr")) + + NAME("name") + + pp.Optional(ARRAY_DIM("size"))) + +RET = pp.Group( + pp.Optional(CONST("is_const")) + + (NATIVE_TYPENAME | NAME)("typename") + + pp.Optional(PTR("ptr"))) + +FUNCTION_DECL = ( + (VOID | RET("return_value")) + + NAME("name") + + LPAREN + + (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) + + RPAREN + + SEMI) + +MJAPI_FUNCTION_DECL = pp.Group( + pp.Optional(MULTILINE_COMMENT("comment")) + + pp.LineStart() + + MJAPI + + FUNCTION_DECL) + +# e.g. +# // predicate function: set enable/disable based on item category +# typedef int (*mjfItemEnable)(int category, void* data); +FUNCTION_PTR_TYPE_DECL = pp.Group( + pp.Optional(MULTILINE_COMMENT("comment")) + + TYPEDEF + + RET("return_type") + + LPAREN + + PTR + + NAME("typename") + + RPAREN + + LPAREN + + (VOID | pp.delimitedList(ARG, delim=COMMA)("arguments")) + + RPAREN + + SEMI) + +# Global variables. +# ------------------------------------------------------------------------------ + +MJAPI_STRING_ARRAY = ( + MJAPI + + EXTERN + + CONST + + pp.Keyword("char") + + PTR + + NAME("name") + + pp.OneOrMore(ARRAY_DIM)("dims") + + SEMI) + +MJAPI_FUNCTION_PTR = MJAPI + EXTERN + NAME("typename") + NAME("name") + SEMI diff --git a/DMC/src/env/dm_control/dm_control/composer/__init__.py b/DMC/src/env/dm_control/dm_control/composer/__init__.py new file mode 100644 index 0000000..18d5037 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2018-2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module containing abstract base classes for Composer environments.""" + +from dm_control.composer.arena import Arena +from dm_control.composer.constants import * # pylint: disable=wildcard-import +from dm_control.composer.define import cached_property +from dm_control.composer.define import observable +from dm_control.composer.entity import Entity +from dm_control.composer.entity import FreePropObservableMixin +from dm_control.composer.entity import ModelWrapperEntity +from dm_control.composer.entity import Observables +from dm_control.composer.environment import Environment +from dm_control.composer.environment import EpisodeInitializationError +from dm_control.composer.environment import HOOK_NAMES +from dm_control.composer.initializer import Initializer +from dm_control.composer.robot import Robot +from dm_control.composer.task import NullTask +from dm_control.composer.task import Task diff --git a/DMC/src/env/dm_control/dm_control/composer/arena.py b/DMC/src/env/dm_control/dm_control/composer/arena.py new file mode 100644 index 0000000..1085faa --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/arena.py @@ -0,0 +1,52 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""The base empty arena that defines global settings for Composer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from dm_control import mjcf +from dm_control.composer import entity as entity_module + +_ARENA_XML_PATH = os.path.join(os.path.dirname(__file__), 'arena.xml') + + +class Arena(entity_module.Entity): + """The base empty arena that defines global settings for Composer.""" + + def _build(self, name=None): + """Initializes this arena. + + Args: + name: (optional) A string, the name of this arena. If `None`, use the + model name defined in the MJCF file. + """ + self._mjcf_root = mjcf.from_path(_ARENA_XML_PATH) + if name: + self._mjcf_root.model = name + + def add_free_entity(self, entity): + """Includes an entity in the arena as a free-moving body.""" + frame = self.attach(entity) + frame.add('freejoint') + return frame + + @property + def mjcf_model(self): + return self._mjcf_root diff --git a/DMC/src/env/dm_control/dm_control/composer/arena.xml b/DMC/src/env/dm_control/dm_control/composer/arena.xml new file mode 100644 index 0000000..e8a1774 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/arena.xml @@ -0,0 +1,11 @@ + + + + diff --git a/DMC/src/env/dm_control/dm_control/composer/constants.py b/DMC/src/env/dm_control/dm_control/composer/constants.py new file mode 100644 index 0000000..bf8616b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/constants.py @@ -0,0 +1,22 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module defining constant values for Composer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +SENSOR_SITES_GROUP = 4 diff --git a/DMC/src/env/dm_control/dm_control/composer/define.py b/DMC/src/env/dm_control/dm_control/composer/define.py new file mode 100644 index 0000000..f03f220 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/define.py @@ -0,0 +1,65 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Decorators for Entity methods returning elements and observables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import threading + + +class cached_property(property): # pylint: disable=invalid-name + """A property that is evaluated only once per object instance.""" + + def __init__(self, func, doc=None): + super(cached_property, self).__init__(fget=func, doc=doc) + self.lock = threading.RLock() + + def __get__(self, obj, cls): + if obj is None: + return self + name = self.fget.__name__ + obj_dict = obj.__dict__ + try: + # Try returning a precomputed value without locking first. + # Profiling shows that the lock takes up a non-trivial amount of time. + return obj_dict[name] + except KeyError: + # The value hasn't been computed, now we have to lock. + with self.lock: + try: + # Check again whether another thread has already computed the value. + return obj_dict[name] + except KeyError: + # Otherwise call the function, cache the result, and return it + return obj_dict.setdefault(name, self.fget(obj)) + + +# A decorator for base.Observables methods returning an observable. This +# decorator should be used by abstract base classes to indicate sub-classes need +# to implement a corresponding @observavble annotated method. +abstract_observable = abc.abstractproperty # pylint: disable=invalid-name + + +class observable(cached_property): # pylint: disable=invalid-name + """A decorator for base.Observables methods returning an observable. + + The body of the decorated function is evaluated at Entity construction time + and the observable is cached. + """ + pass diff --git a/DMC/src/env/dm_control/dm_control/composer/entity.py b/DMC/src/env/dm_control/dm_control/composer/entity.py new file mode 100644 index 0000000..c729101 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/entity.py @@ -0,0 +1,583 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module defining the abstract entity class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import os +import weakref + +from absl import logging +from dm_control import mjcf +from dm_control.composer import define +from dm_control.mujoco.wrapper import mjbindings +import numpy as np +import six + +_OPTION_KEYS = set(['update_interval', 'buffer_size', 'delay', 'aggregator', + 'corruptor', 'enabled']) + +_NO_ATTACHMENT_FRAME = 'No attachment frame found.' + + +# The component order differs from that used by the open-source `tf` package. +def _multiply_quaternions(quat1, quat2): + result = np.empty_like(quat1) + mjbindings.mjlib.mju_mulQuat(result, quat1, quat2) + return result + + +def _rotate_vector(vec, quat): + """Rotates a vector by the given quaternion.""" + result = np.empty_like(vec) + mjbindings.mjlib.mju_rotVecQuat(result, vec, quat) + return result + + +class _ObservableKeys(object): + """Helper object that implements the `observables.dict_keys` functionality.""" + + def __init__(self, entity, observables): + self._entity = entity + self._observables = observables + + def __getattr__(self, name): + try: + model_identifier = self._entity.mjcf_model.full_identifier + except AttributeError: + raise ValueError('cannot retrieve the full identifier of mjcf_model') + return os.path.join(model_identifier, name) + + def __dir__(self): + out = set(self._observables.keys()) + out.update(dir(super(_ObservableKeys, self))) + return list(out) + + +class Observables(object): + """Base-class for Entity observables. + + Subclasses should declare getter methods annotated with @define.observable + decorator and returning an observable object. + """ + + def __init__(self, entity): + self._entity = weakref.proxy(entity) + + self._observables = collections.OrderedDict() + self._keys_helper = _ObservableKeys(self._entity, self._observables) + + # Ensure consistent ordering. + for attr_name in sorted(dir(type(self))): + type_attr = getattr(type(self), attr_name) + if isinstance(type_attr, define.observable): + self._observables[attr_name] = getattr(self, attr_name) + + @property + def dict_keys(self): + return self._keys_helper + + def as_dict(self, fully_qualified=True): + """Returns an OrderedDict of observables belonging to this Entity. + + The returned observables will include any added using the _add_observable + method, as well as any generated by a method decorated with the + @define.observable annotation. + + Args: + fully_qualified: (bool) Whether the dict keys should be prefixed with the + parent entity's full model identifier. + """ + + if fully_qualified: + # We need to make sure that this property doesn't raise an AttributeError, + # otherwise __getattr__ is executed and we get a very funky error. + try: + model_identifier = self._entity.mjcf_model.full_identifier + except AttributeError: + raise ValueError('cannot retrieve the full identifier of mjcf_model') + + return collections.OrderedDict( + [(os.path.join(model_identifier, name), observable) + for name, observable in six.iteritems(self._observables)]) + else: + # Return a copy to prevent dict being edited. + return self._observables.copy() + + def get_observable(self, name, name_fully_qualified=False): + """Returns the observable with the given name. + + Args: + name: (str) The identifier of the observable. + name_fully_qualified: (bool) Whether the provided name is prefixed by the + model's full identifier. + """ + + if name_fully_qualified: + try: + model_identifier = self._entity.mjcf_model.full_identifier + except AttributeError: + raise ValueError('cannot retrieve the full identifier of mjcf_model') + return self._observables[name.replace(model_identifier, '')] + else: + return self._observables[name] + + def set_options(self, options): + """Configure Observables with an options dict. + + Args: + options: A dict of dicts of configuration options keyed on + observable names, or a dict of configuration options, which will + propagate those options to all observables. + """ + if options is None: + options = {} + elif options.keys() and set(options.keys()).issubset(_OPTION_KEYS): + options = dict([(key, options) for key in self._observables.keys()]) + + for obs_key, obs_options in six.iteritems(options): + try: + obs = self._observables[obs_key] + except KeyError: + raise KeyError('No observable with name {!r}'.format(obs_key)) + obs.configure(**obs_options) + + def enable_all(self): + """Enable all observables of this entity.""" + for obs in self._observables.values(): + obs.enabled = True + + def disable_all(self): + """Disable all observables of this entity.""" + for obs in self._observables.values(): + obs.enabled = False + + def add_observable(self, name, observable, enabled=True): + self._observables[name] = observable + self._observables[name].enabled = enabled + + +@six.add_metaclass(abc.ABCMeta) +class FreePropObservableMixin(object): + """Enforce observables of a free-moving object.""" + + @abc.abstractproperty + def position(self): + pass + + @abc.abstractproperty + def orientation(self): + pass + + @abc.abstractproperty + def linear_velocity(self): + pass + + @abc.abstractproperty + def angular_velocity(self): + pass + + +@six.add_metaclass(abc.ABCMeta) +class Entity(object): + """The abstract base class for an entity in a Composer environment.""" + + def __init__(self, *args, **kwargs): + """Entity constructor. + + Subclasses should not override this method, instead implement a _build + method. + + Args: + *args: Arguments passed through to the _build method. + **kwargs: Keyword arguments. Passed through to the _build method, apart + from the following. + `observable_options`: A dictionary of Observable + configuration options. + """ + self._post_init_hooks = [] + + self._parent = None + self._attached = [] + + try: + observable_options = kwargs.pop('observable_options') + except KeyError: + observable_options = None + + self._build(*args, **kwargs) + self._observables = self._build_observables() + self._observables.set_options(observable_options) + + @abc.abstractmethod + def _build(self, *args, **kwargs): + """Entity initialization method to be overridden by subclasses.""" + raise NotImplementedError + + def _build_observables(self): + """Entity observables initialization method. + + Returns: + An object subclassing the Observables class. + """ + return Observables(self) + + def iter_entities(self, exclude_self=False): + """An iterator that recursively iterates through all attached entities. + + Args: + exclude_self: (optional) Whether to exclude this `Entity` itself from the + iterator. + + Yields: + If `exclude_self` is `False`, the first value yielded is this Entity + itself. The following Entities are then yielded recursively in a + depth-first fashion, following the order in which the Entities are + attached. + """ + if not exclude_self: + yield self + for attached_entity in self._attached: + for attached_entity_of_attached_entity in attached_entity.iter_entities(): + yield attached_entity_of_attached_entity + + @property + def observables(self): + """The observables defined by this entity.""" + return self._observables + + def initialize_episode_mjcf(self, random_state): + """Callback executed when the MJCF model is modified between episodes.""" + pass + + def after_compile(self, physics, random_state): + """Callback executed after the Mujoco Physics is recompiled.""" + pass + + def initialize_episode(self, physics, random_state): + """Callback executed during episode initialization.""" + pass + + def before_step(self, physics, random_state): + """Callback executed before an agent control step.""" + pass + + def before_substep(self, physics, random_state): + """Callback executed before a simulation step.""" + pass + + def after_substep(self, physics, random_state): + """A callback which is executed after a simulation step.""" + pass + + def after_step(self, physics, random_state): + """Callback executed after an agent control step.""" + pass + + @abc.abstractproperty + def mjcf_model(self): + raise NotImplementedError + + def attach(self, entity, attach_site=None): + """Attaches an `Entity` without any additional degrees of freedom. + + Args: + entity: The `Entity` to attach. + attach_site: (optional) The site to which to attach the entity's model. If + not set, defaults to self.attachment_site. + + Returns: + The frame of the attached model. + """ + + if attach_site is None: + attach_site = self.attachment_site + + frame = attach_site.attach(entity.mjcf_model) + self._attached.append(entity) + entity._parent = weakref.ref(self) # pylint: disable=protected-access + return frame + + def detach(self): + """Detaches this entity if it has previously been attached.""" + if self._parent is not None: + parent = self._parent() + if parent: # Weakref might dereference to None during garbage collection. + self.mjcf_model.detach() + parent._attached.remove(self) # pylint: disable=protected-access + self._parent = None + else: + raise RuntimeError('Cannot detach an entity that is not attached.') + + @property + def parent(self): + """Returns the `Entity` to which this entity is attached, or `None`.""" + return self._parent() if self._parent else None + + @property + def attachment_site(self): + return self.mjcf_model + + @property + def root_body(self): + if self.parent: + return mjcf.get_attachment_frame(self.mjcf_model) + else: + return self.mjcf_model.worldbody + + def global_vector_to_local_frame(self, physics, vec_in_world_frame): + """Linearly transforms a world-frame vector into entity's local frame. + + Note that this function does not perform an affine transformation of the + vector. In other words, the input vector is assumed to be specified with + respect to the same origin as this entity's local frame. This function + can also be applied to matrices whose innermost dimensions are either 2 or + 3. In this case, a matrix with the same leading dimensions is returned + where the innermost vectors are replaced by their values computed in the + local frame. + + Args: + physics: An `mjcf.Physics` instance. + vec_in_world_frame: A NumPy array with last dimension of shape (2,) or + (3,) that represents a vector quantity in the world frame. + + Returns: + The same quantity as `vec_in_world_frame` but reexpressed in this + entity's local frame. The returned np.array has the same shape as + np.asarray(vec_in_world_frame). + + Raises: + ValueError: if `vec_in_world_frame` does not have shape ending with (2,) + or (3,). + """ + vec_in_world_frame = np.asarray(vec_in_world_frame) + + xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3)) + # The ordering of the np.dot is such that the transformation holds for any + # matrix whose final dimensions are (2,) or (3,). + if vec_in_world_frame.shape[-1] == 2: + return np.dot(vec_in_world_frame, xmat[:2, :2]) + elif vec_in_world_frame.shape[-1] == 3: + return np.dot(vec_in_world_frame, xmat) + else: + raise ValueError('`vec_in_world_frame` should have shape with final ' + 'dimension 2 or 3: got {}'.format( + vec_in_world_frame.shape)) + + def global_xmat_to_local_frame(self, physics, xmat): + """Transforms another entity's `xmat` into this entity's local frame. + + This function takes another entity's (E) xmat, which is an SO(3) matrix + from E's frame to the world frame, and turns it to a matrix that transforms + from E's frame into this entity's local frame. + + Args: + physics: An `mjcf.Physics` instance. + xmat: A NumPy array of shape (3, 3) or (9,) that represents another + entity's xmat. + + Returns: + The `xmat` reexpressed in this entity's local frame. The returned + np.array has the same shape as np.asarray(xmat). + + Raises: + ValueError: if `xmat` does not have shape (3, 3) or (9,). + """ + xmat = np.asarray(xmat) + + input_shape = xmat.shape + if xmat.shape == (9,): + xmat = np.reshape(xmat, (3, 3)) + + self_xmat = np.reshape(physics.bind(self.root_body).xmat, (3, 3)) + if xmat.shape == (3, 3): + return np.reshape(np.dot(self_xmat.T, xmat), input_shape) + else: + raise ValueError('`xmat` should have shape (3, 3) or (9,): got {}'.format( + xmat.shape)) + + def get_pose(self, physics): + """Get the position and orientation of this entity relative to its parent. + + Note that the semantics differ slightly depending on whether or not the + entity has a free joint: + + * If it has a free joint the position and orientation are always given in + global coordinates. + * If the entity is fixed or attached with a different joint type then the + position and orientation are given relative to the parent frame. + + For entities that are either attached directly to the worldbody, or to other + entities that are positioned at the global origin (e.g. the arena) the + global and relative poses are equivalent. + + Args: + physics: An instance of `mjcf.Physics`. + + Returns: + A 2-tuple where the first entry is a (3,) numpy array representing the + position and the second is a (4,) numpy array representing orientation as + a quaternion. + + Raises: + RuntimeError: If the entity is not attached. + """ + root_joint = mjcf.get_frame_freejoint(self.mjcf_model) + if root_joint: + position = physics.bind(root_joint).qpos[:3] + quaternion = physics.bind(root_joint).qpos[3:] + else: + attachment_frame = mjcf.get_attachment_frame(self.mjcf_model) + if attachment_frame is None: + raise RuntimeError(_NO_ATTACHMENT_FRAME) + position = physics.bind(attachment_frame).pos + quaternion = physics.bind(attachment_frame).quat + return position, quaternion + + def set_pose(self, physics, position=None, quaternion=None): + """Sets position and/or orientation of this entity relative to its parent. + + If the entity is attached with a free joint, this method will set the + respective DoFs of the joint. If the entity is either fixed or attached with + a different joint type, this method will update the position and/or + quaternion of the attachment frame. + + Note that the semantics differ slightly between the two cases: the DoFs of a + free body are specified in global coordinates, whereas the position of a + non-free body is specified in relative coordinates with respect to the + parent frame. However, for entities that are either attached directly to the + worldbody, or to other entities that are positioned at the global origin + (e.g. the arena), there is no difference between the two cases. + + Args: + physics: An instance of `mjcf.Physics`. + position: (optional) A NumPy array of size 3. + quaternion: (optional) A NumPy array of size 4. + + Raises: + RuntimeError: If the entity is not attached. + """ + root_joint = mjcf.get_frame_freejoint(self.mjcf_model) + if root_joint: + if position is not None: + physics.bind(root_joint).qpos[:3] = position + if quaternion is not None: + physics.bind(root_joint).qpos[3:] = quaternion + else: + attachment_frame = mjcf.get_attachment_frame(self.mjcf_model) + if attachment_frame is None: + raise RuntimeError(_NO_ATTACHMENT_FRAME) + if position is not None: + physics.bind(attachment_frame).pos = position + if quaternion is not None: + normalised_quaternion = quaternion / np.linalg.norm(quaternion) + physics.bind(attachment_frame).quat = normalised_quaternion + + def shift_pose(self, + physics, + position=None, + quaternion=None, + rotate_velocity=False): + """Shifts the position and/or orientation from its current configuration. + + This is a convenience function that performs the same operation as + `set_pose`, but where the specified `position` is added to the current + position, and the specified `quaternion` is premultiplied to the current + quaternion. + + Args: + physics: An instance of `mjcf.Physics`. + position: (optional) A NumPy array of size 3. + quaternion: (optional) A NumPy array of size 4. + rotate_velocity: (optional) A bool, whether to shift the current linear + velocity along with the pose. This will rotate the current linear + velocity, which is expressed relative to the world frame. The angular + velocity, which is expressed relative to the local frame is left + unchanged. + + Raises: + RuntimeError: If the entity is not attached. + """ + current_position, current_quaternion = self.get_pose(physics) + new_position, new_quaternion = None, None + if position is not None: + new_position = current_position + position + if quaternion is not None: + quaternion = np.array(quaternion, dtype=np.float64, copy=False) + new_quaternion = _multiply_quaternions(quaternion, current_quaternion) + root_joint = mjcf.get_frame_freejoint(self.mjcf_model) + if root_joint and rotate_velocity: + # Rotate the linear velocity. The angular velocity (qvel[3:]) + # is left unchanged, as it is expressed in the local frame. + # When rotatating the body frame the angular velocity already + # tracks the rotation but the linear velocity does not. + velocity = physics.bind(root_joint).qvel[:3] + rotated_velocity = _rotate_vector(velocity, quaternion) + self.set_velocity(physics, rotated_velocity) + self.set_pose(physics, new_position, new_quaternion) + + def set_velocity(self, physics, velocity=None, angular_velocity=None): + """Sets the linear velocity and/or angular velocity of this free entity. + + If the entity is attached with a free joint, this method will set the + respective DoFs of the joint. Otherwise a warning is logged. + + Args: + physics: An instance of `mjcf.Physics`. + velocity: (optional) A NumPy array of size 3 specifying the + linear velocity. + angular_velocity: (optional) A NumPy array of size 3 specifying the + angular velocity + """ + root_joint = mjcf.get_frame_freejoint(self.mjcf_model) + if root_joint: + if velocity is not None: + physics.bind(root_joint).qvel[:3] = velocity + if angular_velocity is not None: + physics.bind(root_joint).qvel[3:] = angular_velocity + else: + logging.warning('Cannot set velocity on Entity with no free joint.') + + def configure_joints(self, physics, position): + """Configures this entity's internal joints. + + The default implementation of this method simply sets the `qpos` of all + joints in this entity to the values specified in the `position` argument. + Entity subclasses with actuated joints may override this method to achieve a + stable reconfiguration of joint positions, for example the control signal + of position actuators may be changed to match the new joint positions. + + Args: + physics: An instance of `mjcf.Physics`. + position: The desired position of this entity's joints. + """ + joints = self.mjcf_model.find_all('joint', exclude_attachments=True) + physics.bind(joints).qpos = position + + +class ModelWrapperEntity(Entity): + """An entity class that wraps an MJCF model without any additional logic.""" + + def _build(self, mjcf_model): + self._mjcf_model = mjcf_model + + @property + def mjcf_model(self): + return self._mjcf_model diff --git a/DMC/src/env/dm_control/dm_control/composer/entity_test.py b/DMC/src/env/dm_control/dm_control/composer/entity_test.py new file mode 100644 index 0000000..b27f398 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/entity_test.py @@ -0,0 +1,435 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for composer.Entity.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import mjcf +from dm_control.composer import arena +from dm_control.composer import define +from dm_control.composer import entity +from dm_control.composer.observation.observable import base as observable +import numpy as np +import six +from six.moves import range + +_NO_ROTATION = (1, 0, 0, 0) # Tests support for non-arrays and non-floats. +_NINETY_DEGREES_ABOUT_X = np.array( + [np.cos(np.pi / 4), np.sin(np.pi / 4), 0., 0.]) +_NINETY_DEGREES_ABOUT_Y = np.array( + [np.cos(np.pi / 4), 0., np.sin(np.pi / 4), 0.]) +_NINETY_DEGREES_ABOUT_Z = np.array( + [np.cos(np.pi / 4), 0., 0., np.sin(np.pi / 4)]) +_FORTYFIVE_DEGREES_ABOUT_X = np.array( + [np.cos(np.pi / 8), np.sin(np.pi / 8), 0., 0.]) + +_TEST_ROTATIONS = [ + # Triplets of original rotation, new rotation and final rotation. + (None, _NO_ROTATION, _NO_ROTATION), + (_NO_ROTATION, _NINETY_DEGREES_ABOUT_Z, _NINETY_DEGREES_ABOUT_Z), + (_FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Y, + np.array([0.65328, 0.2706, 0.65328, -0.2706])), +] + + +def _param_product(**param_lists): + keys, values = zip(*param_lists.items()) + for combination in itertools.product(*values): + yield dict(zip(keys, combination)) + + +class TestEntity(entity.Entity): + """Simple test entity that does nothing but declare some observables.""" + + def _build(self, name='test_entity'): + self._mjcf_root = mjcf.element.RootElement(model=name) + self._mjcf_root.worldbody.add('geom', type='sphere', size=(0.1,)) + + def _build_observables(self): + return TestEntityObservables(self) + + @property + def mjcf_model(self): + return self._mjcf_root + + +class TestEntityObservables(entity.Observables): + """Trivial observables for the test entity.""" + + @define.observable + def observable0(self): + return observable.Generic(lambda phys: 0.0) + + @define.observable + def observable1(self): + return observable.Generic(lambda phys: 1.0) + + +class EntityTest(parameterized.TestCase): + + def setUp(self): + super(EntityTest, self).setUp() + self.entity = TestEntity() + + def testNumObservables(self): + """Tests that the observables dict has the right number of entries.""" + self.assertLen(self.entity.observables.as_dict(), 2) + + def testObservableNames(self): + """Tests that the observables dict keys correspond to the observable names. + """ + obs = self.entity.observables.as_dict() + self.assertIn('observable0', obs) + self.assertIn('observable1', obs) + + subentity = TestEntity(name='subentity') + self.entity.attach(subentity) + self.assertIn('subentity/observable0', subentity.observables.as_dict()) + self.assertEqual(subentity.observables.dict_keys.observable0, + 'subentity/observable0') + self.assertIn('observable0', dir(subentity.observables.dict_keys)) + self.assertIn('subentity/observable1', subentity.observables.as_dict()) + self.assertEqual(subentity.observables.dict_keys.observable1, + 'subentity/observable1') + self.assertIn('observable1', dir(subentity.observables.dict_keys)) + + def testEnableDisableObservables(self): + """Test the enabling and disable functionality for observables.""" + all_obs = self.entity.observables.as_dict() + + self.entity.observables.enable_all() + for obs in all_obs.values(): + self.assertTrue(obs.enabled) + + self.entity.observables.disable_all() + for obs in all_obs.values(): + self.assertFalse(obs.enabled) + + self.entity.observables.observable0.enabled = True + self.assertTrue(all_obs['observable0'].enabled) + + def testObservableDefaultOptions(self): + corruptor = lambda x: x + options = { + 'update_interval': 2, + 'buffer_size': 10, + 'delay': 1, + 'aggregator': 'max', + 'corruptor': corruptor, + 'enabled': True + } + self.entity.observables.set_options(options) + + for obs in self.entity.observables.as_dict().values(): + self.assertEqual(obs.update_interval, 2) + self.assertEqual(obs.delay, 1) + self.assertEqual(obs.buffer_size, 10) + self.assertEqual(obs.aggregator, observable.AGGREGATORS['max']) + self.assertEqual(obs.corruptor, corruptor) + self.assertTrue(obs.enabled) + + def testObservablePartialDefaultOptions(self): + options = {'update_interval': 2, 'delay': 1} + self.entity.observables.set_options(options) + + for obs in self.entity.observables.as_dict().values(): + self.assertEqual(obs.update_interval, 2) + self.assertEqual(obs.delay, 1) + self.assertEqual(obs.buffer_size, None) + self.assertEqual(obs.aggregator, None) + self.assertEqual(obs.corruptor, None) + + def testObservableOptionsInvalidName(self): + options = {'asdf': None} + with six.assertRaisesRegex( + self, KeyError, 'No observable with name \'asdf\''): + self.entity.observables.set_options(options) + + def testObservableInvalidOptions(self): + options = {'observable0': {'asdf': 2}} + with six.assertRaisesRegex(self, AttributeError, + 'Cannot add attribute asdf in configure.'): + self.entity.observables.set_options(options) + + def testObservableOptions(self): + options = { + 'observable0': { + 'update_interval': 2, + 'delay': 3 + }, + 'observable1': { + 'update_interval': 4, + 'delay': 5 + } + } + self.entity.observables.set_options(options) + observables = self.entity.observables.as_dict() + self.assertEqual(observables['observable0'].update_interval, 2) + self.assertEqual(observables['observable0'].delay, 3) + self.assertEqual(observables['observable0'].buffer_size, None) + self.assertEqual(observables['observable0'].aggregator, None) + self.assertEqual(observables['observable0'].corruptor, None) + self.assertFalse(observables['observable0'].enabled) + + self.assertEqual(observables['observable1'].update_interval, 4) + self.assertEqual(observables['observable1'].delay, 5) + self.assertEqual(observables['observable1'].buffer_size, None) + self.assertEqual(observables['observable1'].aggregator, None) + self.assertEqual(observables['observable1'].corruptor, None) + self.assertFalse(observables['observable1'].enabled) + + def testObservableOptionsEntityConstructor(self): + options = { + 'observable0': { + 'update_interval': 2, + 'delay': 3 + }, + 'observable1': { + 'update_interval': 4, + 'delay': 5 + } + } + ent = TestEntity(observable_options=options) + observables = ent.observables.as_dict() + self.assertEqual(observables['observable0'].update_interval, 2) + self.assertEqual(observables['observable0'].delay, 3) + self.assertEqual(observables['observable0'].buffer_size, None) + self.assertEqual(observables['observable0'].aggregator, None) + self.assertEqual(observables['observable0'].corruptor, None) + self.assertFalse(observables['observable0'].enabled) + + self.assertEqual(observables['observable1'].update_interval, 4) + self.assertEqual(observables['observable1'].delay, 5) + self.assertEqual(observables['observable1'].buffer_size, None) + self.assertEqual(observables['observable1'].aggregator, None) + self.assertEqual(observables['observable1'].corruptor, None) + self.assertFalse(observables['observable1'].enabled) + + def testObservablePartialOptions(self): + options = {'observable0': {'update_interval': 2, 'delay': 3}} + self.entity.observables.set_options(options) + observables = self.entity.observables.as_dict() + self.assertEqual(observables['observable0'].update_interval, 2) + self.assertEqual(observables['observable0'].delay, 3) + self.assertEqual(observables['observable0'].buffer_size, None) + self.assertEqual(observables['observable0'].aggregator, None) + self.assertEqual(observables['observable0'].corruptor, None) + self.assertFalse(observables['observable0'].enabled) + + self.assertEqual(observables['observable1'].update_interval, 1) + self.assertEqual(observables['observable1'].delay, None) + self.assertEqual(observables['observable1'].buffer_size, None) + self.assertEqual(observables['observable1'].aggregator, None) + self.assertEqual(observables['observable1'].corruptor, None) + self.assertFalse(observables['observable1'].enabled) + + def testAttach(self): + entities = [TestEntity() for _ in range(4)] + entities[0].attach(entities[1]) + entities[1].attach(entities[2]) + entities[0].attach(entities[3]) + + self.assertIsNone(entities[0].parent) + self.assertIs(entities[1].parent, entities[0]) + self.assertIs(entities[2].parent, entities[1]) + self.assertIs(entities[3].parent, entities[0]) + + self.assertIsNone(entities[0].mjcf_model.parent_model) + self.assertIs(entities[1].mjcf_model.parent_model, entities[0].mjcf_model) + self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model) + self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model) + + self.assertEqual(list(entities[0].iter_entities()), entities) + + def testDetach(self): + entities = [TestEntity() for _ in range(4)] + entities[0].attach(entities[1]) + entities[1].attach(entities[2]) + entities[0].attach(entities[3]) + + entities[1].detach() + with six.assertRaisesRegex(self, RuntimeError, 'not attached'): + entities[1].detach() + + self.assertIsNone(entities[0].parent) + self.assertIsNone(entities[1].parent) + self.assertIs(entities[2].parent, entities[1]) + self.assertIs(entities[3].parent, entities[0]) + + self.assertIsNone(entities[0].mjcf_model.parent_model) + self.assertIsNone(entities[1].mjcf_model.parent_model) + self.assertIs(entities[2].mjcf_model.parent_model, entities[1].mjcf_model) + self.assertIs(entities[3].mjcf_model.parent_model, entities[0].mjcf_model) + + self.assertEqual(list(entities[0].iter_entities()), + [entities[0], entities[3]]) + + def testIterEntitiesExcludeSelf(self): + entities = [TestEntity() for _ in range(4)] + entities[0].attach(entities[1]) + entities[1].attach(entities[2]) + entities[0].attach(entities[3]) + self.assertEqual( + list(entities[0].iter_entities(exclude_self=True)), entities[1:]) + + def testGlobalVectorToLocalFrame(self): + parent = TestEntity() + parent.mjcf_model.worldbody.add( + 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model) + physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model) + + # 3D vectors + np.testing.assert_allclose( + self.entity.global_vector_to_local_frame(physics, [0, 1, 0]), + [1, 0, 0], atol=1e-10) + np.testing.assert_allclose( + self.entity.global_vector_to_local_frame(physics, [-1, 0, 0]), + [0, 1, 0], atol=1e-10) + np.testing.assert_allclose( + self.entity.global_vector_to_local_frame(physics, [0, 0, 1]), + [0, 0, 1], atol=1e-10) + + # 2D vectors; z-component is ignored + np.testing.assert_allclose( + self.entity.global_vector_to_local_frame(physics, [0, 1]), + [1, 0], atol=1e-10) + np.testing.assert_allclose( + self.entity.global_vector_to_local_frame(physics, [-1, 0]), + [0, 1], atol=1e-10) + + def testGlobalMatrixToLocalFrame(self): + parent = TestEntity() + parent.mjcf_model.worldbody.add( + 'site', xyaxes=[0, 1, 0, -1, 0, 0]).attach(self.entity.mjcf_model) + physics = mjcf.Physics.from_mjcf_model(parent.mjcf_model) + + rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]]) + ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]]) + + np.testing.assert_allclose( + self.entity.global_xmat_to_local_frame(physics, rotation_atob), + ego_rotation_atob, atol=1e-10) + + flat_rotation_atob = np.reshape(rotation_atob, -1) + flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1) + np.testing.assert_allclose( + self.entity.global_xmat_to_local_frame( + physics, flat_rotation_atob), + flat_rotation_ego_atob, atol=1e-10) + + @parameterized.parameters(*_param_product( + position=[None, [1., 0., -1.]], + quaternion=[None, _FORTYFIVE_DEGREES_ABOUT_X, _NINETY_DEGREES_ABOUT_Z], + freejoint=[False, True], + )) + def testSetPose(self, position, quaternion, freejoint): + # Setup entity. + test_arena = arena.Arena() + subentity = TestEntity(name='subentity') + frame = test_arena.attach(subentity) + if freejoint: + frame.add('freejoint') + + physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model) + + if quaternion is None: + ground_truth_quat = _NO_ROTATION + else: + ground_truth_quat = quaternion + + if position is None: + ground_truth_pos = np.zeros(shape=(3,)) + else: + ground_truth_pos = position + + subentity.set_pose(physics, position=position, quaternion=quaternion) + + np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos) + np.testing.assert_array_equal(physics.bind(frame).xquat, ground_truth_quat) + + @parameterized.parameters(*_param_product( + original_position=[[-2, -1, -1.], [1., 0., -1.]], + position=[None, [1., 0., -1.]], + original_quaternion=_TEST_ROTATIONS[0], + quaternion=_TEST_ROTATIONS[1], + expected_quaternion=_TEST_ROTATIONS[2], + freejoint=[False, True], + )) + def testShiftPose(self, original_position, position, original_quaternion, + quaternion, expected_quaternion, freejoint): + # Setup entity. + test_arena = arena.Arena() + subentity = TestEntity(name='subentity') + frame = test_arena.attach(subentity) + if freejoint: + frame.add('freejoint') + + physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model) + + # Set the original position + subentity.set_pose( + physics, position=original_position, quaternion=original_quaternion) + + if position is None: + ground_truth_pos = original_position + else: + ground_truth_pos = original_position + np.array(position) + subentity.shift_pose(physics, position=position, quaternion=quaternion) + np.testing.assert_array_equal(physics.bind(frame).xpos, ground_truth_pos) + + updated_quat = physics.bind(frame).xquat + np.testing.assert_array_almost_equal(updated_quat, expected_quaternion, + 1e-4) + + @parameterized.parameters(False, True) + def testShiftPoseWithVelocity(self, rotate_velocity): + # Setup entity. + test_arena = arena.Arena() + subentity = TestEntity(name='subentity') + frame = test_arena.attach(subentity) + frame.add('freejoint') + + physics = mjcf.Physics.from_mjcf_model(test_arena.mjcf_model) + + # Set the original position + subentity.set_pose(physics, position=[0., 0., 0.]) + + # Set velocity in y dim. + subentity.set_velocity(physics, [0., 1., 0.]) + + # Rotate the entity around the z axis. + subentity.shift_pose( + physics, quaternion=[0., 0., 0., 1.], rotate_velocity=rotate_velocity) + + physics.forward() + updated_position, _ = subentity.get_pose(physics) + if rotate_velocity: + # Should not have moved in the y dim. + np.testing.assert_array_almost_equal(updated_position[1], 0.) + else: + # Should not have moved in the x dim. + np.testing.assert_array_almost_equal(updated_position[0], 0.) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/environment.py b/DMC/src/env/dm_control/dm_control/composer/environment.py new file mode 100644 index 0000000..68e904e --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/environment.py @@ -0,0 +1,450 @@ +# Copyright 2018-2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RL environment classes for Composer tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings +import weakref + +from absl import logging +from dm_control import mjcf +from dm_control.composer import observation +from dm_control.rl import control +import dm_env +import numpy as np +from six.moves import range + +warnings.simplefilter('always', DeprecationWarning) + +_STEPS_LOGGING_INTERVAL = 10000 + +HOOK_NAMES = ('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode', + 'before_step', + 'before_substep', + 'after_substep', + 'after_step') + +_empty_function = lambda: None + + +def _empty_function_with_docstring(): + """Some docstring.""" + +_EMPTY_CODE = _empty_function.__code__.co_code +_EMPTY_WITH_DOCSTRING_CODE = _empty_function_with_docstring.__code__.co_code + + +def _callable_is_trivial(f): + return (f.__code__.co_code == _EMPTY_CODE or + f.__code__.co_code == _EMPTY_WITH_DOCSTRING_CODE) + + +class EpisodeInitializationError(RuntimeError): + """Raised by a `composer.Task` when it fails to initialize an episode.""" + + +class _Hook(object): + + __slots__ = ('entity_hooks', 'extra_hooks') + + def __init__(self): + self.entity_hooks = [] + self.extra_hooks = [] + + +class _EnvironmentHooks(object): + """Helper object that scans and memoizes various hooks in a task. + + This object exist to ensure that we do not incur a substantial overhead in + calling empty entity hooks in more complicated tasks. + """ + + __slots__ = (('_task', '_episode_step_count') + + tuple('_' + hook_name for hook_name in HOOK_NAMES)) + + def __init__(self, task): + self._task = task + self._episode_step_count = 0 + for hook_name in HOOK_NAMES: + slot_name = '_' + hook_name + setattr(self, slot_name, _Hook()) + self.refresh_entity_hooks() + + def refresh_entity_hooks(self): + """Scans and memoizes all non-trivial entity hooks.""" + for hook_name in HOOK_NAMES: + hooks = [] + for entity in self._task.root_entity.iter_entities(): + entity_hook = getattr(entity, hook_name) + # Ignore any hook that is a no-op to avoid function call overhead. + if not _callable_is_trivial(entity_hook): + hooks.append(entity_hook) + getattr(self, '_' + hook_name).entity_hooks = hooks + + def add_extra_hook(self, hook_name, hook_callable): + if hook_name not in HOOK_NAMES: + raise ValueError('{!r} is not a valid hook name'.format(hook_name)) + if not callable(hook_callable): + raise ValueError('{!r} is not a callable'.format(hook_callable)) + getattr(self, '_' + hook_name).extra_hooks.append(hook_callable) + + def initialize_episode_mjcf(self, random_state): + self._task.initialize_episode_mjcf(random_state) + for entity_hook in self._initialize_episode_mjcf.entity_hooks: + entity_hook(random_state) + for extra_hook in self._initialize_episode_mjcf.extra_hooks: + extra_hook(random_state) + + def after_compile(self, physics, random_state): + self._task.after_compile(physics, random_state) + for entity_hook in self._after_compile.entity_hooks: + entity_hook(physics, random_state) + for extra_hook in self._after_compile.extra_hooks: + extra_hook(physics, random_state) + + def initialize_episode(self, physics, random_state): + self._episode_step_count = 0 + self._task.initialize_episode(physics, random_state) + for entity_hook in self._initialize_episode.entity_hooks: + entity_hook(physics, random_state) + for extra_hook in self._initialize_episode.extra_hooks: + extra_hook(physics, random_state) + + def before_step(self, physics, action, random_state): + self._episode_step_count += 1 + if self._episode_step_count % _STEPS_LOGGING_INTERVAL == 0: + logging.info('The current episode has been running for %d steps.', + self._episode_step_count) + self._task.before_step(physics, action, random_state) + for entity_hook in self._before_step.entity_hooks: + entity_hook(physics, random_state) + for extra_hook in self._before_step.extra_hooks: + extra_hook(physics, action, random_state) + + def before_substep(self, physics, action, random_state): + self._task.before_substep(physics, action, random_state) + for entity_hook in self._before_substep.entity_hooks: + entity_hook(physics, random_state) + for extra_hooks in self._before_substep.extra_hooks: + extra_hooks(physics, action, random_state) + + def after_substep(self, physics, random_state): + self._task.after_substep(physics, random_state) + for entity_hook in self._after_substep.entity_hooks: + entity_hook(physics, random_state) + for extra_hook in self._after_substep.extra_hooks: + extra_hook(physics, random_state) + + def after_step(self, physics, random_state): + self._task.after_step(physics, random_state) + for entity_hook in self._after_step.entity_hooks: + entity_hook(physics, random_state) + for extra_hook in self._after_step.extra_hooks: + extra_hook(physics, random_state) + + +class _CommonEnvironment(object): + """Common components for RL environments.""" + + def __init__(self, task, time_limit=float('inf'), random_state=None, + n_sub_steps=None, + raise_exception_on_physics_error=True, + strip_singleton_obs_buffer_dim=False): + """Initializes an instance of `_CommonEnvironment`. + + Args: + task: Instance of `composer.base.Task`. + time_limit: (optional) A float, the time limit in seconds beyond which an + episode is forced to terminate. + random_state: Optional, either an int seed or an `np.random.RandomState` + object. If None (default), the random number generator will self-seed + from a platform-dependent source of entropy. + n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per + agent control step. New code should instead override the + `control_substep` property of the task. + raise_exception_on_physics_error: (optional) A boolean, indicating whether + `PhysicsError` should be raised as an exception. If `False`, physics + errors will result in the current episode being terminated with a + warning logged, and a new episode started. + strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, + the array shape of observations with `buffer_size == 1` will not have a + leading buffer dimension. + """ + self._task = task + if not isinstance(random_state, np.random.RandomState): + self._random_state = np.random.RandomState(random_state) + else: + self._random_state = random_state + self._hooks = _EnvironmentHooks(self._task) + self._time_limit = time_limit + self._raise_exception_on_physics_error = raise_exception_on_physics_error + self._strip_singleton_obs_buffer_dim = strip_singleton_obs_buffer_dim + + if n_sub_steps is not None: + warnings.simplefilter('once', DeprecationWarning) + warnings.warn('The `n_sub_steps` argument is deprecated. Please override ' + 'the `control_timestep` property of the task instead.', + DeprecationWarning) + self._overridden_n_sub_steps = n_sub_steps + + self._recompile_physics_and_update_observables() + + def add_extra_hook(self, hook_name, hook_callable): + self._hooks.add_extra_hook(hook_name, hook_callable) + + def _recompile_physics_and_update_observables(self): + """Sets up the environment for latest MJCF model from the task.""" + self._physics_proxy = None + self._recompile_physics() + if isinstance(self._physics, weakref.ProxyType): + self._physics_proxy = self._physics + else: + self._physics_proxy = weakref.proxy(self._physics) + + if self._overridden_n_sub_steps is not None: + self._n_sub_steps = self._overridden_n_sub_steps + else: + self._n_sub_steps = self._task.physics_steps_per_control_step + + self._hooks.refresh_entity_hooks() + self._hooks.after_compile(self._physics_proxy, self._random_state) + self._observation_updater = self._make_observation_updater() + self._observation_updater.reset(self._physics_proxy, self._random_state) + + def _recompile_physics(self): + """Creates a new Physics using the latest MJCF model from the task.""" + if getattr(self, '_physics', None): + self._physics.free() + self._physics = mjcf.Physics.from_mjcf_model( + self._task.root_entity.mjcf_model) + + def _make_observation_updater(self): + return observation.Updater( + self._task.observables, self._task.physics_steps_per_control_step, + self._strip_singleton_obs_buffer_dim) + + @property + def physics(self): + """Returns a `weakref.ProxyType` pointing to the current `mjcf.Physics`. + + Note that the underlying `mjcf.Physics` will be destroyed whenever the MJCF + model is recompiled. It is therefore unsafe for external objects to hold a + reference to `environment.physics`. Attempting to access attributes of a + dead `Physics` instance will result in a `ReferenceError`. + """ + return self._physics_proxy + + @property + def task(self): + return self._task + + @property + def random_state(self): + return self._random_state + + def control_timestep(self): + """Returns the interval between agent actions in seconds.""" + if self._overridden_n_sub_steps is not None: + return self.physics.timestep() * self._overridden_n_sub_steps + else: + return self.task.control_timestep + + +class Environment(_CommonEnvironment, dm_env.Environment): + """Reinforcement learning environment for Composer tasks.""" + + def __init__(self, task, time_limit=float('inf'), random_state=None, + n_sub_steps=None, + raise_exception_on_physics_error=True, + strip_singleton_obs_buffer_dim=False, + max_reset_attempts=1): + """Initializes an instance of `Environment`. + + Args: + task: Instance of `composer.base.Task`. + time_limit: (optional) A float, the time limit in seconds beyond which + an episode is forced to terminate. + random_state: (optional) an int seed or `np.random.RandomState` instance. + n_sub_steps: (DEPRECATED) An integer, number of physics steps to take per + agent control step. New code should instead override the + `control_substep` property of the task. + raise_exception_on_physics_error: (optional) A boolean, indicating whether + `PhysicsError` should be raised as an exception. If `False`, physics + errors will result in the current episode being terminated with a + warning logged, and a new episode started. + strip_singleton_obs_buffer_dim: (optional) A boolean, if `True`, + the array shape of observations with `buffer_size == 1` will not have a + leading buffer dimension. + max_reset_attempts: (optional) Maximum number of times to try resetting + the environment. If an `EpisodeInitializationError` is raised + during this process, an environment reset is reattempted up to this + number of times. If this count is exceeded then the most recent + exception will be allowed to propagate. Defaults to 1, i.e. no failure + is allowed. + """ + super(Environment, self).__init__( + task=task, + time_limit=time_limit, + random_state=random_state, + n_sub_steps=n_sub_steps, + raise_exception_on_physics_error=raise_exception_on_physics_error, + strip_singleton_obs_buffer_dim=strip_singleton_obs_buffer_dim) + self._max_reset_attempts = max_reset_attempts + self._reset_next_step = True + + def reset(self): + failed_attempts = 0 + while True: + try: + return self._reset_attempt() + except EpisodeInitializationError as e: + failed_attempts += 1 + if failed_attempts < self._max_reset_attempts: + logging.error('Error during episode reset: %s', repr(e)) + else: + raise + + def _reset_attempt(self): + self._hooks.initialize_episode_mjcf(self._random_state) + self._recompile_physics_and_update_observables() + with self._physics.reset_context(): + self._hooks.initialize_episode(self._physics_proxy, self._random_state) + self._observation_updater.reset(self._physics_proxy, self._random_state) + self._reset_next_step = False + return dm_env.TimeStep( + step_type=dm_env.StepType.FIRST, + reward=None, + discount=None, + observation=self._observation_updater.get_observation()) + + # TODO(b/129061424): Remove this method. + def step_spec(self): + """DEPRECATED: please use `reward_spec` and `discount_spec` instead.""" + warnings.warn('`step_spec` is deprecated, please use `reward_spec` and ' + '`discount_spec` instead.', DeprecationWarning) + if (self._task.get_reward_spec() is None or + self._task.get_discount_spec() is None): + raise NotImplementedError + return dm_env.TimeStep( + step_type=None, + reward=self._task.get_reward_spec(), + discount=self._task.get_discount_spec(), + observation=self._observation_updater.observation_spec(), + ) + + def step(self, action): + """Updates the environment using the action and returns a `TimeStep`.""" + if self._reset_next_step: + self._reset_next_step = False + return self.reset() + + self._hooks.before_step(self._physics_proxy, action, self._random_state) + self._observation_updater.prepare_for_next_control_step() + + try: + for i in range(self._n_sub_steps): + self._hooks.before_substep(self._physics_proxy, action, + self._random_state) + self._physics.step() + self._hooks.after_substep(self._physics_proxy, self._random_state) + # The final observation update must happen after all the hooks in + # `self._hooks.after_step` is called. Otherwise, if any of these hooks + # modify the physics state then we might capture an observation that is + # inconsistent with the final physics state. + if i < self._n_sub_steps - 1: + self._observation_updater.update() + physics_is_divergent = False + except control.PhysicsError as e: + if not self._raise_exception_on_physics_error: + logging.warning(e) + physics_is_divergent = True + else: + raise + + self._hooks.after_step(self._physics_proxy, self._random_state) + self._observation_updater.update() + + if not physics_is_divergent: + reward = self._task.get_reward(self._physics_proxy) + discount = self._task.get_discount(self._physics_proxy) + terminating = ( + self._task.should_terminate_episode(self._physics_proxy) + or self._physics.time() >= self._time_limit + ) + else: + reward = 0.0 + discount = 0.0 + terminating = True + + obs = self._observation_updater.get_observation() + + if not terminating: + return dm_env.TimeStep(dm_env.StepType.MID, reward, discount, obs) + else: + self._reset_next_step = True + return dm_env.TimeStep(dm_env.StepType.LAST, reward, discount, obs) + + def action_spec(self): + """Returns the action specification for this environment.""" + return self._task.action_spec(self._physics_proxy) + + def reward_spec(self): + """Describes the reward returned by this environment. + + This will be the output of `self.task.reward_spec()` if it is not None, + otherwise it will be the default spec returned by + `dm_env.Environment.reward_spec()`. + + Returns: + A `specs.Array` instance, or a nested dict, list or tuple of + `specs.Array`s. + """ + task_reward_spec = self._task.get_reward_spec() + if task_reward_spec is not None: + return task_reward_spec + else: + return super(Environment, self).reward_spec() + + def discount_spec(self): + """Describes the discount returned by this environment. + + This will be the output of `self.task.discount_spec()` if it is not None, + otherwise it will be the default spec returned by + `dm_env.Environment.discount_spec()`. + + Returns: + A `specs.Array` instance, or a nested dict, list or tuple of + `specs.Array`s. + """ + task_discount_spec = self._task.get_discount_spec() + if task_discount_spec is not None: + return task_discount_spec + else: + return super(Environment, self).discount_spec() + + def observation_spec(self): + """Returns the observation specification for this environment. + + Returns: + An `OrderedDict` mapping observation name to `specs.Array` containing + observation shape and dtype. + """ + return self._observation_updater.observation_spec() diff --git a/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py b/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py new file mode 100644 index 0000000..5fcebbf --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/environment_hooks_test.py @@ -0,0 +1,46 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for Entity and Task hooks in an Environment.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from dm_control import composer +from dm_control.composer import hooks_test_utils +import numpy as np +from six.moves import range + + +class EnvironmentHooksTest(hooks_test_utils.HooksTestMixin, absltest.TestCase): + + def testEnvironmentHooksScheduling(self): + env = composer.Environment(self.task) + for hook_name in composer.HOOK_NAMES: + env.add_extra_hook(hook_name, getattr(self.extra_hooks, hook_name)) + for _ in range(self.num_episodes): + with self.track_episode(): + env.reset() + for _ in range(self.steps_per_episode): + env.step([0.1, 0.2, 0.3, 0.4]) + np.testing.assert_array_equal(env.physics.data.ctrl, + [0.1, 0.2, 0.3, 0.4]) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/environment_test.py b/DMC/src/env/dm_control/dm_control/composer/environment_test.py new file mode 100644 index 0000000..835450f --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/environment_test.py @@ -0,0 +1,106 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.composer.environment.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable +import dm_env +import mock +import numpy as np +from six.moves import range + + +class DummyTask(composer.NullTask): + + def __init__(self): + null_entity = composer.ModelWrapperEntity(mjcf.RootElement()) + super(DummyTask, self).__init__(null_entity) + + @property + def task_observables(self): + time = observable.Generic(lambda physics: physics.time()) + time.enabled = True + return {'time': time} + + +class DummyTaskWithResetFailures(DummyTask): + + def __init__(self, num_reset_failures): + super(DummyTaskWithResetFailures, self).__init__() + self.num_reset_failures = num_reset_failures + self.reset_counter = 0 + + def initialize_episode_mjcf(self, random_state): + self.reset_counter += 1 + + def initialize_episode(self, physics, random_state): + if self.reset_counter <= self.num_reset_failures: + raise composer.EpisodeInitializationError() + + +class EnvironmentTest(parameterized.TestCase): + + def test_failed_resets(self): + total_reset_failures = 5 + env_reset_attempts = 2 + task = DummyTaskWithResetFailures(num_reset_failures=total_reset_failures) + env = composer.Environment(task, max_reset_attempts=env_reset_attempts) + for _ in range(total_reset_failures // env_reset_attempts): + with self.assertRaises(composer.EpisodeInitializationError): + env.reset() + env.reset() # should not raise an exception + self.assertEqual(task.reset_counter, total_reset_failures + 1) + + @parameterized.parameters( + dict(name='reward_spec', defined_in_task=True), + dict(name='reward_spec', defined_in_task=False), + dict(name='discount_spec', defined_in_task=True), + dict(name='discount_spec', defined_in_task=False)) + def test_get_spec(self, name, defined_in_task): + task = DummyTask() + env = composer.Environment(task) + with mock.patch.object(task, 'get_' + name) as mock_task_get_spec: + if defined_in_task: + expected_spec = mock.Mock() + mock_task_get_spec.return_value = expected_spec + else: + expected_spec = getattr(dm_env.Environment, name)(env) + mock_task_get_spec.return_value = None + spec = getattr(env, name)() + mock_task_get_spec.assert_called_once_with() + self.assertSameStructure(spec, expected_spec) + + def test_can_provide_observation(self): + task = DummyTask() + env = composer.Environment(task) + obs = env.reset().observation + self.assertLen(obs, 1) + np.testing.assert_array_equal(obs['time'], env.physics.time()) + for _ in range(20): + obs = env.step([]).observation + self.assertLen(obs, 1) + np.testing.assert_array_equal(obs['time'], env.physics.time()) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py b/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py new file mode 100644 index 0000000..f1a449f --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/hooks_test_utils.py @@ -0,0 +1,326 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Utilities for testing environment hooks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import contextlib +import inspect + +from dm_control import composer +from dm_control import mjcf +from six.moves import range + + +def add_bodies_and_actuators(mjcf_model, num_actuators): + if num_actuators % 2: + raise ValueError('num_actuators is not a multiple of 2') + for _ in range(num_actuators // 2): + body = mjcf_model.worldbody.add('body') + body.add('inertial', pos=[0, 0, 0], mass=1, diaginertia=[1, 1, 1]) + joint_x = body.add('joint', axis=[1, 0, 0]) + mjcf_model.actuator.add('position', joint=joint_x) + joint_y = body.add('joint', axis=[0, 1, 0]) + mjcf_model.actuator.add('position', joint=joint_y) + + +class HooksTracker(object): + """Helper class for tracking call order of callbacks.""" + + def __init__(self, test_case, physics_timestep, control_timestep, + *args, **kwargs): + super(HooksTracker, self).__init__(*args, **kwargs) + self.tracked = False + self._test_case = test_case + self._call_count = collections.defaultdict(lambda: 0) + self._physics_timestep = physics_timestep + self._physics_steps_per_control_step = ( + round(int(control_timestep / physics_timestep))) + + mro = inspect.getmro(type(self)) + self._has_super = mro[mro.index(HooksTracker) + 1] != object + + def assertEqual(self, actual, expected, msg=''): + msg = '{}: {}: {!r} != {!r}'.format(type(self), msg, actual, expected) + self._test_case.assertEqual(actual, expected, msg) + + def assertHooksNotCalled(self, *hook_names): + for hook_name in hook_names: + self.assertEqual( + self._call_count[hook_name], 0, + 'assertHooksNotCalled: hook_name = {!r}'.format(hook_name)) + + def assertHooksCalledOnce(self, *hook_names): + for hook_name in hook_names: + self.assertEqual( + self._call_count[hook_name], 1, + 'assertHooksCalledOnce: hook_name = {!r}'.format(hook_name)) + + def assertCompleteEpisode(self, control_steps): + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode') + physics_steps = control_steps * self._physics_steps_per_control_step + self.assertEqual(self._call_count['before_step'], control_steps) + self.assertEqual(self._call_count['before_substep'], physics_steps) + self.assertEqual(self._call_count['after_substep'], physics_steps) + self.assertEqual(self._call_count['after_step'], control_steps) + + def assertPhysicsStepCountEqual(self, physics, expected_count): + actual_count = int(round(physics.time() / self._physics_timestep)) + self.assertEqual(actual_count, expected_count) + + def reset_call_counts(self): + self._call_count = collections.defaultdict(lambda: 0) + + def initialize_episode_mjcf(self, random_state): + """Implements `initialize_episode_mjcf` Composer callback.""" + if self._has_super: + super(HooksTracker, self).initialize_episode_mjcf(random_state) + if not self.tracked: + return + self.assertHooksNotCalled('after_compile', + 'initialize_episode', + 'before_step', + 'before_substep', + 'after_substep', + 'after_step') + self._call_count['initialize_episode_mjcf'] += 1 + + def after_compile(self, physics, random_state): + """Implements `after_compile` Composer callback.""" + if self._has_super: + super(HooksTracker, self).after_compile(physics, random_state) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf') + self.assertHooksNotCalled('initialize_episode', + 'before_step', + 'before_substep', + 'after_substep', + 'after_step') + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual(physics, + self._call_count['before_substep']) + self._call_count['after_compile'] += 1 + + def initialize_episode(self, physics, random_state): + """Implements `initialize_episode` Composer callback.""" + if self._has_super: + super(HooksTracker, self).initialize_episode(physics, random_state) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile') + self.assertHooksNotCalled('before_step', + 'before_substep', + 'after_substep', + 'after_step') + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual(physics, + self._call_count['before_substep']) + self._call_count['initialize_episode'] += 1 + + def before_step(self, physics, *args): + """Implements `before_step` Composer callback.""" + if self._has_super: + super(HooksTracker, self).before_step(physics, *args) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode') + + # `before_step` is only called in between complete control steps. + self.assertEqual( + self._call_count['after_step'], self._call_count['before_step']) + + # Complete control steps imply complete physics steps. + self.assertEqual( + self._call_count['after_substep'], self._call_count['before_substep']) + + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual(physics, + self._call_count['before_substep']) + + self._call_count['before_step'] += 1 + + def before_substep(self, physics, *args): + """Implements `before_substep` Composer callback.""" + if self._has_super: + super(HooksTracker, self).before_substep(physics, *args) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode') + + # We are inside a partial control step, so `after_step` should lag behind. + self.assertEqual( + self._call_count['after_step'], self._call_count['before_step'] - 1) + + # `before_substep` is only called in between complete physics steps. + self.assertEqual( + self._call_count['after_substep'], self._call_count['before_substep']) + + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual( + physics, self._call_count['before_substep']) + + self._call_count['before_substep'] += 1 + + def after_substep(self, physics, random_state): + """Implements `after_substep` Composer callback.""" + if self._has_super: + super(HooksTracker, self).after_substep(physics, random_state) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode') + + # We are inside a partial control step, so `after_step` should lag behind. + self.assertEqual( + self._call_count['after_step'], self._call_count['before_step'] - 1) + + # We are inside a partial physics step, so `after_substep` should be behind. + self.assertEqual(self._call_count['after_substep'], + self._call_count['before_substep'] - 1) + + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual( + physics, self._call_count['before_substep']) + + self._call_count['after_substep'] += 1 + + def after_step(self, physics, random_state): + """Implements `after_step` Composer callback.""" + if self._has_super: + super(HooksTracker, self).after_step(physics, random_state) + if not self.tracked: + return + self.assertHooksCalledOnce('initialize_episode_mjcf', + 'after_compile', + 'initialize_episode') + + # We are inside a partial control step, so `after_step` should lag behind. + self.assertEqual( + self._call_count['after_step'], self._call_count['before_step'] - 1) + + # `after_step` is only called in between complete physics steps. + self.assertEqual( + self._call_count['after_substep'], self._call_count['before_substep']) + + # Number of physics steps is always consistent with `before_substep`. + self.assertPhysicsStepCountEqual( + physics, self._call_count['before_substep']) + + # Check that the number of physics steps is consistent with control steps. + self.assertEqual( + self._call_count['before_substep'], + self._call_count['before_step'] * self._physics_steps_per_control_step) + + self._call_count['after_step'] += 1 + + +class TrackedEntity(HooksTracker, composer.Entity): + """A `composer.Entity` that tracks call order of callbacks.""" + + def _build(self, name): + self._mjcf_root = mjcf.RootElement(model=name) + + @property + def mjcf_model(self): + return self._mjcf_root + + @property + def name(self): + return self._mjcf_root.model + + +class TrackedTask(HooksTracker, composer.NullTask): + """A `composer.Task` that tracks call order of callbacks.""" + + def __init__(self, physics_timestep, control_timestep, *args, **kwargs): + super(TrackedTask, self).__init__(physics_timestep=physics_timestep, + control_timestep=control_timestep, + *args, **kwargs) + self.set_timesteps(physics_timestep=physics_timestep, + control_timestep=control_timestep) + add_bodies_and_actuators(self.root_entity.mjcf_model, num_actuators=4) + + +class HooksTestMixin(object): + """A mixin for an `absltest.TestCase` to track call order of callbacks.""" + + def setUp(self): + """Sets up the test case.""" + super(HooksTestMixin, self).setUp() + + self.num_episodes = 5 + self.steps_per_episode = 100 + + self.control_timestep = 0.05 + self.physics_timestep = 0.002 + + self.extra_hooks = HooksTracker(physics_timestep=self.physics_timestep, + control_timestep=self.control_timestep, + test_case=self) + + self.entities = [] + for i in range(9): + self.entities.append(TrackedEntity(name='entity_{}'.format(i), + physics_timestep=self.physics_timestep, + control_timestep=self.control_timestep, + test_case=self)) + + ######################################## + # Make the following entity hierarchy # + # 0 # + # 1 2 3 # + # 4 5 6 7 # + # 8 # + ######################################## + + self.entities[4].attach(self.entities[8]) + self.entities[1].attach(self.entities[4]) + self.entities[1].attach(self.entities[5]) + self.entities[0].attach(self.entities[1]) + + self.entities[2].attach(self.entities[6]) + self.entities[2].attach(self.entities[7]) + self.entities[0].attach(self.entities[2]) + + self.entities[0].attach(self.entities[3]) + + self.task = TrackedTask(root_entity=self.entities[0], + physics_timestep=self.physics_timestep, + control_timestep=self.control_timestep, + test_case=self) + + @contextlib.contextmanager + def track_episode(self): + tracked_objects = [self.task, self.extra_hooks] + self.entities + for obj in tracked_objects: + obj.reset_call_counts() + obj.tracked = True + yield + for obj in tracked_objects: + obj.assertCompleteEpisode(self.steps_per_episode) + obj.tracked = False diff --git a/DMC/src/env/dm_control/dm_control/composer/initializer.py b/DMC/src/env/dm_control/dm_control/composer/initializer.py new file mode 100644 index 0000000..e355125 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/initializer.py @@ -0,0 +1,33 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module defining the abstract initializer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + + +@six.add_metaclass(abc.ABCMeta) +class Initializer(object): + """The abstract base class for an initializer.""" + + @abc.abstractmethod + def __call__(self, physics, random_state): + raise NotImplementedError diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py b/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py new file mode 100644 index 0000000..dd9cbfc --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Multi-rate observation and buffering framework for Composer environments.""" + +from dm_control.composer.observation import observable +from dm_control.composer.observation.obs_buffer import Buffer +from dm_control.composer.observation.updater import DEFAULT_BUFFER_SIZE +from dm_control.composer.observation.updater import DEFAULT_DELAY +from dm_control.composer.observation.updater import DEFAULT_UPDATE_INTERVAL +from dm_control.composer.observation.updater import Updater diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py b/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py new file mode 100644 index 0000000..e9cc6c6 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/fake_physics.py @@ -0,0 +1,76 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A fake Physics class for unit testing observation framework.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from dm_control.composer.observation import observable +from dm_control.rl import control +import numpy as np + + +class FakePhysics(control.Physics): + """A fake Physics class for unit testing observation framework.""" + + def __init__(self): + self._step_counter = 0 + self._observables = { + 'twice': observable.Generic(FakePhysics.twice), + 'repeated': observable.Generic(FakePhysics.repeated, update_interval=5), + 'matrix': observable.Generic(FakePhysics.matrix, update_interval=3) + } + + def step(self, sub_steps=1): + self._step_counter += 1 + + @property + def observables(self): + return self._observables + + def twice(self): + return 2*self._step_counter + + def repeated(self): + return [self._step_counter, self._step_counter] + + def sqrt(self): + return np.sqrt(self._step_counter) + + def matrix(self): + return [[self._step_counter] * 3] * 2 + + def time(self): + return self._step_counter + + def timestep(self): + return 1.0 + + def set_control(self, ctrl): + pass + + def reset(self): + self._step_counter = 0 + + def after_reset(self): + pass + + @contextlib.contextmanager + def suppress_physics_errors(self): + yield diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py new file mode 100644 index 0000000..876c0fa --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer.py @@ -0,0 +1,251 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""""An object that manages the buffering and delaying of observation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np +import six +from six.moves import range + + +class InFlightObservation(object): + """Represents a delayed observation that may not have arrived yet. + + Attributes: + arrival: The time at which this observation will be delivered. + timestamp: The time at which this observation was made. + delay: The amount of delay between the time at which this observation was + made and the time at which it is delivered. + value: The value of this observation. + """ + + __slots__ = ('arrival', 'timestamp', 'delay', 'value') + + def __init__(self, timestamp, delay, value): + self.arrival = timestamp + delay + self.timestamp = timestamp + self.delay = delay + self.value = value + + def __lt__(self, other): + # This is implemented to facilitate sorting. + return self.arrival < other.arrival + + +class Buffer(object): + """An object that manages the buffering and delaying of observation.""" + + def __init__(self, buffer_size, shape, dtype, pad_value=0, + strip_singleton_buffer_dim=False): + """Initializes this observation buffer. + + Args: + buffer_size: The size of the buffer returned by `read`. Note + that this does *not* affect size of the internal buffer held by this + object, which always grow as large as is necessary in the presence of + large delays. + shape: The shape of a single observation held by this buffer, which can + either be a single integer or an iterable of integers. The shape of the + buffer returned by `read` will then be + `(buffer_size, shape[0], ..., shape[n])`, unless `buffer_size == 1` + and `strip_singleton_buffer_dim == True`. + dtype: The NumPy dtype of observation entries. + pad_value: (optional) The value that is used to pad the buffer returned + by `read` when the number of observation entries is less + then `buffer_size`. Specifically, the buffer will be padded by + `np.full(shape, pad_value, dtype)`. + strip_singleton_buffer_dim: (optional) A boolean, if `True` and + `buffer_size == 1` then the leading dimension will not be added to the + shape of the array returned by `read`. + """ + self._buffer_size = buffer_size + try: + shape = tuple(shape) + except TypeError: + if isinstance(shape, int): + shape = (shape,) + else: + raise + + self._has_buffer_dim = not (strip_singleton_buffer_dim and buffer_size == 1) + if self._has_buffer_dim: + self._buffered_shape = (buffer_size,) + shape + else: + self._buffered_shape = shape + self._dtype = dtype + + # The "arrived" deque contains entries that are due to be delivered now. + # This deque should never grow beyond buffer_size. + self._arrived_deque = collections.deque(maxlen=buffer_size) + for _ in range(buffer_size): + self._arrived_deque.append( + InFlightObservation(-np.inf, 0, np.full(shape, pad_value, dtype))) + + # The "pending" deque contains entries that are stored for future delivery. + # This deque can grow arbitrarily large in presence of long delays. + self._pending_deque = collections.deque() + + def _update_arrived_deque(self, timestamp): + while self._pending_deque and self._pending_deque[0].arrival <= timestamp: + self._arrived_deque.append(self._pending_deque.popleft()) + + @property + def shape(self): + return self._buffered_shape + + @property + def dtype(self): + return self._dtype + + def insert(self, timestamp, delay, value): + """Inserts a new observation to the buffer. + + This function implicitly updates the internal "clock" of this buffer to + the timestamp of the new observation, and the internal buffer is trimmed + accordingly, i.e. at most `buffer_size` items whose delayed arrival time + preceeds `timestamp` are kept. + + Args: + timestamp: The time at which this observation was made. + delay: The amount of delay between the time at which this observation was + made and the time at which it is delivered. + value: The value of this observation. + + Raises: + ValueError: if `delay` is negative. + """ + self._update_arrived_deque(timestamp) + new_obs = InFlightObservation(timestamp, delay, np.array(value)) + arrival = new_obs.arrival + if delay == 0: + # No delay, so the new observation is due for immediate delivery. + # Add it to the arrived deque. + self._arrived_deque.append(new_obs) + elif delay > 0: + if not self._pending_deque or arrival > self._pending_deque[-1].arrival: + # New observation's arrival time is monotonic. + # Technically, we can handle this in the general code branch below, + # but since this is assumed to be the "typical" case, the special + # handling here saves us from repeatedly allocating and deallocating + # an empty temporary deque. + self._pending_deque.append(new_obs) + else: + # General, out-of-order observation. + arriving_after_new_obs = collections.deque() + while self._pending_deque and arrival < self._pending_deque[-1].arrival: + arriving_after_new_obs.appendleft(self._pending_deque.pop()) + self._pending_deque.append(new_obs) + for existing_obs in arriving_after_new_obs: + self._pending_deque.append(existing_obs) + else: + raise ValueError('`delay` should not be negative: ' + 'got {!r}'.format(delay)) + + def read(self, current_time): + """Reads the content of the buffer at the given timestamp.""" + self._update_arrived_deque(current_time) + if self._has_buffer_dim: + out = np.empty(self._buffered_shape, dtype=self._dtype) + for i, obs in enumerate(self._arrived_deque): + out[i] = obs.value + else: + out = self._arrived_deque[0].value.copy() + return out + + def drop_unobserved_upcoming_items(self, observation_schedule, read_interval): + """Plans an optimal observation schedule for an upcoming control period. + + This function determines which of the proposed upcoming observations will + never in fact be delivered and removes them from the observation schedule. + + We assume that observations will only be queried at times that are integer + multiples of `read_interval`. If more observations are generated during + the upcoming control step than the `buffer_size` of this `Buffer` + then of those new observations will never be required. This function takes + into account the delayed arrival time and existing buffered items in the + planning process. + + Args: + observation_schedule: An list of `(timestamp, delay)` tuples, where + `timestamp` is the time at which the observation value will be produced, + and `delay` is the amount of time the observation will be delayed by. + This list will be modified in place. + read_interval: The time interval between successive calls to `read`. + We assume that observations will only be queried at times that are + integer multiples of `read_interval`. + """ + # Private deques to simulate what the deques will look like in the future, + # according to the proposed upcoming observation schedule. + future_arrived_deque = collections.deque() + future_pending_deque = collections.deque() + + # Take existing buffered observations into account when planning the + # upcoming schedule. + def get_next_existing_timestamp(): + for obs in reversed(self._pending_deque): + yield InFlightObservation(obs.timestamp, obs.delay, None) + while True: + yield InFlightObservation(-np.inf, 0, None) + existing_timestamp_iter = get_next_existing_timestamp() + existing_timestamp = six.next(existing_timestamp_iter) + + # Build the simulated state of the pending deque at the end of the proposed + # schedule. + sorted_schedule = sorted([InFlightObservation(time[0], time[1], None) + for time in observation_schedule]) + for new_timestamp in reversed(sorted_schedule): + # We don't need to worry about any existing item that are delivered before + # the first new item, since those are purged independently of our + # proposed new observations. + while existing_timestamp.arrival > new_timestamp.arrival: + future_pending_deque.appendleft(existing_timestamp) + existing_timestamp = six.next(existing_timestamp_iter) + future_pending_deque.appendleft(new_timestamp) + + # Find the next timestep at which `read` is called. + first_proposed_timestamp = min(t for t, _ in observation_schedule) + next_read_time = read_interval * int(np.ceil( + first_proposed_timestamp // read_interval)) + + # Build the simulated state of the arrived deque at each subsequent + # control steps. + while future_pending_deque: + # Keep track of observations that are delivered for the first time + # during this control timestep. + newly_arrived = collections.deque() + while (future_pending_deque and + future_pending_deque[0].arrival <= next_read_time): + # `fake_observation` is an `InFlightObservation` without `value`. + fake_observation = future_pending_deque.popleft() + future_arrived_deque.append(fake_observation) + newly_arrived.append(fake_observation) + while len(future_arrived_deque) > self._buffer_size: + stale = future_arrived_deque.popleft() + # Newly-arrived items that become immediately stale are never actually + # delivered. + if newly_arrived and stale == newly_arrived[0]: + newly_arrived.popleft() + # `stale` might either be one of the existing pending observations or + # from the proposed schedule. + if stale.timestamp >= first_proposed_timestamp: + observation_schedule.remove((stale.timestamp, stale.delay)) + + next_read_time += read_interval diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py new file mode 100644 index 0000000..73841fb --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/obs_buffer_test.py @@ -0,0 +1,83 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for observation.obs_buffer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.composer.observation import obs_buffer +import numpy as np +from six.moves import range + + +def _generate_constant_schedule(update_timestep, delay, + control_timestep, n_observed_steps): + first = update_timestep + last = control_timestep * n_observed_steps + 1 + return [(i, delay) for i in range(first, last, update_timestep)] + + +class BufferTest(parameterized.TestCase): + + def testOutOfOrderArrival(self): + buf = obs_buffer.Buffer(buffer_size=3, shape=(), dtype=np.float) + buf.insert(timestamp=0, delay=4, value=1) + buf.insert(timestamp=1, delay=2, value=2) + buf.insert(timestamp=2, delay=3, value=3) + np.testing.assert_array_equal(buf.read(current_time=2), [0., 0., 0.]) + np.testing.assert_array_equal(buf.read(current_time=3), [0., 0., 2.]) + np.testing.assert_array_equal(buf.read(current_time=4), [0., 2., 1.]) + np.testing.assert_array_equal(buf.read(current_time=5), [2., 1., 3.]) + np.testing.assert_array_equal(buf.read(current_time=6), [2., 1., 3.]) + + @parameterized.parameters(((3, 3),), ((),)) + def testStripSingletonDimension(self, shape): + buf = obs_buffer.Buffer(buffer_size=1, shape=shape, dtype=np.float, + strip_singleton_buffer_dim=True) + expected_value = np.full(shape, 42, dtype=np.float) + buf.insert(timestamp=0, delay=0, value=expected_value) + np.testing.assert_array_equal(buf.read(current_time=1), expected_value) + + def testPlanToSingleUndelayedObservation(self): + buf = obs_buffer.Buffer( + buffer_size=1, shape=(), dtype=np.float) + control_timestep = 20 + observation_schedule = _generate_constant_schedule( + update_timestep=1, delay=0, + control_timestep=control_timestep, n_observed_steps=1) + buf.drop_unobserved_upcoming_items( + observation_schedule, read_interval=control_timestep) + self.assertEqual(observation_schedule, [(20, 0)]) + + def testPlanTwoStepsAhead(self): + buf = obs_buffer.Buffer( + buffer_size=1, shape=(), dtype=np.float) + control_timestep = 5 + observation_schedule = _generate_constant_schedule( + update_timestep=2, delay=3, + control_timestep=control_timestep, n_observed_steps=2) + buf.drop_unobserved_upcoming_items( + observation_schedule, read_interval=control_timestep) + self.assertEqual(observation_schedule, [(2, 3), (6, 3), (10, 3)]) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py new file mode 100644 index 0000000..cd61957 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module for observables in the Composer library.""" + +from dm_control.composer.observation.observable.base import Generic +from dm_control.composer.observation.observable.base import MujocoCamera +from dm_control.composer.observation.observable.base import MujocoFeature +from dm_control.composer.observation.observable.base import Observable + +from dm_control.composer.observation.observable.mjcf import MJCFCamera +from dm_control.composer.observation.observable.mjcf import MJCFFeature diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py new file mode 100644 index 0000000..5b8cddc --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base.py @@ -0,0 +1,318 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Classes representing observables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import functools + +from dm_env import specs +import numpy as np +import six + + +def _make_aggregator(np_reducer_func, bounds_preserving): + result = functools.partial(np_reducer_func, axis=0) + setattr(result, 'bounds_reserving', bounds_preserving) + return result + + +AGGREGATORS = { + 'min': _make_aggregator(np.min, True), + 'max': _make_aggregator(np.max, True), + 'mean': _make_aggregator(np.mean, True), + 'median': _make_aggregator(np.median, True), + 'sum': _make_aggregator(np.sum, False), +} + + +def _get_aggregator(name_or_callable): + """Returns aggregator from predefined set by name, else returns callable.""" + if name_or_callable is None: + return None + elif not callable(name_or_callable): + try: + return AGGREGATORS[name_or_callable] + except KeyError: + raise KeyError('Unrecognized aggregator name: {!r}. Valid names: {}.' + .format(name_or_callable, AGGREGATORS.keys())) + else: + return name_or_callable + + +@six.add_metaclass(abc.ABCMeta) +class Observable(object): + """Abstract base class for an observable.""" + + def __init__(self, update_interval, buffer_size, delay, + aggregator, corruptor): + self._update_interval = update_interval + self._buffer_size = buffer_size + self._delay = delay + self._aggregator = _get_aggregator(aggregator) + self._corruptor = corruptor + self._enabled = False + + @property + def update_interval(self): + return self._update_interval + + @update_interval.setter + def update_interval(self, value): + self._update_interval = value + + @property + def buffer_size(self): + return self._buffer_size + + @buffer_size.setter + def buffer_size(self, value): + self._buffer_size = value + + @property + def delay(self): + return self._delay + + @delay.setter + def delay(self, value): + self._delay = value + + @property + def aggregator(self): + return self._aggregator + + @aggregator.setter + def aggregator(self, value): + self._aggregator = _get_aggregator(value) + + @property + def corruptor(self): + return self._corruptor + + @corruptor.setter + def corruptor(self, value): + self._corruptor = value + + @property + def enabled(self): + return self._enabled + + @enabled.setter + def enabled(self, value): + self._enabled = value + + @property + def array_spec(self): + """The `ArraySpec` which describes observation arrays from this observable. + + If this property is `None`, then the specification should be inferred by + actually retrieving an observation from this observable. + """ + return None + + @abc.abstractmethod + def _callable(self, physics): + pass + + def observation_callable(self, physics, random_state=None): + """A callable which returns a (potentially corrupted) observation.""" + raw_callable = self._callable(physics) + if self._corruptor: + def _corrupted(): + return self._corruptor(raw_callable(), random_state=random_state) + return _corrupted + else: + return raw_callable + + def __call__(self, physics, random_state=None): + """Convenience function to just call an observable.""" + return self.observation_callable(physics, random_state)() + + def configure(self, **kwargs): + """Sets multiple attributes of this observable. + + Args: + **kwargs: The keyword argument names correspond to the attributes + being modified. + Raises: + AttributeError: If kwargs contained an attribute not in the observable. + """ + for key, value in six.iteritems(kwargs): + if not hasattr(self, key): + raise AttributeError('Cannot add attribute %s in configure.' % key) + self.__setattr__(key, value) + + +class Generic(Observable): + """A generic observable defined via a callable.""" + + def __init__(self, raw_observation_callable, update_interval=1, + buffer_size=None, delay=None, + aggregator=None, corruptor=None): + """Initializes this observable. + + Args: + raw_observation_callable: A callable which accepts a single argument of + type `control.base.Physics` and returns the observation value. + update_interval: (optional) An integer, number of simulation steps between + successive updates to the value of this observable. + buffer_size: (optional) The maximum size of the returned buffer. + This option is only relevant when used in conjunction with an + `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will + be used. + delay: (optional) Number of additional simulation steps that must be + taken before an observation is returned. This option is only relevant + when used in conjunction with an`observation.Updater`. If None, + `observation.DEFAULT_DELAY` will be used. + aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that + performs a reduction operation over the first dimension of the buffered + observation before it is returned. A value of `None` means that no + aggregation will be performed and the whole buffer will be returned. + corruptor: (optional) A callable which takes a single observation as + an argument, modifies it, and returns it. An example use case for this + is to add random noise to the observation. When used in a + `BufferedWrapper`, the corruptor is applied to the observation before + it is added to the buffer. In particular, this means that the aggregator + operates on corrupted observations. + """ + self._raw_callable = raw_observation_callable + super(Generic, self).__init__( + update_interval, buffer_size, delay, aggregator, corruptor) + + def _callable(self, physics): + return lambda: self._raw_callable(physics) + + +class MujocoFeature(Observable): + """An observable corresponding to a named MuJoCo feature.""" + + def __init__(self, kind, feature_name, update_interval=1, + buffer_size=None, delay=None, + aggregator=None, corruptor=None): + """Initializes this observable. + + Args: + kind: A string corresponding to a field name in MuJoCo's mjData struct. + feature_name: A string, or list of strings, or a callable returning + either, corresponding to the name(s) of an entity in the + MuJoCo XML model. + update_interval: (optional) An integer, number of simulation steps between + successive updates to the value of this observable. + buffer_size: (optional) The maximum size of the returned buffer. + This option is only relevant when used in conjunction with an + `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will + be used. + delay: (optional) Number of additional simulation steps that must be + taken before an observation is returned. This option is only relevant + when used in conjunction with an`observation.Updater`. If None, + `observation.DEFAULT_DELAY` will be used. + aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that + performs a reduction operation over the first dimension of the buffered + observation before it is returned. A value of `None` means that no + aggregation will be performed and the whole buffer will be returned. + corruptor: (optional) A callable which takes a single observation as + an argument, modifies it, and returns it. An example use case for this + is to add random noise to the observation. When used in a + `BufferedWrapper`, the corruptor is applied to the observation before + it is added to the buffer. In particular, this means that the aggregator + operates on corrupted observations. + """ + self._kind = kind + self._feature_name = feature_name + super(MujocoFeature, self).__init__( + update_interval, buffer_size, delay, aggregator, corruptor) + + def _callable(self, physics): + named_indexer_for_kind = physics.named.data.__getattribute__(self._kind) + if callable(self._feature_name): + return lambda: named_indexer_for_kind[self._feature_name()] + else: + return lambda: named_indexer_for_kind[self._feature_name] + + +class MujocoCamera(Observable): + """An observable corresponding to a MuJoCo camera.""" + + def __init__(self, camera_name, height=240, width=320, update_interval=1, + buffer_size=None, delay=None, + aggregator=None, corruptor=None, depth=False): + """Initializes this observable. + + Args: + camera_name: A string corresponding to the name of a camera in the + MuJoCo XML model. + height: (optional) An integer, the height of the rendered image. + width: (optional) An integer, the width of the rendered image. + update_interval: (optional) An integer, number of simulation steps between + successive updates to the value of this observable. + buffer_size: (optional) The maximum size of the returned buffer. + This option is only relevant when used in conjunction with an + `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will + be used. + delay: (optional) Number of additional simulation steps that must be + taken before an observation is returned. This option is only relevant + when used in conjunction with an`observation.Updater`. If None, + `observation.DEFAULT_DELAY` will be used. + aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that + performs a reduction operation over the first dimension of the buffered + observation before it is returned. A value of `None` means that no + aggregation will be performed and the whole buffer will be returned. + corruptor: (optional) A callable which takes a single observation as + an argument, modifies it, and returns it. An example use case for this + is to add random noise to the observation. When used in a + `BufferedWrapper`, the corruptor is applied to the observation before + it is added to the buffer. In particular, this means that the aggregator + operates on corrupted observations. + depth: (optional) A boolean. If `True`, renders a depth image (1-channel) + instead of RGB (3-channel). + """ + self._camera_name = camera_name + self._height = height + self._width = width + + self._n_channels = 1 if depth else 3 + self._dtype = np.float32 if depth else np.uint8 + self._depth = depth + super(MujocoCamera, self).__init__( + update_interval, buffer_size, delay, aggregator, corruptor) + + @property + def height(self): + return self._height + + @height.setter + def height(self, value): + self._height = value + + @property + def width(self): + return self._width + + @width.setter + def width(self, value): + self._width = value + + @property + def array_spec(self): + return specs.Array( + shape=(self._height, self._width, self._n_channels), dtype=self._dtype) + + def _callable(self, physics): + return lambda: physics.render( # pylint: disable=g-long-lambda + self._height, self._width, self._camera_name, depth=self._depth) diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py new file mode 100644 index 0000000..0f519c2 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/base_test.py @@ -0,0 +1,154 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for observable.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. + +from absl.testing import absltest +from dm_control import mujoco +from dm_control.composer.observation import fake_physics +from dm_control.composer.observation.observable import base +import numpy as np +import six + + +_MJCF = """ + + + + + + + + + + + +""" + + +class _FakeBaseObservable(base.Observable): + + def _callable(self, physics): + pass + + +class ObservableTest(absltest.TestCase): + + def testBaseProperties(self): + fake_observable = _FakeBaseObservable(update_interval=42, + buffer_size=5, + delay=10, + aggregator=None, + corruptor=None) + self.assertEqual(fake_observable.update_interval, 42) + self.assertEqual(fake_observable.buffer_size, 5) + self.assertEqual(fake_observable.delay, 10) + + fake_observable.update_interval = 48 + self.assertEqual(fake_observable.update_interval, 48) + + fake_observable.buffer_size = 7 + self.assertEqual(fake_observable.buffer_size, 7) + + fake_observable.delay = 13 + self.assertEqual(fake_observable.delay, 13) + + enabled = not fake_observable.enabled + fake_observable.enabled = not fake_observable.enabled + self.assertEqual(fake_observable.enabled, enabled) + + def testGeneric(self): + physics = fake_physics.FakePhysics() + repeated_observable = base.Generic( + fake_physics.FakePhysics.repeated, update_interval=42) + repeated_observation = repeated_observable.observation_callable(physics)() + self.assertEqual(repeated_observable.update_interval, 42) + np.testing.assert_array_equal(repeated_observation, [0, 0]) + + def testMujocoFeature(self): + physics = mujoco.Physics.from_xml_string(_MJCF) + + hinge_observable = base.MujocoFeature( + kind='qpos', feature_name='my_hinge') + hinge_observation = hinge_observable.observation_callable(physics)() + np.testing.assert_array_equal( + hinge_observation, physics.named.data.qpos['my_hinge']) + + box_observable = base.MujocoFeature( + kind='geom_xpos', feature_name='small_sphere', update_interval=5) + box_observation = box_observable.observation_callable(physics)() + self.assertEqual(box_observable.update_interval, 5) + np.testing.assert_array_equal( + box_observation, physics.named.data.geom_xpos['small_sphere']) + + observable_from_callable = base.MujocoFeature( + kind='geom_xpos', feature_name=lambda: ['my_box', 'small_sphere']) + observation_from_callable = ( + observable_from_callable.observation_callable(physics)()) + np.testing.assert_array_equal( + observation_from_callable, + physics.named.data.geom_xpos[['my_box', 'small_sphere']]) + + def testMujocoCamera(self): + physics = mujoco.Physics.from_xml_string(_MJCF) + + camera_observable = base.MujocoCamera( + camera_name='world', height=480, width=640, update_interval=7) + self.assertEqual(camera_observable.update_interval, 7) + camera_observation = camera_observable.observation_callable(physics)() + np.testing.assert_array_equal( + camera_observation, physics.render(480, 640, 'world')) + self.assertEqual(camera_observation.shape, + camera_observable.array_spec.shape) + self.assertEqual(camera_observation.dtype, + camera_observable.array_spec.dtype) + + camera_observable.height = 300 + camera_observable.width = 400 + camera_observation = camera_observable.observation_callable(physics)() + self.assertEqual(camera_observable.height, 300) + self.assertEqual(camera_observable.width, 400) + np.testing.assert_array_equal( + camera_observation, physics.render(300, 400, 'world')) + self.assertEqual(camera_observation.shape, + camera_observable.array_spec.shape) + self.assertEqual(camera_observation.dtype, + camera_observable.array_spec.dtype) + + def testCorruptor(self): + physics = fake_physics.FakePhysics() + def add_twelve(old_value, random_state): + del random_state # Unused. + return [x + 12 for x in old_value] + repeated_observable = base.Generic( + fake_physics.FakePhysics.repeated, corruptor=add_twelve) + corrupted = repeated_observable.observation_callable( + physics=physics, random_state=None)() + np.testing.assert_array_equal(corrupted, [12, 12]) + + def testInvalidAggregatorName(self): + name = 'invalid_name' + with six.assertRaisesRegex(self, KeyError, 'Unrecognized aggregator name'): + _ = _FakeBaseObservable(update_interval=3, buffer_size=2, delay=1, + aggregator=name, corruptor=None) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py new file mode 100644 index 0000000..84993c5 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf.py @@ -0,0 +1,257 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Observables that are defined in terms of MJCF elements.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from dm_control import mjcf +from dm_control.composer.observation.observable import base +from dm_env import specs +import numpy as np + + +_BOTH_SEGMENTATION_AND_DEPTH_ENABLED = ( + '`segmentation` and `depth` cannot both be `True`.') + + +def _check_mjcf_element(obj): + if not isinstance(obj, mjcf.Element): + raise ValueError( + 'expected an `mjcf.Element`, got type {}: {}'.format(type(obj), obj)) + + +def _check_mjcf_element_iterable(obj_iterable): + if not isinstance(obj_iterable, collections.Iterable): + obj_iterable = (obj_iterable,) + for obj in obj_iterable: + _check_mjcf_element(obj) + + +class MJCFFeature(base.Observable): + """An observable corresponding to an element in an MJCF model.""" + + def __init__(self, kind, mjcf_element, update_interval=1, + buffer_size=None, delay=None, + aggregator=None, corruptor=None, index=None): + """Initializes this observable. + + Args: + kind: The name of an attribute of a bound `mjcf.Physics` instance. See the + docstring for `mjcf.Physics.bind()` for examples showing this syntax. + mjcf_element: An `mjcf.Element`, or iterable of `mjcf.Element`. + update_interval: (optional) An integer, number of simulation steps between + successive updates to the value of this observable. + buffer_size: (optional) The maximum size of the returned buffer. + This option is only relevant when used in conjunction with an + `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will + be used. + delay: (optional) Number of additional simulation steps that must be + taken before an observation is returned. This option is only relevant + when used in conjunction with an`observation.Updater`. If None, + `observation.DEFAULT_DELAY` will be used. + aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that + performs a reduction operation over the first dimension of the buffered + observation before it is returned. A value of `None` means that no + aggregation will be performed and the whole buffer will be returned. + corruptor: (optional) A callable which takes a single observation as + an argument, modifies it, and returns it. An example use case for this + is to add random noise to the observation. When used in a + `BufferedWrapper`, the corruptor is applied to the observation before + it is added to the buffer. In particular, this means that the aggregator + operates on corrupted observations. + index: (optional) An index that is to be applied to an array attribute + to pick out a slice or particular items. As a syntactic sugar, + `MJCFFeature` also implements `__getitem__` that returns a copy of the + same observable with an index applied. + + Raises: + ValueError: if `mjcf_element` is not an `mjcf.Element`. + """ + _check_mjcf_element_iterable(mjcf_element) + self._kind = kind + self._mjcf_element = mjcf_element + self._index = index + super(MJCFFeature, self).__init__( + update_interval, buffer_size, delay, aggregator, corruptor) + + def _callable(self, physics): + binding = physics.bind(self._mjcf_element) + if self._index is not None: + return lambda: getattr(binding, self._kind)[self._index] + else: + return lambda: getattr(binding, self._kind) + + def __getitem__(self, key): + if self._index is not None: + raise NotImplementedError( + 'slicing an already-sliced MJCFFeature observable is not supported') + return MJCFFeature(self._kind, self._mjcf_element, self._update_interval, + self._buffer_size, self._delay, self._aggregator, + self._corruptor, key) + + +class MJCFCamera(base.Observable): + """An observable corresponding to a camera in an MJCF model.""" + + def __init__(self, + mjcf_element, + height=240, + width=320, + update_interval=1, + buffer_size=None, + delay=None, + aggregator=None, + corruptor=None, + depth=False, + segmentation=False, + scene_option=None): + """Initializes this observable. + + Args: + mjcf_element: A `mjcf.Element`. + height: (optional) An integer, the height of the rendered image. + width: (optional) An integer, the width of the rendered image. + update_interval: (optional) An integer, number of simulation steps between + successive updates to the value of this observable. + buffer_size: (optional) The maximum size of the returned buffer. + This option is only relevant when used in conjunction with an + `observation.Updater`. If None, `observation.DEFAULT_BUFFER_SIZE` will + be used. + delay: (optional) Number of additional simulation steps that must be + taken before an observation is returned. This option is only relevant + when used in conjunction with an`observation.Updater`. If None, + `observation.DEFAULT_DELAY` will be used. + aggregator: (optional) Name of an item in `AGGREGATORS` or a callable that + performs a reduction operation over the first dimension of the buffered + observation before it is returned. A value of `None` means that no + aggregation will be performed and the whole buffer will be returned. + corruptor: (optional) A callable which takes a single observation as + an argument, modifies it, and returns it. An example use case for this + is to add random noise to the observation. When used in a + `BufferedWrapper`, the corruptor is applied to the observation before + it is added to the buffer. In particular, this means that the aggregator + operates on corrupted observations. + depth: (optional) A boolean. If `True`, renders a depth image (1-channel) + instead of RGB (3-channel). + segmentation: (optional) A boolean. If `True`, renders a segmentation mask + (2-channel, int32) labeling the objects in the scene with their + (mjModel ID, mjtObj enum object type) pair. Background pixels are + set to (-1, -1). + scene_option: An optional `wrapper.MjvOption` instance that can be used to + render the scene with custom visualization options. If None then the + default options will be used. + + Raises: + ValueError: if `mjcf_element` is not a element. + ValueError: if segmentation and depth flags are both set to True. + """ + _check_mjcf_element(mjcf_element) + if mjcf_element.tag != 'camera': + raise ValueError( + 'expected a element: got {}'.format(mjcf_element)) + self._mjcf_element = mjcf_element + self._height = height + self._width = width + + if segmentation and depth: + raise ValueError(_BOTH_SEGMENTATION_AND_DEPTH_ENABLED) + if segmentation: + self._dtype = np.int32 + self._n_channels = 2 + elif depth: + self._dtype = np.float32 + self._n_channels = 1 + else: + self._dtype = np.uint8 + self._n_channels = 3 + self._depth = depth + self._segmentation = segmentation + self._scene_option = scene_option + super(MJCFCamera, self).__init__( + update_interval, buffer_size, delay, aggregator, corruptor) + + @property + def height(self): + return self._height + + @height.setter + def height(self, value): + self._height = value + + @property + def width(self): + return self._width + + @width.setter + def width(self, value): + self._width = value + + @property + def depth(self): + return self._depth + + @depth.setter + def depth(self, value): + self._depth = value + + @property + def segmentation(self): + return self._segmentation + + @segmentation.setter + def segmentation(self, value): + self._segmentation = value + + @property + def array_spec(self): + if self._depth: + # Note that these are loose bounds - the exact bounds are given by: + # extent*(znear, zfar), however the values of these parameters are unknown + # since we don't have access to the compiled model within this method. + minimum = 0.0 + maximum = np.inf + elif self._segmentation: + # -1 denotes background pixels. See dm_control.mujoco.Camera.render for + # further details. + minimum = -1 + maximum = np.iinfo(self._dtype).max + else: + minimum = np.iinfo(self._dtype).min + maximum = np.iinfo(self._dtype).max + + return specs.BoundedArray( + minimum=minimum, + maximum=maximum, + shape=(self._height, self._width, self._n_channels), + dtype=self._dtype) + + def _callable(self, physics): + + def get_observation(): + pixels = physics.render( + height=self._height, + width=self._width, + camera_id=self._mjcf_element.full_identifier, + depth=self._depth, + segmentation=self._segmentation, + scene_option=self._scene_option) + return np.atleast_3d(pixels) + + return get_observation diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py new file mode 100644 index 0000000..9150fa9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/observable/mjcf_test.py @@ -0,0 +1,181 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for mjcf observables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import mjcf +from dm_control.composer.observation.observable import mjcf as mjcf_observable +from dm_env import specs +import numpy as np +import six + +_MJCF = """ + + + + + + + + + + + +""" + + +class ObservableTest(parameterized.TestCase): + + def testMJCFFeature(self): + mjcf_root = mjcf.from_xml_string(_MJCF) + physics = mjcf.Physics.from_mjcf_model(mjcf_root) + + my_hinge = mjcf_root.find('joint', 'my_hinge') + hinge_observable = mjcf_observable.MJCFFeature( + kind='qpos', mjcf_element=my_hinge) + hinge_observation = hinge_observable.observation_callable(physics)() + np.testing.assert_array_equal( + hinge_observation, physics.named.data.qpos[my_hinge.full_identifier]) + + small_sphere = mjcf_root.find('geom', 'small_sphere') + sphere_observable = mjcf_observable.MJCFFeature( + kind='xpos', mjcf_element=small_sphere, update_interval=5) + sphere_observation = sphere_observable.observation_callable(physics)() + self.assertEqual(sphere_observable.update_interval, 5) + np.testing.assert_array_equal( + sphere_observation, physics.named.data.geom_xpos[ + small_sphere.full_identifier]) + + my_box = mjcf_root.find('geom', 'my_box') + list_observable = mjcf_observable.MJCFFeature( + kind='xpos', mjcf_element=[my_box, small_sphere]) + list_observation = ( + list_observable.observation_callable(physics)()) + np.testing.assert_array_equal( + list_observation, + physics.named.data.geom_xpos[[my_box.full_identifier, + small_sphere.full_identifier]]) + + with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'): + mjcf_observable.MJCFFeature('qpos', 'my_hinge') + with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'): + mjcf_observable.MJCFFeature('geom_xpos', [my_box, 'small_sphere']) + + def testMJCFFeatureIndex(self): + mjcf_root = mjcf.from_xml_string(_MJCF) + physics = mjcf.Physics.from_mjcf_model(mjcf_root) + + small_sphere = mjcf_root.find('geom', 'small_sphere') + sphere_xmat = np.array( + physics.named.data.geom_xmat[small_sphere.full_identifier]) + + observable_xrow = mjcf_observable.MJCFFeature( + 'xmat', small_sphere, index=[1, 3, 5, 7]) + np.testing.assert_array_equal( + observable_xrow.observation_callable(physics)(), + sphere_xmat[[1, 3, 5, 7]]) + + observable_yyzz = mjcf_observable.MJCFFeature('xmat', small_sphere)[2:6] + np.testing.assert_array_equal( + observable_yyzz.observation_callable(physics)(), sphere_xmat[2:6]) + + def testMJCFCamera(self): + mjcf_root = mjcf.from_xml_string(_MJCF) + physics = mjcf.Physics.from_mjcf_model(mjcf_root) + + camera = mjcf_root.find('camera', 'world') + camera_observable = mjcf_observable.MJCFCamera( + mjcf_element=camera, height=480, width=640, update_interval=7) + self.assertEqual(camera_observable.update_interval, 7) + camera_observation = camera_observable.observation_callable(physics)() + np.testing.assert_array_equal( + camera_observation, physics.render(480, 640, 'world')) + self.assertEqual(camera_observation.shape, + camera_observable.array_spec.shape) + self.assertEqual(camera_observation.dtype, + camera_observable.array_spec.dtype) + + camera_observable.height = 300 + camera_observable.width = 400 + camera_observation = camera_observable.observation_callable(physics)() + self.assertEqual(camera_observable.height, 300) + self.assertEqual(camera_observable.width, 400) + np.testing.assert_array_equal( + camera_observation, physics.render(300, 400, 'world')) + self.assertEqual(camera_observation.shape, + camera_observable.array_spec.shape) + self.assertEqual(camera_observation.dtype, + camera_observable.array_spec.dtype) + + with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'): + mjcf_observable.MJCFCamera('world') + with six.assertRaisesRegex(self, ValueError, 'expected an `mjcf.Element`'): + mjcf_observable.MJCFCamera([camera]) + with six.assertRaisesRegex(self, ValueError, 'expected a '): + mjcf_observable.MJCFCamera(mjcf_root.find('body', 'body')) + + @parameterized.parameters( + dict(camera_type='rgb', channels=3, dtype=np.uint8, + minimum=0, maximum=255), + dict(camera_type='depth', channels=1, dtype=np.float32, + minimum=0., maximum=np.inf), + dict(camera_type='segmentation', channels=2, dtype=np.int32, + minimum=-1, maximum=np.iinfo(np.int32).max), + ) + def testMJCFCameraSpecs(self, camera_type, channels, dtype, minimum, maximum): + width = 640 + height = 480 + shape = (height, width, channels) + expected_spec = specs.BoundedArray( + shape=shape, dtype=dtype, minimum=minimum, maximum=maximum) + mjcf_root = mjcf.from_xml_string(_MJCF) + camera = mjcf_root.find('camera', 'world') + observable_kwargs = {} if camera_type == 'rgb' else {camera_type: True} + camera_observable = mjcf_observable.MJCFCamera( + mjcf_element=camera, height=height, width=width, update_interval=7, + **observable_kwargs) + self.assertEqual(camera_observable.array_spec, expected_spec) + + def testMJCFSegCamera(self): + mjcf_root = mjcf.from_xml_string(_MJCF) + physics = mjcf.Physics.from_mjcf_model(mjcf_root) + camera = mjcf_root.find('camera', 'world') + camera_observable = mjcf_observable.MJCFCamera( + mjcf_element=camera, height=480, width=640, update_interval=7, + segmentation=True) + self.assertEqual(camera_observable.update_interval, 7) + camera_observation = camera_observable.observation_callable(physics)() + np.testing.assert_array_equal( + camera_observation, + physics.render(480, 640, 'world', segmentation=True)) + camera_observable.array_spec.validate(camera_observation) + + def testErrorIfSegmentationAndDepthBothEnabled(self): + camera = mjcf.from_xml_string(_MJCF).find('camera', 'world') + with self.assertRaisesWithLiteralMatch( + ValueError, mjcf_observable._BOTH_SEGMENTATION_AND_DEPTH_ENABLED): + mjcf_observable.MJCFCamera(mjcf_element=camera, segmentation=True, + depth=True) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/updater.py b/DMC/src/env/dm_control/dm_control/composer/observation/updater.py new file mode 100644 index 0000000..937745a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/updater.py @@ -0,0 +1,302 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""An object that creates and updates buffers for enabled observables.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from absl import logging + +from dm_control.composer.observation import obs_buffer +from dm_env import specs +import numpy as np +import six +from six.moves import range + +DEFAULT_BUFFER_SIZE = 1 +DEFAULT_UPDATE_INTERVAL = 1 +DEFAULT_DELAY = 0 + + +class _EnabledObservable(object): + """Encapsulates an enabled observable, its buffer, and its update schedule.""" + + __slots__ = ('observable', 'observation_callable', + 'buffer', 'update_schedule') + + def __init__(self, observable, physics, random_state, + strip_singleton_buffer_dim): + self.observable = observable + self.observation_callable = ( + observable.observation_callable(physics, random_state)) + + obs_spec = self.observable.array_spec + if obs_spec is None: + # We take an observation to determine the shape and dtype of the array. + # This occurs outside of an episode and doesn't affect environment + # behavior. At this point the physics state is not guaranteed to be valid, + # so we might get a `PhysicsError` if the observation callable calls + # `physics.forward`. We suppress such errors since they do not matter as + # far as the shape and dtype of the observation are concerned. + with physics.suppress_physics_errors(): + obs_array = self.observation_callable() + obs_array = np.asarray(obs_array) + obs_spec = specs.Array(shape=obs_array.shape, dtype=obs_array.dtype) + self.buffer = obs_buffer.Buffer( + buffer_size=(observable.buffer_size or DEFAULT_BUFFER_SIZE), + shape=obs_spec.shape, dtype=obs_spec.dtype, + strip_singleton_buffer_dim=strip_singleton_buffer_dim) + self.update_schedule = collections.deque() + + +def _call_if_callable(arg): + if callable(arg): + return arg() + else: + return arg + + +def _validate_structure(structure): + """Validates the structure of the given observables collection. + + The collection must either be a dict, or a (list or tuple) of dicts. + + Args: + structure: A candidate collection of observables. + + Returns: + A boolean that is `True` if `structure` is either a list or a tuple, or + `False` otherwise. + + Raises: + ValueError: If `structure` is neither a dict nor a (list or tuple) of dicts. + """ + is_nested = isinstance(structure, (list, tuple)) + if is_nested: + is_valid = all(isinstance(obj, dict) for obj in structure) + else: + is_valid = isinstance(structure, dict) + if not is_valid: + raise ValueError( + '`observables` should be a dict, or a (list or tuple) of dicts' + ': got {}'.format(structure)) + return is_nested + + +class Updater(object): + """Creates and updates buffers for enabled observables.""" + + def __init__(self, observables, physics_steps_per_control_step=1, + strip_singleton_buffer_dim=False): + self._physics_steps_per_control_step = physics_steps_per_control_step + self._strip_singleton_buffer_dim = strip_singleton_buffer_dim + self._step_counter = 0 + self._observables = observables + self._is_nested = _validate_structure(observables) + self._enabled_structure = None + self._enabled_list = None + + def reset(self, physics, random_state): + """Resets this updater's state.""" + + def make_buffers_dict(observables): + """Makes observable states in a dict.""" + # Use `type(observables)` so that our output structure respects the + # original dict subclass (e.g. OrderedDict). + out_dict = type(observables)() + for key, value in six.iteritems(observables): + if value.enabled: + out_dict[key] = _EnabledObservable(value, physics, random_state, + self._strip_singleton_buffer_dim) + return out_dict + + if self._is_nested: + self._enabled_structure = type(self._observables)( + make_buffers_dict(obs_dict) for obs_dict in self._observables) + self._enabled_list = [] + for enabled_dict in self._enabled_structure: + self._enabled_list.extend(enabled_dict.values()) + else: + self._enabled_structure = make_buffers_dict(self._observables) + self._enabled_list = self._enabled_structure.values() + + self._step_counter = 0 + for enabled in self._enabled_list: + first_delay = _call_if_callable(enabled.observable.delay or DEFAULT_DELAY) + enabled.buffer.insert( + 0, first_delay, + enabled.observation_callable()) + + def observation_spec(self): + """The observation specification for this environment. + + Returns a dict mapping the names of enabled observations to their + corresponding `Array` or `BoundedArray` specs. + + If an obs has a BoundedArray spec, but uses an aggregator that + does not preserve those bounds (such as `sum`), it will be mapped to an + (unbounded) `Array` spec. If using a bounds-preserving custom aggregator + `my_agg`, give it an attribute `my_agg.preserves_bounds = True` to indicate + to this method that it is bounds-preserving. + + The returned specification is only valid as of the previous call + to `reset`. In particular, it is an error to call this function before + the first call to `reset`. + + Returns: + A dict mapping observation name to `Array` or `BoundedArray` spec + containing the observation shape and dtype, and possibly bounds. + + Raises: + RuntimeError: If this method is called before `reset` has been called. + """ + if self._enabled_structure is None: + raise RuntimeError('`reset` must be called before `observation_spec`.') + + def make_observation_spec_dict(enabled_dict): + """Makes a dict of enabled observation specs from of observables.""" + out_dict = type(enabled_dict)() + for name, enabled in six.iteritems(enabled_dict): + + if isinstance(enabled.observable.array_spec, specs.BoundedArray): + bounds = (enabled.observable.array_spec.minimum, + enabled.observable.array_spec.maximum) + else: + bounds = None + + if enabled.observable.aggregator: + aggregator = enabled.observable.aggregator + aggregated = aggregator(np.zeros(enabled.buffer.shape, + dtype=enabled.buffer.dtype)) + shape = aggregated.shape + dtype = aggregated.dtype + + # Ditch bounds if the aggregator isn't known to be bounds-preserving. + if bounds: + if not hasattr(aggregator, 'preserves_bounds'): + logging.warning('Ignoring the bounds of this observable\'s spec, ' + 'as its aggregator method has no boolean ' + '`preserves_bounds` attrubute.') + bounds = None + elif not aggregator.preserves_bounds: + bounds = None + else: + shape = enabled.buffer.shape + dtype = enabled.buffer.dtype + + if bounds: + spec = specs.BoundedArray(minimum=bounds[0], + maximum=bounds[1], + shape=shape, + dtype=dtype, + name=name) + else: + spec = specs.Array(shape=shape, dtype=dtype, name=name) + + out_dict[name] = spec + return out_dict + + if self._is_nested: + enabled_specs = type(self._enabled_structure)( + make_observation_spec_dict(enabled_dict) + for enabled_dict in self._enabled_structure) + else: + enabled_specs = make_observation_spec_dict(self._enabled_structure) + + return enabled_specs + + def prepare_for_next_control_step(self): + """Simulates the next control step and optimizes the update schedule.""" + if self._enabled_structure is None: + raise RuntimeError('`reset` must be called before `before_step`.') + for enabled in self._enabled_list: + update_interval = ( + enabled.observable.update_interval or DEFAULT_UPDATE_INTERVAL) + delay = enabled.observable.delay or DEFAULT_DELAY + buffer_size = enabled.observable.buffer_size or DEFAULT_BUFFER_SIZE + + if (update_interval == DEFAULT_UPDATE_INTERVAL and delay == DEFAULT_DELAY + and buffer_size < self._physics_steps_per_control_step): + for i in reversed(range(buffer_size)): + next_step = ( + self._step_counter + self._physics_steps_per_control_step - i) + next_delay = DEFAULT_DELAY + enabled.update_schedule.append((next_step, next_delay)) + else: + if enabled.update_schedule: + last_scheduled_step = enabled.update_schedule[-1][0] + else: + last_scheduled_step = self._step_counter + max_step = self._step_counter + 2 * self._physics_steps_per_control_step + while last_scheduled_step < max_step: + next_update_interval = _call_if_callable(update_interval) + next_step = last_scheduled_step + next_update_interval + next_delay = _call_if_callable(delay) + enabled.update_schedule.append((next_step, next_delay)) + last_scheduled_step = next_step + # Optimize the schedule by planning ahead and dropping unseen entries. + enabled.buffer.drop_unobserved_upcoming_items( + enabled.update_schedule, self._physics_steps_per_control_step) + + def update(self): + if self._enabled_structure is None: + raise RuntimeError('`reset` must be called before `after_substep`.') + self._step_counter += 1 + for enabled in self._enabled_list: + if (enabled.update_schedule and + enabled.update_schedule[0][0] == self._step_counter): + timestamp, delay = enabled.update_schedule.popleft() + enabled.buffer.insert( + timestamp, delay, + enabled.observation_callable()) + + def get_observation(self): + """Gets the current observation. + + The returned observation is only valid as of the previous call + to `reset`. In particular, it is an error to call this function before + the first call to `reset`. + + Returns: + A dict, or list of dicts, or tuple of dicts, of observation values. + The returned structure corresponds to the structure of the `observables` + that was given at initialization time. + + Raises: + RuntimeError: If this method is called before `reset` has been called. + """ + if self._enabled_structure is None: + raise RuntimeError('`reset` must be called before `observation`.') + + def aggregate_dict(enabled_dict): + out_dict = type(enabled_dict)() + for name, enabled in six.iteritems(enabled_dict): + if enabled.observable.aggregator: + aggregated = enabled.observable.aggregator( + enabled.buffer.read(self._step_counter)) + else: + aggregated = enabled.buffer.read(self._step_counter) + out_dict[name] = aggregated + return out_dict + + if self._is_nested: + return type(self._enabled_structure)( + aggregate_dict(enabled_dict) + for enabled_dict in self._enabled_structure) + else: + return aggregate_dict(self._enabled_structure) diff --git a/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py b/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py new file mode 100644 index 0000000..5dd04f7 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/observation/updater_test.py @@ -0,0 +1,279 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for observation.observation_updater.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools +import math + +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.composer.observation import fake_physics +from dm_control.composer.observation import observable +from dm_control.composer.observation import updater +from dm_env import specs +import numpy as np +import six +from six.moves import range + + +class DeterministicSequence(object): + + def __init__(self, sequence): + self._iter = itertools.cycle(sequence) + + def __call__(self, random_state=None): + del random_state # unused + return six.next(self._iter) + + +class BoundedGeneric(observable.Generic): + + def __init__(self, raw_observation_callable, minimum, maximum, **kwargs): + super(BoundedGeneric, self).__init__( + raw_observation_callable=raw_observation_callable, + **kwargs) + self._bounds = (minimum, maximum) + + @property + def array_spec(self): + datum = np.array(self(None, None)) + return specs.BoundedArray(shape=datum.shape, + dtype=datum.dtype, + minimum=self._bounds[0], + maximum=self._bounds[1]) + + +class UpdaterTest(parameterized.TestCase): + + @parameterized.parameters(list, tuple) + def testNestedSpecsAndValues(self, list_or_tuple): + observables = list_or_tuple(( + {'one': observable.Generic(lambda _: 1.), + 'two': observable.Generic(lambda _: [2, 2]), + }, collections.OrderedDict([ + ('three', observable.Generic(lambda _: np.full((2, 2), 3))), + ('four', observable.Generic(lambda _: [4.])), + ('five', observable.Generic(lambda _: 5)), + ('six', BoundedGeneric(lambda _: [2, 2], 1, 4)), + ('seven', BoundedGeneric(lambda _: 2, 1, 4, aggregator='sum')), + ]) + )) + + observables[0]['two'].enabled = True + observables[1]['three'].enabled = True + observables[1]['five'].enabled = True + observables[1]['six'].enabled = True + observables[1]['seven'].enabled = True + + observation_updater = updater.Updater(observables) + observation_updater.reset(physics=fake_physics.FakePhysics(), + random_state=None) + + def make_spec(obs): + array = np.array(obs.observation_callable(None, None)()) + shape = array.shape if obs.aggregator else (1,) + array.shape + + if (isinstance(obs, BoundedGeneric) and + obs.aggregator is not observable.base.AGGREGATORS['sum']): + return specs.BoundedArray(shape=shape, + dtype=array.dtype, + minimum=obs.array_spec.minimum, + maximum=obs.array_spec.maximum) + else: + return specs.Array(shape=shape, dtype=array.dtype) + + expected_specs = list_or_tuple(( + {'two': make_spec(observables[0]['two'])}, + collections.OrderedDict([ + ('three', make_spec(observables[1]['three'])), + ('five', make_spec(observables[1]['five'])), + ('six', make_spec(observables[1]['six'])), + ('seven', make_spec(observables[1]['seven'])), + ]) + )) + + actual_specs = observation_updater.observation_spec() + self.assertIs(type(actual_specs), type(expected_specs)) + for actual_dict, expected_dict in zip(actual_specs, expected_specs): + self.assertIs(type(actual_dict), type(expected_dict)) + self.assertEqual(actual_dict, expected_dict) + + def make_value(obs): + value = obs(physics=None, random_state=None) + if obs.aggregator: + return value + else: + value = np.array(value) + value = value[np.newaxis, ...] + return value + + expected_values = list_or_tuple(( + {'two': make_value(observables[0]['two'])}, + collections.OrderedDict([ + ('three', make_value(observables[1]['three'])), + ('five', make_value(observables[1]['five'])), + ('six', make_value(observables[1]['six'])), + ('seven', make_value(observables[1]['seven'])), + ]) + )) + + actual_values = observation_updater.get_observation() + self.assertIs(type(actual_values), type(expected_values)) + for actual_dict, expected_dict in zip(actual_values, expected_values): + self.assertIs(type(actual_dict), type(expected_dict)) + self.assertLen(actual_dict, len(expected_dict)) + for actual, expected in zip(six.iteritems(actual_dict), + six.iteritems(expected_dict)): + actual_name, actual_value = actual + expected_name, expected_value = expected + self.assertEqual(actual_name, expected_name) + np.testing.assert_array_equal(actual_value, expected_value) + + def assertCorrectSpec( + self, spec, expected_shape, expected_dtype, expected_name): + self.assertEqual(spec.shape, expected_shape) + self.assertEqual(spec.dtype, expected_dtype) + self.assertEqual(spec.name, expected_name) + + def testObservationSpecInference(self): + physics = fake_physics.FakePhysics() + physics.observables['repeated'].buffer_size = 5 + physics.observables['matrix'].buffer_size = 4 + physics.observables['sqrt'] = observable.Generic( + fake_physics.FakePhysics.sqrt, buffer_size=3) + + for obs in six.itervalues(physics.observables): + obs.enabled = True + + observation_updater = updater.Updater(physics.observables) + observation_updater.reset(physics=physics, random_state=None) + + spec = observation_updater.observation_spec() + self.assertCorrectSpec(spec['repeated'], (5, 2), np.int, 'repeated') + self.assertCorrectSpec(spec['matrix'], (4, 2, 3), np.int, 'matrix') + self.assertCorrectSpec(spec['sqrt'], (3,), np.float, 'sqrt') + + def testObservation(self): + physics = fake_physics.FakePhysics() + physics.observables['repeated'].buffer_size = 5 + physics.observables['matrix'].delay = 1 + physics.observables['sqrt'] = observable.Generic( + fake_physics.FakePhysics.sqrt, update_interval=7, + buffer_size=3, delay=2) + for obs in six.itervalues(physics.observables): + obs.enabled = True + with physics.reset_context(): + pass + + physics_steps_per_control_step = 5 + observation_updater = updater.Updater( + physics.observables, physics_steps_per_control_step) + observation_updater.reset(physics=physics, random_state=None) + + for control_step in range(0, 200): + observation_updater.prepare_for_next_control_step() + for _ in range(physics_steps_per_control_step): + physics.step() + observation_updater.update() + + step_counter = (control_step + 1) * physics_steps_per_control_step + + observation = observation_updater.get_observation() + def assert_correct_buffer(obs_name, expected_callable, + observation=observation, + step_counter=step_counter): + update_interval = (physics.observables[obs_name].update_interval + or updater.DEFAULT_UPDATE_INTERVAL) + buffer_size = (physics.observables[obs_name].buffer_size + or updater.DEFAULT_BUFFER_SIZE) + delay = (physics.observables[obs_name].delay + or updater.DEFAULT_DELAY) + + # The final item in the buffer is the current time, less the delay, + # rounded _down_ to the nearest multiple of the update interval. + end = update_interval * int( + math.floor((step_counter - delay) / update_interval)) + + # Figure out the first item in the buffer by working backwards from + # the final item in multiples of the update interval. + start = end - (buffer_size - 1) * update_interval + + # Clamp both the start and end step number below by zero. + buffer_range = range(max(0, start), max(0, end + 1), update_interval) + + # Arrays with expected shapes, filled with expected default values. + expected_value_spec = observation_updater.observation_spec()[obs_name] + expected_values = np.zeros(shape=expected_value_spec.shape, + dtype=expected_value_spec.dtype) + + # The arrays are filled from right to left, such that the most recent + # entry is the rightmost one, and any padding is on the left. + for index, timestamp in enumerate(reversed(buffer_range)): + expected_values[-(index+1)] = expected_callable(timestamp) + + np.testing.assert_array_equal(observation[obs_name], expected_values) + + assert_correct_buffer('twice', lambda x: 2*x) + assert_correct_buffer('matrix', lambda x: [[x]*3]*2) + assert_correct_buffer('repeated', lambda x: [x, x]) + assert_correct_buffer('sqrt', np.sqrt) + + def testVariableRatesAndDelays(self): + physics = fake_physics.FakePhysics() + physics.observables['time'] = observable.Generic( + lambda physics: physics.time(), + buffer_size=3, + # observations produced on step numbers 20*N + [0, 3, 5, 8, 11, 15, 16] + update_interval=DeterministicSequence([3, 2, 3, 3, 4, 1, 4]), + # observations arrive on step numbers 20*N + [3, 8, 7, 12, 11, 17, 20] + delay=DeterministicSequence([3, 5, 2, 5, 1, 2, 4])) + physics.observables['time'].enabled = True + + physics_steps_per_control_step = 10 + observation_updater = updater.Updater( + physics.observables, physics_steps_per_control_step) + observation_updater.reset(physics=physics, random_state=None) + + # Run through a few cycles of the variation sequences to make sure that + # cross-control-boundary behaviour is correct. + for i in range(5): + observation_updater.prepare_for_next_control_step() + for _ in range(physics_steps_per_control_step): + physics.step() + observation_updater.update() + np.testing.assert_array_equal( + observation_updater.get_observation()['time'], + 20*i + np.array([0, 5, 3])) + + observation_updater.prepare_for_next_control_step() + for _ in range(physics_steps_per_control_step): + physics.step() + observation_updater.update() + # Note that #11 is dropped since it arrives after #8, + # whose large delay caused it to cross the control step boundary at #10. + np.testing.assert_array_equal( + observation_updater.get_observation()['time'], + 20*i + np.array([8, 15, 16])) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/robot.py b/DMC/src/env/dm_control/dm_control/composer/robot.py new file mode 100644 index 0000000..fa4dfad --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/robot.py @@ -0,0 +1,38 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Module defining the abstract robot class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from dm_control.composer import entity +import numpy as np +import six + +DOWN_QUATERNION = np.array([0., 0.70710678118, 0.70710678118, 0.]) + + +@six.add_metaclass(abc.ABCMeta) +class Robot(entity.Entity): + """The abstract base class for robots.""" + + @abc.abstractproperty + def actuators(self): + """Returns the actuator elements of the robot.""" + raise NotImplementedError diff --git a/DMC/src/env/dm_control/dm_control/composer/task.py b/DMC/src/env/dm_control/dm_control/composer/task.py new file mode 100644 index 0000000..606a109 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/task.py @@ -0,0 +1,332 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Abstract base class for a Composer task.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import copy +import sys + +from dm_control import mujoco +from dm_env import specs +import six +from six.moves import range + + +def _check_timesteps_divisible(control_timestep, physics_timestep): + num_steps = control_timestep / physics_timestep + rounded_num_steps = int(round(num_steps)) + if abs(num_steps - rounded_num_steps) > 1e-6: + raise ValueError( + 'Control timestep should be an integer multiple of physics timestep' + ': got {!r} and {!r}'.format(control_timestep, physics_timestep)) + return rounded_num_steps + + +@six.add_metaclass(abc.ABCMeta) +class Task(object): + """Abstract base class for a Composer task.""" + + @abc.abstractproperty + def root_entity(self): + """A `base.Entity` instance for this task.""" + raise NotImplementedError + + def iter_entities(self): + return self.root_entity.iter_entities() + + @property + def observables(self): + """An OrderedDict of `control.Observable` instances for this task. + + Task subclasses should generally NOT override this property. + + This property is automatically computed by combining the observables dict + provided by each `Entity` present in this task, and any additional + observables returned via the `task_observables` property. + + To provide an observable to an agent, the task code should either set + `enabled` property of an `Entity`-bound observable to `True`, or override + the `task_observables` property to provide additional observables not bound + to an `Entity`. + + Returns: + An `collections.OrderedDict` mapping strings to instances of + `control.Observable`. + """ + # Make a shallow copy of the OrderedDict, not the Observables themselves. + observables = copy.copy(self.task_observables) + for entity in self.root_entity.iter_entities(): + observables.update(entity.observables.as_dict()) + return observables + + @property + def task_observables(self): + """An OrderedDict of task-specific `control.Observable` instances. + + A task should override this property if it wants to provide additional + observables to the agent that are not already provided by any `Entity` that + forms part of the task's model. For example, this may be used to provide + observations that is derived from relative poses between two entities. + + Returns: + An `collections.OrderedDict` mapping strings to instances of + `control.Observable`. + """ + return collections.OrderedDict() + + def after_compile(self, physics, random_state): + """A callback which is executed after the Mujoco Physics is recompiled. + + Args: + physics: An instance of `control.Physics`. + random_state: An instance of `np.random.RandomState`. + """ + pass + + def _check_root_entity(self, callee_name): + try: + _ = self.root_entity + except: # pylint: disable=bare-except + err_type, err, tb = sys.exc_info() + message = ( + 'call to `{}` made before `root_entity` is available;\n' + 'original error message: {}'.format(callee_name, str(err))) + six.reraise(err_type, err_type(message), tb) + + @property + def control_timestep(self): + """Returns the agent's control timestep for this task (in seconds).""" + self._check_root_entity('control_timestep') + if hasattr(self, '_control_timestep'): + return self._control_timestep + else: + return self.physics_timestep + + @control_timestep.setter + def control_timestep(self, new_value): + """Changes the agent's control timestep for this task. + + Args: + new_value: the new control timestep (in seconds). + + Raises: + ValueError: if `new_value` is set and is not divisible by + `physics_timestep`. + """ + self._check_root_entity('control_timestep') + _check_timesteps_divisible(new_value, self.physics_timestep) + self._control_timestep = new_value + + @property + def physics_timestep(self): + """Returns the physics timestep for this task (in seconds).""" + self._check_root_entity('physics_timestep') + if self.root_entity.mjcf_model.option.timestep is None: + return 0.002 # MuJoCo's default. + else: + return self.root_entity.mjcf_model.option.timestep + + @physics_timestep.setter + def physics_timestep(self, new_value): + """Changes the physics simulation timestep for this task. + + Args: + new_value: the new simulation timestep (in seconds). + + Raises: + ValueError: if `control_timestep` is set and is not divisible by + `new_value`. + """ + self._check_root_entity('physics_timestep') + if hasattr(self, '_control_timestep'): + _check_timesteps_divisible(self._control_timestep, new_value) + self.root_entity.mjcf_model.option.timestep = new_value + + def set_timesteps(self, control_timestep, physics_timestep): + """Changes the agent's control timestep and physics simulation timestep. + + This is equivalent to modifying `control_timestep` and `physics_timestep` + simultaneously. The divisibility check is performed between the two + new values. + + Args: + control_timestep: the new agent's control timestep (in seconds). + physics_timestep: the new physics simulation timestep (in seconds). + + Raises: + ValueError: if `control_timestep` is not divisible by `physics_timestep`. + """ + self._check_root_entity('set_timesteps') + _check_timesteps_divisible(control_timestep, physics_timestep) + self.root_entity.mjcf_model.option.timestep = physics_timestep + self._control_timestep = control_timestep + + @property + def physics_steps_per_control_step(self): + """Returns number of physics steps per agent's control step.""" + return _check_timesteps_divisible( + self.control_timestep, self.physics_timestep) + + def action_spec(self, physics): + """Returns a `BoundedArray` spec matching the `Physics` actuators. + + BoundedArray.name should contain a tab-separated list of actuator names. + When overloading this method, non-MuJoCo actuators should be added to the + top of the list when possible, as a matter of convention. + + Args: + physics: used to query actuator names in the model. + """ + names = [physics.model.id2name(i, 'actuator') or str(i) + for i in range(physics.model.nu)] + action_spec = mujoco.action_spec(physics) + return specs.BoundedArray(shape=action_spec.shape, + dtype=action_spec.dtype, + minimum=action_spec.minimum, + maximum=action_spec.maximum, + name='\t'.join(names)) + + def get_reward_spec(self): + """Optional method to define non-scalar rewards for a `Task`.""" + return None + + def get_discount_spec(self): + """Optional method to define non-scalar discounts for a `Task`.""" + return None + + def initialize_episode_mjcf(self, random_state): + """Modifies the MJCF model of this task before the next episode begins. + + The Environment calls this method and recompiles the physics + if necessary before calling `initialize_episode`. + + Args: + random_state: An instance of `np.random.RandomState`. + """ + pass + + def initialize_episode(self, physics, random_state): + """Modifies the physics state before the next episode begins. + + The Environment calls this method after `initialize_episode_mjcf`, and also + after the physics has been recompiled if necessary. + + Args: + physics: An instance of `control.Physics`. + random_state: An instance of `np.random.RandomState`. + """ + pass + + def before_step(self, physics, action, random_state): + """A callback which is executed before an agent control step. + + The default implementation sets the control signal for the actuators in + `physics` to be equal to `action`. Subclasses that override this method + should ensure that the overriding method also sets the control signal before + returning, either by calling `super(..., self).before_step`, or by setting + the control signal explicitly (e.g. in order to create a non-trivial mapping + between `action` and the control signal). + + Args: + physics: An instance of `control.Physics`. + action: A NumPy array corresponding to agent actions. + random_state: An instance of `np.random.RandomState` (unused). + """ + del random_state # Unused. + physics.set_control(action) + + def before_substep(self, physics, action, random_state): + """A callback which is executed before a simulation step. + + Actuation can be set, or overridden, in this callback. + + Args: + physics: An instance of `control.Physics`. + action: A NumPy array corresponding to agent actions. + random_state: An instance of `np.random.RandomState`. + """ + pass + + def after_substep(self, physics, random_state): + """A callback which is executed after a simulation step. + + Args: + physics: An instance of `control.Physics`. + random_state: An instance of `np.random.RandomState`. + """ + pass + + def after_step(self, physics, random_state): + """A callback which is executed after an agent control step. + + Args: + physics: An instance of `control.Physics`. + random_state: An instance of `np.random.RandomState`. + """ + pass + + @abc.abstractmethod + def get_reward(self, physics): + """Calculates the reward signal given the physics state. + + Args: + physics: A Physics object. + + Returns: + A float + """ + raise NotImplementedError + + def should_terminate_episode(self, physics): # pylint: disable=unused-argument + """Determines whether the episode should terminate given the physics state. + + Args: + physics: A Physics object + + Returns: + A boolean + """ + return False + + def get_discount(self, physics): # pylint: disable=unused-argument + """Calculates the reward discount factor given the physics state. + + Args: + physics: A Physics object + + Returns: + A float + """ + return 1.0 + + +class NullTask(Task): + """A class that wraps a single `Entity` into a `Task` with no reward.""" + + def __init__(self, root_entity): + self._root_entity = root_entity + + @property + def root_entity(self): + return self._root_entity + + def get_reward(self, physics): + return 0.0 diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py b/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py new file mode 100644 index 0000000..4831e1e --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/__init__.py @@ -0,0 +1,137 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A module that helps manage model variation in Composer environments.""" + +import collections +import copy + +from dm_control.composer.variation.base import Variation +from dm_control.composer.variation.variation_values import evaluate +import six + + +class _VariationInfo(object): + + __slots__ = ['initial_value', 'variation'] + + def __init__(self, initial_value=None, variation=None): + self.initial_value = initial_value + self.variation = variation + + +class MJCFVariator(object): + """Helper object for applying variations to MJCF attributes. + + An instance of this class remembers the original value of each MJCF attribute + the first time a variation is applied. The original value is then passed as an + argument to each variation callable. + """ + + def __init__(self): + self._variations = collections.defaultdict(dict) + + def bind_attributes(self, element, **kwargs): + """Binds variations to attributes of an MJCF element. + + Args: + element: An `mjcf.Element` object. + **kwargs: Keyword arguments mapping attribute names to the corresponding + variations. A variation is either a fixed value or a callable that + optionally takes the original value of an attribute and returns a + new value. + """ + for attribute_name, variation in six.iteritems(kwargs): + if variation is None and attribute_name in self._variations[element]: + del self._variations[element][attribute_name] + else: + initial_value = copy.copy(getattr(element, attribute_name)) + self._variations[element][attribute_name] = ( + _VariationInfo(initial_value, variation)) + + def apply_variations(self, random_state): + """Applies variations in-place to the specified MJCF element. + + Args: + random_state: A `numpy.random.RandomState` instance. + """ + for element, attribute_variations in six.iteritems(self._variations): + new_values = {} + for attribute_name, variation_info in six.iteritems(attribute_variations): + current_value = getattr(element, attribute_name) + if variation_info.initial_value is None: + variation_info.initial_value = copy.copy(current_value) + new_values[attribute_name] = evaluate( + variation_info.variation, variation_info.initial_value, + current_value, random_state) + element.set_attributes(**new_values) + + def clear(self): + """Clears all bound attribute variations.""" + self._variations.clear() + + def reset_initial_values(self): + for variations in six.itervalues(self._variations): + for variation_info in six.itervalues(variations): + variation_info.initial_value = None + + +class PhysicsVariator(object): + """Helper object for applying variations to MjModel and MjData. + + An instance of this class remembers the original value of each attribute + the first time a variation is applied. The original value is then passed as an + argument to each variation callable. + """ + + def __init__(self): + self._variations = collections.defaultdict(dict) + + def bind_attributes(self, element, **kwargs): + """Binds variations to attributes of an MJCF element. + + Args: + element: An `mjcf.Element` object. + **kwargs: Keyword arguments mapping attribute names to the corresponding + variations. A variation is either a fixed value or a callable that + optionally takes the original value of an attribute and returns a + new value. + """ + for attribute_name, variation in six.iteritems(kwargs): + if variation is None and attribute_name in self._variations[element]: + del self._variations[element][attribute_name] + else: + self._variations[element][attribute_name] = ( + _VariationInfo(None, variation)) + + def apply_variations(self, physics, random_state): + for element, variations in six.iteritems(self._variations): + binding = physics.bind(element) + for attribute_name, variation_info in six.iteritems(variations): + current_value = getattr(binding, attribute_name) + if variation_info.initial_value is None: + variation_info.initial_value = copy.copy(current_value) + setattr(binding, attribute_name, evaluate( + variation_info.variation, variation_info.initial_value, + current_value, random_state)) + + def clear(self): + """Clears all bound attribute variations.""" + self._variations.clear() + + def reset_initial_values(self): + for variations in six.itervalues(self._variations): + for variation_info in six.itervalues(variations): + variation_info.initial_value = None diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/base.py b/DMC/src/env/dm_control/dm_control/composer/variation/base.py new file mode 100644 index 0000000..7783c8d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/base.py @@ -0,0 +1,99 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for variations and binary operations on variations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import operator + +from dm_control.composer.variation import variation_values +import six + + +@six.add_metaclass(abc.ABCMeta) +class Variation(object): + """Abstract base class for variations.""" + + @abc.abstractmethod + def __call__(self, initial_value, current_value, random_state): + """Generates a value for this variation. + + Args: + initial_value: The original value of the attribute being varied. + Absolute variations may ignore this argument. + current_value: The current value of the attribute being varied. + Absolute variations may ignore this argument. + random_state: A `numpy.RandomState` used to generate the value. + Deterministic variations may ignore this argument. + + Returns: + The next value for this variation. + """ + + def __add__(self, other): + return _BinaryOperation(operator.add, self, other) + + def __radd__(self, other): + return _BinaryOperation(operator.add, other, self) + + def __sub__(self, other): + return _BinaryOperation(operator.sub, self, other) + + def __rsub__(self, other): + return _BinaryOperation(operator.sub, other, self) + + def __mul__(self, other): + return _BinaryOperation(operator.mul, self, other) + + def __rmul__(self, other): + return _BinaryOperation(operator.mul, other, self) + + def __truediv__(self, other): + return _BinaryOperation(operator.truediv, self, other) + + def __rtruediv__(self, other): + return _BinaryOperation(operator.truediv, other, self) + + def __floordiv__(self, other): + return _BinaryOperation(operator.floordiv, self, other) + + def __rfloordiv__(self, other): + return _BinaryOperation(operator.floordiv, other, self) + + def __pow__(self, other): + return _BinaryOperation(operator.pow, self, other) + + def __rpow__(self, other): + return _BinaryOperation(operator.pow, other, self) + + +class _BinaryOperation(Variation): + """Represents the result of applying a binary operator to two Variations.""" + + def __init__(self, op, first, second): + self._first = first + self._second = second + self._op = op + + def __call__(self, initial_value=None, current_value=None, random_state=None): + first_value = variation_values.evaluate( + self._first, initial_value, current_value, random_state) + second_value = variation_values.evaluate( + self._second, initial_value, current_value, random_state) + return self._op(first_value, second_value) diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/colors.py b/DMC/src/env/dm_control/dm_control/composer/variation/colors.py new file mode 100644 index 0000000..2fc091d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/colors.py @@ -0,0 +1,79 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Variations in colors. + +Classes in this module allow users to specify a variations for each channel in +a variety of color spaces. The generated values are always RGBA arrays. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import colorsys + +from dm_control.composer.variation import base +from dm_control.composer.variation import variation_values +import numpy as np + + +class RgbVariation(base.Variation): + """Represents a variation in the RGB color space. + + This class allows users to specify independent variations in the R, G, B, and + alpha channels of a color, and generates the corresponding array of RGBA + values. + """ + + def __init__(self, r, g, b, alpha=1.0): + self._r, self._g, self._b = r, g, b + self._alpha = alpha + + def __call__(self, initial_value=None, current_value=None, random_state=None): + return np.asarray( + variation_values.evaluate([self._r, self._g, self._b, self._alpha], + initial_value, current_value, random_state)) + + +class HsvVariation(base.Variation): + """Represents a variation in the HSV color space. + + This class allows users to specify independent variations in the H, S, V, and + alpha channels of a color, and generates the corresponding array of RGBA + values. + """ + + def __init__(self, h, s, v, alpha=1.0): + self._h, self._s, self._v = h, s, v + self._alpha = alpha + + def __call__(self, initial_value=None, current_value=None, random_state=None): + h, s, v, alpha = variation_values.evaluate( + (self._h, self._s, self._v, self._alpha), initial_value, current_value, + random_state) + return np.asarray(list(colorsys.hsv_to_rgb(h, s, v)) + [alpha]) + + +class GrayVariation(HsvVariation): + """Represents a variation in gray level. + + This class allows users to specify independent variations in the gray level + and alpha channels of a color, and generates the corresponding array of RGBA + values. + """ + + def __init__(self, gray_level, alpha=1.0): + super(GrayVariation, self).__init__(h=0.0, s=0.0, v=gray_level, alpha=alpha) diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py b/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py new file mode 100644 index 0000000..afea874 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/deterministic.py @@ -0,0 +1,51 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Deterministic variations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control.composer.variation import base + + +class Constant(base.Variation): + """Wraps a constant value into a Variation object. + + This class is provided mainly for use in tests, to check that variations are + invoked correctly without having to introduce randomness in test cases. + """ + + def __init__(self, value): + self._value = value + + def __call__(self, initial_value=None, current_value=None, random_state=None): + return self._value + + +class Sequence(base.Variation): + """Variation representing a fixed sequence of values.""" + + def __init__(self, values): + self._values = values + self._iterator = iter(self._values) + + def __call__(self, initial_value=None, current_value=None, random_state=None): + try: + return next(self._iterator) + except StopIteration: + self._iterator = iter(self._values) + return next(self._iterator) diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py b/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py new file mode 100644 index 0000000..e948bf9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/distributions.py @@ -0,0 +1,212 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Standard statistical distributions that conform to the Variation API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import functools + +from dm_control.composer import variation +from dm_control.composer.variation import base +import numpy as np +import six + + +@six.add_metaclass(abc.ABCMeta) +class Distribution(base.Variation): + """Base Distribution class for sampling a parametrized distribution. + + Subclasses need to implement `_callable`, which needs to return a callable + based on the random_state passed as arg. This callable then gets called using + the arguments passed to the constructor, after being evaluated. This allows + the distribution parameters themselves to be instances of `base.Variation`. + By default samples are drawn in the shape of `initial_value`, unless the + optional `single_sample` constructor arg is set to `True`, in which case only + a single sample is drawn. + """ + __slots__ = ('_single_sample', '_args', '_kwargs') + + def __init__(self, *args, **kwargs): + self._single_sample = kwargs.pop('single_sample', False) + self._args = args + self._kwargs = kwargs + + def __call__(self, initial_value=None, current_value=None, random_state=None): + local_random_state = random_state or np.random + size = (None if self._single_sample or initial_value is None # pylint: disable=g-long-ternary + else np.shape(initial_value)) + local_args = variation.evaluate(self._args, + initial_value=initial_value, + current_value=current_value, + random_state=random_state) + local_kwargs = variation.evaluate(self._kwargs, + initial_value=initial_value, + current_value=current_value, + random_state=random_state) + return self._callable(local_random_state)(*local_args, + size=size, + **local_kwargs) + + @abc.abstractmethod + def _callable(self, random_state): + raise NotImplementedError + + +class Uniform(Distribution): + __slots__ = () + + def __init__(self, low=0.0, high=1.0, single_sample=False): + super(Uniform, self).__init__(low=low, high=high, + single_sample=single_sample) + + def _callable(self, random_state): + return random_state.uniform + + +class UniformInteger(Distribution): + __slots__ = () + + def __init__(self, low, high=None, single_sample=False): + super(UniformInteger, self).__init__(low, high=high, + single_sample=single_sample) + + def _callable(self, random_state): + return random_state.randint + + +class UniformChoice(Distribution): + __slots__ = () + + def __init__(self, choices, single_sample=False): + super(UniformChoice, self).__init__(choices, single_sample=single_sample) + + def _callable(self, random_state): + return random_state.choice + + +class UniformPointOnSphere(base.Variation): + __slots__ = () + + def __call__(self, initial_value=None, + current_value=None, random_state=None): + random_state = random_state or np.random + size = 3 if initial_value is None else np.append(np.shape(initial_value), 3) + axis = random_state.normal(size=size) + axis /= np.linalg.norm(axis, axis=-1, keepdims=True) + return axis + + +class Normal(Distribution): + __slots__ = () + + def __init__(self, loc=0.0, scale=1.0, single_sample=False): + super(Normal, self).__init__(loc=loc, scale=scale, + single_sample=single_sample) + + def _callable(self, random_state): + return random_state.normal + + +class LogNormal(Distribution): + __slots__ = () + + def __init__(self, mean=0.0, sigma=1.0, single_sample=False): + super(LogNormal, self).__init__(mean=mean, sigma=sigma, + single_sample=single_sample) + + def _callable(self, random_state): + return random_state.lognormal + + +class Exponential(Distribution): + __slots__ = () + + def __init__(self, scale=1.0, single_sample=False): + super(Exponential, self).__init__(scale=scale, single_sample=single_sample) + + def _callable(self, random_state): + return random_state.exponential + + +class Poisson(Distribution): + __slots__ = () + + def __init__(self, lam=1.0, single_sample=False): + super(Poisson, self).__init__(lam=lam, single_sample=single_sample) + + def _callable(self, random_state): + return random_state.poisson + + +class Bernoulli(Distribution): + __slots__ = () + + def __init__(self, prob=0.5, single_sample=False): + super(Bernoulli, self).__init__(prob, single_sample=single_sample) + + def _callable(self, random_state): + return functools.partial(random_state.binomial, 1) + + +_NEGATIVE_STDEV = '`stdev` must be >= 0, got {}.' +_NEGATIVE_TIMESCALE = '`timescale` must be >= 0, got {}.' + + +class BiasedRandomWalk(base.Variation): + """A Class for generating noise from a zero-mean Ornstein-Uhlenbeck process. + + Let + `retain = np.exp(-1. / timescale)` + and + `scale = stdev * sqrt(1 - (retain * retain))` + Then the discete-time first-order filtered diffusion process + `x_next = retain * x + N(0, scale))` + has standard deviation `stdev` and characteristic timescale `timescale`. + """ + __slots__ = ('_scale', '_value') + + def __init__(self, stdev=0.1, timescale=10.): + """Initializes a `BiasedRandomWalk`. + + Args: + stdev: Float. Standard deviation of the output sequence. + timescale: Integer. Number of timesteps characteristic of the random walk. + After `timescale` steps the correlation is reduced by exp(-1). Larger + or equal to 0, where a value of 0 is an uncorrelated normal + distribution. + + Raises: + ValueError: if either `stdev` or `timescale` is negative. + """ + if stdev < 0: + raise ValueError(_NEGATIVE_STDEV.format(stdev)) + if timescale < 0: + raise ValueError(_NEGATIVE_TIMESCALE.format(timescale)) + elif timescale == 0: + self._retain = 0. + else: + self._retain = np.exp(-1. / timescale) + self._scale = stdev * np.sqrt(1 - (self._retain * self._retain)) + self._value = 0.0 + + def __call__(self, initial_value=None, current_value=None, random_state=None): + random_state = random_state or np.random + self._value = (self._retain * self._value + + random_state.normal(loc=0.0, scale=self._scale)) + return self._value diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py new file mode 100644 index 0000000..1abe269 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/distributions_test.py @@ -0,0 +1,115 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for distributions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.composer.variation import distributions +import numpy as np +from six.moves import range + +RANDOM_SEED = 123 +NUM_ITERATIONS = 100 + + +def _make_random_state(): + return np.random.RandomState(RANDOM_SEED) + + +class DistributionsTest(parameterized.TestCase): + + def setUp(self): + super(DistributionsTest, self).setUp() + self._variation_random_state = _make_random_state() + self._np_random_state = _make_random_state() + + def testUniform(self): + lower, upper = [2, 3, 4], [5, 6, 7] + variation = distributions.Uniform(low=lower, high=upper) + for _ in range(NUM_ITERATIONS): + np.testing.assert_array_equal( + variation(random_state=self._variation_random_state), + self._np_random_state.uniform(lower, upper)) + + def testUniformChoice(self): + choices = ['apple', 'banana', 'cherry'] + variation = distributions.UniformChoice(choices) + for _ in range(NUM_ITERATIONS): + self.assertEqual( + variation(random_state=self._variation_random_state), + self._np_random_state.choice(choices)) + + def testUniformPointOnSphere(self): + variation = distributions.UniformPointOnSphere() + samples = [] + for _ in range(NUM_ITERATIONS): + sample = variation(random_state=self._variation_random_state) + self.assertEqual(sample.size, 3) + np.testing.assert_approx_equal(np.linalg.norm(sample), 1.0) + samples.append(sample) + # Make sure that none of the samples are the same. + self.assertLen( + set(np.reshape(np.asarray(samples), -1)), 3 * NUM_ITERATIONS) + + def testNormal(self): + loc, scale = 1, 2 + variation = distributions.Normal(loc=loc, scale=scale) + for _ in range(NUM_ITERATIONS): + self.assertEqual( + variation(random_state=self._variation_random_state), + self._np_random_state.normal(loc, scale)) + + def testExponential(self): + scale = 3 + variation = distributions.Exponential(scale=scale) + for _ in range(NUM_ITERATIONS): + self.assertEqual( + variation(random_state=self._variation_random_state), + self._np_random_state.exponential(scale)) + + def testPoisson(self): + lam = 4 + variation = distributions.Poisson(lam=lam) + for _ in range(NUM_ITERATIONS): + self.assertEqual( + variation(random_state=self._variation_random_state), + self._np_random_state.poisson(lam)) + + @parameterized.parameters(0, 10) + def testBiasedRandomWalk(self, timescale): + stdev = 1. + variation = distributions.BiasedRandomWalk(stdev=stdev, timescale=timescale) + sequence = [variation(random_state=self._variation_random_state) + for _ in range(int(max(timescale, 1)*NUM_ITERATIONS*1000))] + self.assertAlmostEqual(np.mean(sequence), 0., delta=0.01) + self.assertAlmostEqual(np.std(sequence), stdev, delta=0.01) + + @parameterized.parameters( + dict(arg_name='stdev', template=distributions._NEGATIVE_STDEV), + dict(arg_name='timescale', template=distributions._NEGATIVE_TIMESCALE)) + def testBiasedRandomWalkExceptions(self, arg_name, template): + bad_value = -1. + with self.assertRaisesWithLiteralMatch( + ValueError, template.format(bad_value)): + _ = distributions.BiasedRandomWalk(**{arg_name: bad_value}) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/noises.py b/DMC/src/env/dm_control/dm_control/composer/variation/noises.py new file mode 100644 index 0000000..a43dee3 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/noises.py @@ -0,0 +1,63 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Meta-variations that modify original values by a specified variation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control.composer.variation import base +from dm_control.composer.variation import variation_values + + +class Additive(base.Variation): + """A variation that adds to an existing value. + + This variation takes a value generated by another variation and adds it to an + existing value. In cumulative mode, the generated value is added to the + current value being varied. In non-cumulative mode, the generated value is + added to a fixed initial value. + """ + + def __init__(self, variation, cumulative=False): + self._variation = variation + self._cumulative = cumulative + + def __call__(self, initial_value=None, current_value=None, random_state=None): + base_value = current_value if self._cumulative else initial_value + return base_value + ( + variation_values.evaluate(self._variation, initial_value, current_value, + random_state)) + + +class Multiplicative(base.Variation): + """A variation that multiplies to an existing value. + + This variation takes a value generated by another variation and multiplies it + to an existing value. In cumulative mode, the generated value is multiplied to + the current value being varied. In non-cumulative mode, the generated value is + multiplied to a fixed initial value. + """ + + def __init__(self, variation, cumulative=False): + self._variation = variation + self._cumulative = cumulative + + def __call__(self, initial_value=None, current_value=None, random_state=None): + base_value = current_value if self._cumulative else initial_value + return base_value * ( + variation_values.evaluate(self._variation, initial_value, current_value, + random_state)) diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py new file mode 100644 index 0000000..da8145f --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/noises_test.py @@ -0,0 +1,93 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for noises.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.composer.variation import deterministic +from dm_control.composer.variation import noises +from six.moves import range + +NUM_ITERATIONS = 100 + + +class NoisesTest(parameterized.TestCase): + + @parameterized.parameters(False, True) + def testAdditive(self, use_constant_variation_object): + amount = 2 + if use_constant_variation_object: + variation = noises.Additive(deterministic.Constant(amount)) + else: + variation = noises.Additive(amount) + initial_value = 0 + current_value = initial_value + for _ in range(NUM_ITERATIONS): + current_value = variation( + initial_value=initial_value, current_value=current_value) + self.assertEqual(current_value, initial_value + amount) + + @parameterized.parameters(False, True) + def testAdditiveCumulative(self, use_constant_variation_object): + amount = 3 + if use_constant_variation_object: + variation = noises.Additive( + deterministic.Constant(amount), cumulative=True) + else: + variation = noises.Additive(amount, cumulative=True) + initial_value = 1 + current_value = initial_value + for i in range(NUM_ITERATIONS): + current_value = variation( + initial_value=initial_value, current_value=current_value) + self.assertEqual(current_value, initial_value + amount * (i + 1)) + + @parameterized.parameters(False, True) + def testMultiplicative(self, use_constant_variation_object): + amount = 23 + if use_constant_variation_object: + variation = noises.Multiplicative(deterministic.Constant(amount)) + else: + variation = noises.Multiplicative(amount) + initial_value = 3 + current_value = initial_value + for _ in range(NUM_ITERATIONS): + current_value = variation( + initial_value=initial_value, current_value=current_value) + self.assertEqual(current_value, initial_value * amount) + + @parameterized.parameters(False, True) + def testMultiplicativeCumulative(self, use_constant_variation_object): + amount = 2 + if use_constant_variation_object: + variation = noises.Multiplicative( + deterministic.Constant(amount), cumulative=True) + else: + variation = noises.Multiplicative(amount, cumulative=True) + initial_value = 3 + current_value = initial_value + for i in range(NUM_ITERATIONS): + current_value = variation( + initial_value=initial_value, current_value=current_value) + self.assertEqual(current_value, initial_value * amount ** (i + 1)) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py b/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py new file mode 100644 index 0000000..b84e19b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/rotations.py @@ -0,0 +1,80 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Variations in 3D rotations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control.composer.variation import base +from dm_control.composer.variation import variation_values +import numpy as np + +IDENTITY_QUATERNION = np.array([1., 0., 0., 0.]) + + +class UniformQuaternion(base.Variation): + """Uniformly distributed unit quaternions.""" + + def __call__(self, initial_value=None, current_value=None, random_state=None): + random_state = random_state or np.random + u1, u2, u3 = random_state.uniform([0.] * 3, [1., 2. * np.pi, 2. * np.pi]) + return np.array([np.sqrt(1. - u1) * np.sin(u2), + np.sqrt(1. - u1) * np.cos(u2), + np.sqrt(u1) * np.sin(u3), + np.sqrt(u1) * np.cos(u3)]) + + +class QuaternionFromAxisAngle(base.Variation): + """Quaternion variation specified in terms of variations in axis and angle.""" + + def __init__(self, axis, angle): + self._axis = axis + self._angle = angle + + def __call__(self, initial_value=None, current_value=None, random_state=None): + random_state = random_state or np.random + axis = variation_values.evaluate( + self._axis, initial_value, current_value, random_state) + angle = variation_values.evaluate( + self._angle, initial_value, current_value, random_state) + sine, cosine = np.sin(angle / 2), np.cos(angle / 2) + return np.array([cosine, axis[0] * sine, axis[1] * sine, axis[2] * sine]) + + +class QuaternionPreMultiply(base.Variation): + """A variation that pre-multiplies an existing quaternion value. + + This variation takes a quaternion value generated by another variation and + pre-multiplies it to an existing value. In cumulative mode, the new quaternion + is pre-multiplied to the current value being varied. In non-cumulative mode, + the new quaternion is pre-multiplied to a fixed initial value. + """ + + def __init__(self, quat, cumulative=False): + self._quat = quat + self._cumulative = cumulative + + def __call__(self, initial_value=None, current_value=None, random_state=None): + random_state = random_state or np.random + q1 = variation_values.evaluate(self._quat, initial_value, current_value, + random_state) + q2 = current_value if self._cumulative else initial_value + return np.array([ + q1[0]*q2[0] - q1[1]*q2[1] - q1[2]*q2[2] - q1[3]*q2[3], + q1[0]*q2[1] + q1[1]*q2[0] + q1[2]*q2[3] - q1[3]*q2[2], + q1[0]*q2[2] - q1[1]*q2[3] + q1[2]*q2[0] + q1[3]*q2[1], + q1[0]*q2[3] + q1[1]*q2[2] - q1[2]*q2[1] + q1[3]*q2[0]]) diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py b/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py new file mode 100644 index 0000000..8e4b752 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/variation_test.py @@ -0,0 +1,52 @@ +# Copyright 2018 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for base variation operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import operator +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.composer import variation +from dm_control.composer.variation import deterministic + + +class VariationTest(parameterized.TestCase): + + def setUp(self): + self.value_1 = 3 + self.variation_1 = deterministic.Constant(self.value_1) + self.value_2 = 5 + self.variation_2 = deterministic.Constant(self.value_2) + + @parameterized.parameters(['add', 'sub', 'mul', 'truediv', 'floordiv', 'pow']) + def test_operator(self, name): + func = getattr(operator, name) + self.assertEqual( + variation.evaluate(func(self.value_1, self.variation_2)), + func(self.value_1, self.value_2)) + self.assertEqual( + variation.evaluate(func(self.variation_1, self.value_2)), + func(self.value_1, self.value_2)) + self.assertEqual( + variation.evaluate(func(self.variation_1, self.variation_2)), + func(self.value_1, self.value_2)) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py b/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py new file mode 100644 index 0000000..8791225 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/composer/variation/variation_values.py @@ -0,0 +1,39 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Utilities for handling nested structures of callables or constants.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tree + + +def evaluate(structure, *args, **kwargs): + """Evaluates a arbitrarily nested structure of callables or constant values. + + Args: + structure: An arbitrarily nested structure of callables or constant values. + By "structures", we mean lists, tuples, namedtuples, or dicts. + *args: Positional arguments passed to each callable in `structure`. + **kwargs: Keyword arguments passed to each callable in `structure. + + Returns: + The same nested structure, with each callable replaced by the value returned + by calling it. + """ + return tree.map_structure( + lambda x: x(*args, **kwargs) if callable(x) else x, structure) diff --git a/DMC/src/env/dm_control/dm_control/entities/__init__.py b/DMC/src/env/dm_control/dm_control/entities/__init__.py new file mode 100644 index 0000000..4224c02 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/DMC/src/env/dm_control/dm_control/entities/props/__init__.py b/DMC/src/env/dm_control/dm_control/entities/props/__init__.py new file mode 100644 index 0000000..5b8863c --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/props/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Composer entities corresponding to props. + +A "prop" is typically a non-actuated entity representing an object in the world. +""" + +# Prop imports +from dm_control.entities.props.position_detector import PositionDetector +from dm_control.entities.props.primitive import Primitive diff --git a/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py b/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py new file mode 100644 index 0000000..7332ab7 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/props/position_detector.py @@ -0,0 +1,270 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Detects the presence of registered entities within a cuboidal region.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control import mjcf +import numpy as np + +_RENDERED_HEIGHT_IN_2D_MODE = 0.1 + + +def _ensure_3d(pos): + # Pad the array with a zero if its length is 2. + if len(pos) == 2: + return np.hstack([pos, 0.]) + return pos + + +class _Detection(object): + + __slots__ = ('entity', 'detected') + + def __init__(self, entity, detected=False): + self.entity = entity + self.detected = detected + + +class PositionDetector(composer.Entity): + """Detects the presence of registered entities within a cuboidal region. + + An entity is considered "detected" if the `xpos` value of any one of its geom + lies within the active region defined by this detector. Note that this is NOT + a contact-based detector. Generally speaking, a geom will not be detected + until it is already "half inside" the region. + + This detector supports both 2D and 3D modes. In 2D mode, the active region + has an effective infinite height along the z-direction. + + This detector also provides an "inverted" detection mode, where an entity is + detected when it is not inside the detector's region. + """ + + def _build(self, + pos, + size, + inverted=False, + visible=False, + rgba=(1, 0, 0, 0.25), + detected_rgba=(0, 1, 0, 0.25), + name='position_detector'): + """Builds the detector. + + Args: + pos: The position at the center of this detector's active region. Should + be an array-like object of length 3 in 3D mode, or length 2 in 2D mode. + size: The half-lengths of this detector's active region. Should + be an array-like object of length 3 in 3D mode, or length 2 in 2D mode. + inverted: (optional) A boolean, whether to operate in inverted detection + mode. If `True`, an entity is detected when it is not in the active + region. + visible: (optional) A boolean, whether this detector is visible by + default in rendered images. If `False`, this detector's active zone + is placed in MuJoCo rendering group 4, which is not rendered by default, + but can be toggled on (e.g. in `dm_control.viewer`) for debugging + purposes. + rgba: (optional) The color to render when nothing is detected. + detected_rgba: (optional) The color to render when an entity is detected. + name: (optional) XML element name of this position detector. + + Raises: + ValueError: If the `pos` and `size` arrays do not have the same length. + """ + if len(pos) != len(size): + raise ValueError('`pos` and `size` should have the same length: ' + 'got {!r} and {!r}'.format(pos, size)) + + self._inverted = inverted + self._detected = False + self._lower = np.array(pos) - np.array(size) + self._upper = np.array(pos) + np.array(size) + self._lower_3d = _ensure_3d(self._lower) + self._upper_3d = _ensure_3d(self._upper) + self._mid_3d = (self._lower_3d + self._upper_3d) / 2. + + self._entities = [] + self._entity_geoms = {} + + self._rgba = np.asarray(rgba) + self._detected_rgba = np.asarray(detected_rgba) + + render_pos = np.zeros(3) + render_pos[:len(pos)] = pos + + render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE) + render_size[:len(size)] = size + + self._mjcf_root = mjcf.RootElement(model=name) + self._site = self._mjcf_root.worldbody.add( + 'site', name='detection_zone', type='box', + pos=render_pos, size=render_size, rgba=self._rgba) + self._lower_site = self._mjcf_root.worldbody.add( + 'site', name='lower', pos=self._lower_3d, size=[0.05], + rgba=self._rgba) + self._mid_site = self._mjcf_root.worldbody.add( + 'site', name='mid', pos=self._mid_3d, size=[0.05], + rgba=self._rgba) + self._upper_site = self._mjcf_root.worldbody.add( + 'site', name='upper', pos=self._upper_3d, size=[0.05], + rgba=self._rgba) + self._lower_sensor = self._mjcf_root.sensor.add( + 'framepos', objtype='site', objname=self._lower_site, + name='{}_lower'.format(name)) + self._mid_sensor = self._mjcf_root.sensor.add( + 'framepos', objtype='site', objname=self._mid_site, + name='{}_mid'.format(name)) + self._upper_sensor = self._mjcf_root.sensor.add( + 'framepos', objtype='site', objname=self._upper_site, + name='{}_upper'.format(name)) + + if not visible: + self._site.group = composer.SENSOR_SITES_GROUP + self._lower_site.group = composer.SENSOR_SITES_GROUP + self._mid_site.group = composer.SENSOR_SITES_GROUP + self._upper_site.group = composer.SENSOR_SITES_GROUP + + def resize(self, pos, size): + if len(pos) != len(size): + raise ValueError('`pos` and `size` should have the same length: ' + 'got {!r} and {!r}'.format(pos, size)) + self._lower = np.array(pos) - np.array(size) + self._upper = np.array(pos) + np.array(size) + + self._lower_3d = _ensure_3d(self._lower) + self._upper_3d = _ensure_3d(self._upper) + self._mid_3d = (self._lower_3d + self._upper_3d) / 2. + + render_pos = np.zeros(3) + render_pos[:len(pos)] = pos + + render_size = np.full(3, _RENDERED_HEIGHT_IN_2D_MODE) + render_size[:len(size)] = size + + self._site.pos = render_pos + self._site.size = render_size + self._lower_site.pos = self._lower_3d + self._mid_site.pos = self._mid_3d + self._upper_site.pos = self._upper_3d + + def set_colors(self, rgba, detected_rgba): + self.set_color(rgba) + self.set_detected_color(detected_rgba) + + def set_color(self, rgba): + self._rgba[:3] = rgba + self._site.rgba = self._rgba + + def set_detected_color(self, detected_rgba): + self._detected_rgba[:3] = detected_rgba + + def set_position(self, physics, pos): + physics.bind(self._site).pos = pos + size = physics.bind(self._site).size[:3] + self._lower = np.array(pos) - np.array(size) + self._upper = np.array(pos) + np.array(size) + + self._lower_3d = _ensure_3d(self._lower) + self._upper_3d = _ensure_3d(self._upper) + self._mid_3d = (self._lower_3d + self._upper_3d) / 2. + + physics.bind(self._lower_site).pos = self._lower_3d + physics.bind(self._mid_site).pos = self._mid_3d + physics.bind(self._upper_site).pos = self._upper_3d + + @property + def mjcf_model(self): + return self._mjcf_root + + def register_entities(self, *entities): + for entity in entities: + self._entities.append(_Detection(entity)) + self._entity_geoms[entity] = entity.mjcf_model.find_all('geom') + + def deregister_entities(self): + self._entities = [] + + @property + def detected_entities(self): + """A list of detected entities.""" + return [ + detection.entity for detection in self._entities if detection.detected] + + def initialize_episode_mjcf(self, unused_random_state): + self._entity_geoms = {} + for detection in self._entities: + entity = detection.entity + self._entity_geoms[entity] = entity.mjcf_model.find_all('geom') + + def initialize_episode(self, physics, unused_random_state): + self._update_detection(physics) + + def after_substep(self, physics, unused_random_state): + self._update_detection(physics) + + def _is_in_zone(self, xpos): + return (np.all(self._lower < xpos[:len(self._lower)]) + and np.all(self._upper > xpos[:len(self._upper)])) + + def _update_detection(self, physics): + previously_detected = self._detected + self._detected = False + for detection in self._entities: + detection.detected = False + for geom in self._entity_geoms[detection.entity]: + if self._is_in_zone(physics.bind(geom).xpos) != self._inverted: + detection.detected = True + self._detected = True + break + + if self._detected and not previously_detected: + physics.bind(self._site).rgba = self._detected_rgba + elif previously_detected and not self._detected: + physics.bind(self._site).rgba = self._rgba + + def site_pos(self, physics): + return physics.bind(self._site).pos + + @property + def activated(self): + return self._detected + + @property + def upper(self): + return self._upper + + @property + def lower(self): + return self._lower + + @property + def mid(self): + return (self._lower + self._upper) / 2. + + @property + def lower_sensor(self): + return self._lower_sensor + + @property + def mid_sensor(self): + return self._mid_sensor + + @property + def upper_sensor(self): + return self._upper_sensor diff --git a/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py b/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py new file mode 100644 index 0000000..cfdd073 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/props/position_detector_test.py @@ -0,0 +1,133 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.composer.props.position_detector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control.entities.props import position_detector +from dm_control.entities.props import primitive +import numpy as np + + +class PositionDetectorTest(parameterized.TestCase): + + def setUp(self): + super(PositionDetectorTest, self).setUp() + self.arena = composer.Arena() + self.props = [ + primitive.Primitive(geom_type='sphere', size=(0.1,)), + primitive.Primitive(geom_type='sphere', size=(0.1,)) + ] + for prop in self.props: + self.arena.add_free_entity(prop) + self.task = composer.NullTask(self.arena) + + def assertDetected(self, entity, detector): + if not self.inverted: + self.assertIn(entity, detector.detected_entities) + else: + self.assertNotIn(entity, detector.detected_entities) + + def assertNotDetected(self, entity, detector): + if not self.inverted: + self.assertNotIn(entity, detector.detected_entities) + else: + self.assertIn(entity, detector.detected_entities) + + @parameterized.parameters(False, True) + def test3DDetection(self, inverted): + self.inverted = inverted + + detector_pos = np.array([0.3, 0.2, 0.1]) + detector_size = np.array([0.1, 0.2, 0.3]) + detector = position_detector.PositionDetector( + pos=detector_pos, size=detector_size, inverted=inverted) + detector.register_entities(*self.props) + self.arena.attach(detector) + env = composer.Environment(self.task) + + env.reset() + self.assertNotDetected(self.props[0], detector) + self.assertNotDetected(self.props[1], detector) + + def initialize_episode(physics, unused_random_state): + for prop in self.props: + prop.set_pose(physics, detector_pos) + self.task.initialize_episode = initialize_episode + env.reset() + self.assertDetected(self.props[0], detector) + self.assertDetected(self.props[1], detector) + + self.props[0].set_pose(env.physics, detector_pos - detector_size) + env.step([]) + self.assertNotDetected(self.props[0], detector) + self.assertDetected(self.props[1], detector) + + self.props[0].set_pose(env.physics, detector_pos - detector_size / 2) + self.props[1].set_pose(env.physics, detector_pos + detector_size * 1.01) + env.step([]) + self.assertDetected(self.props[0], detector) + self.assertNotDetected(self.props[1], detector) + + @parameterized.parameters(False, True) + def test2DDetection(self, inverted): + self.inverted = inverted + + detector_pos = np.array([0.3, 0.2]) + detector_size = np.array([0.1, 0.2]) + detector = position_detector.PositionDetector( + pos=detector_pos, size=detector_size, inverted=inverted) + detector.register_entities(*self.props) + self.arena.attach(detector) + env = composer.Environment(self.task) + + env.reset() + self.assertNotDetected(self.props[0], detector) + self.assertNotDetected(self.props[1], detector) + + def initialize_episode(physics, unused_random_state): + # In 2D mode, detection should occur no matter how large |z| is. + self.props[0].set_pose(physics, [detector_pos[0], detector_pos[1], 1e+6]) + self.props[1].set_pose(physics, [detector_pos[0], detector_pos[1], -1e+6]) + self.task.initialize_episode = initialize_episode + env.reset() + self.assertDetected(self.props[0], detector) + self.assertDetected(self.props[1], detector) + + self.props[0].set_pose( + env.physics, [detector_pos[0] - detector_size[0], detector_pos[1], 0]) + env.step([]) + self.assertNotDetected(self.props[0], detector) + self.assertDetected(self.props[1], detector) + + self.props[0].set_pose( + env.physics, [detector_pos[0] - detector_size[0] / 2, + detector_pos[1] + detector_size[1] / 2, 0]) + self.props[1].set_pose( + env.physics, [detector_pos[0], detector_pos[1] + detector_size[1], 0]) + env.step([]) + self.assertDetected(self.props[0], detector) + self.assertNotDetected(self.props[1], detector) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/entities/props/primitive.py b/DMC/src/env/dm_control/dm_control/entities/props/primitive.py new file mode 100644 index 0000000..718e8a8 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/props/primitive.py @@ -0,0 +1,112 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Prop consisting of a single geom with position and velocity sensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer import define +from dm_control.composer.observation import observable + + +class Primitive(composer.Entity): + """A prop consisting of a single geom with position and velocity sensors.""" + + def _build(self, geom_type, size, name=None, **kwargs): + """Initializes the prop. + + Args: + geom_type: String specifying the geom type. + size: List or numpy array of up to 3 numbers, depending on `geom_type`: + geom_type='box', size=[x_half_length, y_half_length, z_half_length] + geom_type='capsule', size=[radius, half_length] + geom_type='cylinder', size=[radius, half_length] + geom_type='ellipsoid', size=[x_radius, y_radius, z_radius] + geom_type='sphere', size=[radius] + name: (optional) A string, the name of this prop. + **kwargs: Additional geom parameters. Please see the MuJoCo documentation + for further details: http://www.mujoco.org/book/XMLreference.html#geom. + """ + self._mjcf_root = mjcf.element.RootElement(model=name) + self._geom = self._mjcf_root.worldbody.add( + 'geom', name='geom', type=geom_type, size=size, **kwargs) + self._position = self._mjcf_root.sensor.add( + 'framepos', name='position', objtype='geom', objname=self.geom) + self._orientation = self._mjcf_root.sensor.add( + 'framequat', name='orientation', objtype='geom', objname=self.geom) + self._linear_velocity = self._mjcf_root.sensor.add( + 'framelinvel', name='linear_velocity', objtype='geom', + objname=self.geom) + self._angular_velocity = self._mjcf_root.sensor.add( + 'frameangvel', name='angular_velocity', objtype='geom', + objname=self.geom) + + def _build_observables(self): + return PrimitiveObservables(self) + + @property + def geom(self): + """The geom belonging to this prop.""" + return self._geom + + @property + def position(self): + """Sensor that returns the prop position.""" + return self._position + + @property + def orientation(self): + """Sensor that returns the prop orientation (as a quaternion).""" + # TODO(b/120829807): Consider returning a rotation matrix instead. + return self._orientation + + @property + def linear_velocity(self): + """Sensor that returns the linear velocity of the prop.""" + return self._linear_velocity + + @property + def angular_velocity(self): + """Sensor that returns the angular velocity of the prop.""" + return self._angular_velocity + + @property + def mjcf_model(self): + return self._mjcf_root + + +class PrimitiveObservables(composer.Observables, + composer.FreePropObservableMixin): + """Primitive entity's observables.""" + + @define.observable + def position(self): + return observable.MJCFFeature('sensordata', self._entity.position) + + @define.observable + def orientation(self): + return observable.MJCFFeature('sensordata', self._entity.orientation) + + @define.observable + def linear_velocity(self): + return observable.MJCFFeature('sensordata', self._entity.linear_velocity) + + @define.observable + def angular_velocity(self): + return observable.MJCFFeature('sensordata', self._entity.angular_velocity) diff --git a/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py b/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py new file mode 100644 index 0000000..a731e0d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/entities/props/primitive_test.py @@ -0,0 +1,100 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.composer.props.primitive.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control import mjcf +from dm_control.entities.props import primitive +import numpy as np + + +class PrimitiveTest(parameterized.TestCase): + + def _make_free_prop(self, geom_type='sphere', size=(0.1,), **kwargs): + prop = primitive.Primitive(geom_type=geom_type, size=size, **kwargs) + arena = composer.Arena() + arena.add_free_entity(prop) + physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model) + return prop, physics + + @parameterized.parameters([ + dict(geom_type='sphere', size=[0.1]), + dict(geom_type='capsule', size=[0.1, 0.2]), + dict(geom_type='cylinder', size=[0.1, 0.2]), + dict(geom_type='box', size=[0.1, 0.2, 0.3]), + dict(geom_type='ellipsoid', size=[0.1, 0.2, 0.3]), + ]) + def test_instantiation(self, geom_type, size): + name = 'foo' + rgba = [1., 0., 1., 0.5] + prop, physics = self._make_free_prop( + geom_type=geom_type, size=size, name=name, rgba=rgba) + # Check that the name and other kwargs are set correctly. + self.assertEqual(prop.mjcf_model.model, name) + np.testing.assert_array_equal(physics.bind(prop.geom).rgba, rgba) + # Check that we can step without anything breaking. + physics.step() + + @parameterized.parameters([ + dict(position=[0., 0., 0.]), + dict(position=[0.1, -0.2, 0.3]), + ]) + def test_position_observable(self, position): + prop, physics = self._make_free_prop() + prop.set_pose(physics, position=position) + observation = prop.observables.position(physics) + np.testing.assert_array_equal(position, observation) + + @parameterized.parameters([ + dict(quat=[1., 0., 0., 0.]), + dict(quat=[0., -1., 1., 0.]), + ]) + def test_orientation_observable(self, quat): + prop, physics = self._make_free_prop() + normalized_quat = np.array(quat) / np.linalg.norm(quat) + prop.set_pose(physics, quaternion=normalized_quat) + observation = prop.observables.orientation(physics) + np.testing.assert_array_almost_equal(normalized_quat, observation) + + @parameterized.parameters([ + dict(velocity=[0., 0., 0.]), + dict(velocity=[0.1, -0.2, 0.3]), + ]) + def test_linear_velocity_observable(self, velocity): + prop, physics = self._make_free_prop() + prop.set_velocity(physics, velocity=velocity) + observation = prop.observables.linear_velocity(physics) + np.testing.assert_array_almost_equal(velocity, observation) + + @parameterized.parameters([ + dict(angular_velocity=[0., 0., 0.]), + dict(angular_velocity=[0.1, -0.2, 0.3]), + ]) + def test_angular_velocity_observable(self, angular_velocity): + prop, physics = self._make_free_prop() + prop.set_velocity(physics, angular_velocity=angular_velocity) + observation = prop.observables.angular_velocity(physics) + np.testing.assert_array_almost_equal(angular_velocity, observation) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/README.md b/DMC/src/env/dm_control/dm_control/locomotion/README.md new file mode 100644 index 0000000..0941561 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/README.md @@ -0,0 +1,88 @@ +# Locomotion task library + +This package contains reusable components for defining control tasks that are +related to locomotion. New users are encouraged to start by browsing the +`examples/` subdirectory, which contains preconfigured RL environments +associated with various research papers. These examples can serve as starting +points or be customized to design new environments using the components +available from this library. + +

+ + +

+ +## Terminology + +This library facilitates the creation of environments that require **walkers** +to perform a **task** in an **arena**. + +- **walkers** refer to detached bodies that can move around in the + environment. + +- **arenas** refer to the surroundings in which the walkers and possibly other + objects exist. + +- **tasks** refer to the specification of observations and rewards that are + passed from the "environment" to the "agent", along with runtime details + such as initialization and termination logic. + +## Installation and requirements + +See [the documentation for `dm_control`][installation-and-requirements]. + +## Quickstart + +```python +from dm_control import composer +from dm_control.locomotion.examples import basic_cmu_2019 +import numpy as np + +# Build an example environment. +env = basic_cmu_2019.cmu_humanoid_run_walls() + +# Get the `action_spec` describing the control inputs. +action_spec = env.action_spec() + +# Step through the environment for one episode with random actions. +time_step = env.reset() +while not time_step.last(): + action = np.random.uniform(action_spec.minimum, action_spec.maximum, + size=action_spec.shape) + time_step = env.step(action) + print("reward = {}, discount = {}, observations = {}.".format( + time_step.reward, time_step.discount, time_step.observation)) +``` + +[`dm_control.viewer`] can also be used to visualize and interact with the +environment, e.g.: + +```python +from dm_control import viewer + +viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_walls) +``` + +## Publications + +This library contains environments that were adapted from several research +papers. Relevant references include: + +- [Emergence of Locomotion Behaviours in Rich Environments (2017)][heess2017]. + +- [Learning human behaviors from motion capture by adversarial imitation + (2017)][merel2017]. + +- [Hierarchical visuomotor control of humanoids (2019)][merel2019a]. + +- [Neural probabilistic motor primitives for humanoid control (2019)][merel2019b]. + +- [Deep neuroethology of a virtual rodent (2020)][merel2020]. + +[installation-and-requirements]: ../../README.md#installation-and-requirements +[`dm_control.viewer`]: ../viewer/README.md +[heess2017]: https://arxiv.org/abs/1707.02286 +[merel2017]: https://arxiv.org/abs/1707.02201 +[merel2019a]: https://arxiv.org/abs/1811.09656 +[merel2019b]: https://arxiv.org/abs/1811.11711 +[merel2020]: https://openreview.net/pdf?id=SyxrxR4KPS diff --git a/DMC/src/env/dm_control/dm_control/locomotion/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/__init__.py new file mode 100644 index 0000000..4224c02 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py new file mode 100644 index 0000000..3f81949 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Arenas for Locomotion tasks.""" + +from dm_control.locomotion.arenas.bowl import Bowl +from dm_control.locomotion.arenas.corridors import EmptyCorridor +from dm_control.locomotion.arenas.corridors import GapsCorridor +from dm_control.locomotion.arenas.corridors import WallsCorridor +from dm_control.locomotion.arenas.floors import Floor +from dm_control.locomotion.arenas.labmaze_textures import FloorTextures +from dm_control.locomotion.arenas.labmaze_textures import SkyBox +from dm_control.locomotion.arenas.labmaze_textures import WallTextures +from dm_control.locomotion.arenas.mazes import MazeWithTargets +from dm_control.locomotion.arenas.mazes import RandomMazeWithTargets diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py new file mode 100644 index 0000000..74b35c0 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/__init__.py @@ -0,0 +1,63 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Locomotion texture assets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import sys + +ROOT_DIR = '../locomotion/arenas/assets' + + +def get_texturedir(style): + return os.path.join(ROOT_DIR, style) + +SKY_STYLES = ('outdoor_natural') + +SkyBox = collections.namedtuple( + 'SkyBox', ('file', 'gridsize', 'gridlayout')) + + +def get_sky_texture_info(style): + if style not in SKY_STYLES: + raise ValueError('`style` should be one of {}: got {!r}'.format( + SKY_STYLES, style)) + return SkyBox(file='OutdoorSkybox2048.png', + gridsize='3 4', + gridlayout='.U..LFRB.D..') + + +GROUND_STYLES = ('outdoor_natural') + +GroundTexture = collections.namedtuple( + 'GroundTexture', ('file', 'type')) + + +def get_ground_texture_info(style): + if style not in GROUND_STYLES: + raise ValueError('`style` should be one of {}: got {!r}'.format( + GROUND_STYLES, style)) + return GroundTexture( + file='OutdoorGrassFloorD.png', + type='2d') + + + + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png new file mode 100644 index 0000000..2a93f5b Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorGrassFloorD.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png new file mode 100644 index 0000000..d6f9a58 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/arenas/assets/outdoor_natural/OutdoorSkybox2048.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py new file mode 100644 index 0000000..d591f51 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl.py @@ -0,0 +1,139 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Bowl arena with bumps.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.locomotion.arenas import assets as locomotion_arenas_assets +from dm_control.mujoco.wrapper import mjbindings + +import numpy as np +from scipy import ndimage + +mjlib = mjbindings.mjlib + +_TOP_CAMERA_DISTANCE = 100 +_TOP_CAMERA_Y_PADDING_FACTOR = 1.1 + +# Constants related to terrain generation. +_TERRAIN_SMOOTHNESS = .5 # 0.0: maximally bumpy; 1.0: completely smooth. +_TERRAIN_BUMP_SCALE = .2 # Spatial scale of terrain bumps (in meters). + + +class Bowl(composer.Arena): + """A bowl arena with sinusoidal bumps.""" + + def _build(self, size=(10, 10), aesthetic='default', name='bowl'): + super(Bowl, self)._build(name=name) + + self._hfield = self._mjcf_root.asset.add( + 'hfield', + name='terrain', + nrow=201, + ncol=201, + size=(6, 6, 0.5, 0.1)) + + if aesthetic != 'default': + ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic) + sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic) + texturedir = locomotion_arenas_assets.get_texturedir(aesthetic) + self._mjcf_root.compiler.texturedir = texturedir + + self._texture = self._mjcf_root.asset.add( + 'texture', name='aesthetic_texture', file=ground_info.file, + type=ground_info.type) + self._material = self._mjcf_root.asset.add( + 'material', name='aesthetic_material', texture=self._texture, + texuniform='true') + self._skybox = self._mjcf_root.asset.add( + 'texture', name='aesthetic_skybox', file=sky_info.file, + type='skybox', gridsize=sky_info.gridsize, + gridlayout=sky_info.gridlayout) + self._terrain_geom = self._mjcf_root.worldbody.add( + 'geom', + name='terrain', + type='hfield', + pos=(0, 0, -0.01), + hfield='terrain', + material=self._material) + self._ground_geom = self._mjcf_root.worldbody.add( + 'geom', + type='plane', + name='groundplane', + size=list(size) + [0.5], + material=self._material) + else: + self._terrain_geom = self._mjcf_root.worldbody.add( + 'geom', + name='terrain', + type='hfield', + rgba=(0.2, 0.3, 0.4, 1), + pos=(0, 0, -0.01), + hfield='terrain') + self._ground_geom = self._mjcf_root.worldbody.add( + 'geom', + type='plane', + name='groundplane', + rgba=(0.2, 0.3, 0.4, 1), + size=list(size) + [0.5]) + + self._mjcf_root.visual.headlight.set_attributes( + ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1]) + + self._regenerate = True + + def regenerate(self, random_state): + # regeneration of the bowl requires physics, so postponed to initialization. + self._regenerate = True + + def initialize_episode(self, physics, random_state): + if self._regenerate: + self._regenerate = False + + # Get heightfield resolution, assert that it is square. + res = physics.bind(self._hfield).nrow + assert res == physics.bind(self._hfield).ncol + + # Sinusoidal bowl shape. + row_grid, col_grid = np.ogrid[-1:1:res*1j, -1:1:res*1j] + radius = np.clip(np.sqrt(col_grid**2 + row_grid**2), .1, 1) + bowl_shape = .5 - np.cos(2*np.pi*radius)/2 + + # Random smooth bumps. + terrain_size = 2 * physics.bind(self._hfield).size[0] + bump_res = int(terrain_size / _TERRAIN_BUMP_SCALE) + bumps = random_state.uniform(_TERRAIN_SMOOTHNESS, 1, (bump_res, bump_res)) + smooth_bumps = ndimage.zoom(bumps, res / float(bump_res)) + + # Terrain is elementwise product. + terrain = bowl_shape * smooth_bumps + start_idx = physics.bind(self._hfield).adr + physics.model.hfield_data[start_idx:start_idx+res**2] = terrain.ravel() + + # If we have a rendering context, we need to re-upload the modified + # heightfield data. + if physics.contexts: + with physics.contexts.gl.make_current() as ctx: + ctx.call(mjlib.mjr_uploadHField, + physics.model.ptr, + physics.contexts.mujoco.ptr, + physics.bind(self._hfield).element_id) + + @property + def ground_geoms(self): + return (self._terrain_geom, self._ground_geom) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py new file mode 100644 index 0000000..5699346 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/bowl_test.py @@ -0,0 +1,35 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for locomotion.arenas.bowl.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import mjcf +from dm_control.locomotion.arenas import bowl + + +class BowlTest(absltest.TestCase): + + def test_can_compile_mjcf(self): + + arena = bowl.Bowl() + mjcf.Physics.from_mjcf_model(arena.mjcf_model) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py new file mode 100644 index 0000000..4129023 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors.py @@ -0,0 +1,433 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Corridor-based arenas.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from dm_control import composer +from dm_control.composer import variation +from dm_control.locomotion.arenas import assets as locomotion_arenas_assets +import six + +_SIDE_WALLS_GEOM_GROUP = 3 +_CORRIDOR_X_PADDING = 2.0 +_WALL_THICKNESS = 0.16 +_SIDE_WALL_HEIGHT = 4.0 +_DEFAULT_ALPHA = 0.5 + + +@six.add_metaclass(abc.ABCMeta) +class Corridor(composer.Arena): + """Abstract base class for corridor-type arenas.""" + + @abc.abstractmethod + def regenerate(self, random_state): + raise NotImplementedError + + @abc.abstractproperty + def corridor_length(self): + raise NotImplementedError + + @abc.abstractproperty + def corridor_width(self): + raise NotImplementedError + + @abc.abstractproperty + def ground_geoms(self): + raise NotImplementedError + + def is_at_target_position(self, position, tolerance=0.0): + """Checks if a `position` is within `tolerance' of an end of the corridor. + + This can also be used to evaluate more complicated T-shaped or L-shaped + corridors. + + Args: + position: An iterable of 2 elements corresponding to the x and y location + of the position to evaluate. + tolerance: A `float` tolerance to use while evaluating the position. + + Returns: + A `bool` indicating whether the `position` is within the `tolerance` of an + end of the corridor. + """ + x, _ = position + return x > self.corridor_length - tolerance + + +class EmptyCorridor(Corridor): + """An empty corridor with planes around the perimeter.""" + + def _build(self, + corridor_width=4, + corridor_length=40, + visible_side_planes=True, + name='empty_corridor'): + """Builds the corridor. + + Args: + corridor_width: A number or a `composer.variation.Variation` object that + specifies the width of the corridor. + corridor_length: A number or a `composer.variation.Variation` object that + specifies the length of the corridor. + visible_side_planes: Whether to the side planes that bound the corridor's + perimeter should be rendered. + name: The name of this arena. + """ + super(EmptyCorridor, self)._build(name=name) + + self._corridor_width = corridor_width + self._corridor_length = corridor_length + + self._walls_body = self._mjcf_root.worldbody.add('body', name='walls') + + self._mjcf_root.visual.map.znear = 0.0005 + self._mjcf_root.asset.add( + 'texture', type='skybox', builtin='gradient', + rgb1=[0.4, 0.6, 0.8], rgb2=[0, 0, 0], width=100, height=600) + self._mjcf_root.visual.headlight.set_attributes( + ambient=[0.4, 0.4, 0.4], diffuse=[0.8, 0.8, 0.8], + specular=[0.1, 0.1, 0.1]) + + alpha = _DEFAULT_ALPHA if visible_side_planes else 0.0 + self._ground_plane = self._mjcf_root.worldbody.add( + 'geom', type='plane', rgba=[0.5, 0.5, 0.5, 1], size=[1, 1, 1]) + self._left_plane = self._mjcf_root.worldbody.add( + 'geom', type='plane', xyaxes=[1, 0, 0, 0, 0, 1], size=[1, 1, 1], + rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP) + self._right_plane = self._mjcf_root.worldbody.add( + 'geom', type='plane', xyaxes=[-1, 0, 0, 0, 0, 1], size=[1, 1, 1], + rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP) + self._near_plane = self._mjcf_root.worldbody.add( + 'geom', type='plane', xyaxes=[0, 1, 0, 0, 0, 1], size=[1, 1, 1], + rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP) + self._far_plane = self._mjcf_root.worldbody.add( + 'geom', type='plane', xyaxes=[0, -1, 0, 0, 0, 1], size=[1, 1, 1], + rgba=[1, 0, 0, alpha], group=_SIDE_WALLS_GEOM_GROUP) + + self._current_corridor_length = None + self._current_corridor_width = None + + def regenerate(self, random_state): + """Regenerates this corridor. + + New values are drawn from the `corridor_width` and `corridor_height` + distributions specified in `_build`. The corridor is resized accordingly. + + Args: + random_state: A `numpy.random.RandomState` object that is passed to the + `Variation` objects. + """ + self._walls_body.geom.clear() + corridor_width = variation.evaluate(self._corridor_width, + random_state=random_state) + corridor_length = variation.evaluate(self._corridor_length, + random_state=random_state) + self._current_corridor_length = corridor_length + self._current_corridor_width = corridor_width + + self._ground_plane.pos = [corridor_length / 2, 0, 0] + self._ground_plane.size = [ + corridor_length / 2 + _CORRIDOR_X_PADDING, corridor_width / 2, 1] + + self._left_plane.pos = [ + corridor_length / 2, corridor_width / 2, _SIDE_WALL_HEIGHT / 2] + self._left_plane.size = [ + corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1] + + self._right_plane.pos = [ + corridor_length / 2, -corridor_width / 2, _SIDE_WALL_HEIGHT / 2] + self._right_plane.size = [ + corridor_length / 2 + _CORRIDOR_X_PADDING, _SIDE_WALL_HEIGHT / 2, 1] + + self._near_plane.pos = [ + -_CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2] + self._near_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1] + + self._far_plane.pos = [ + corridor_length + _CORRIDOR_X_PADDING, 0, _SIDE_WALL_HEIGHT / 2] + self._far_plane.size = [corridor_width / 2, _SIDE_WALL_HEIGHT / 2, 1] + + @property + def corridor_length(self): + return self._current_corridor_length + + @property + def corridor_width(self): + return self._current_corridor_width + + @property + def ground_geoms(self): + return (self._ground_plane,) + + +class GapsCorridor(EmptyCorridor): + """A corridor that consists of multiple platforms separated by gaps.""" + + def _build(self, + platform_length=1., + gap_length=2.5, + corridor_width=4, + corridor_length=40, + ground_rgba=(0.5, 0.5, 0.5, 1), + visible_side_planes=False, + aesthetic='default', + name='gaps_corridor'): + """Builds the corridor. + + Args: + platform_length: A number or a `composer.variation.Variation` object that + specifies the size of the platforms along the corridor. + gap_length: A number or a `composer.variation.Variation` object that + specifies the size of the gaps along the corridor. + corridor_width: A number or a `composer.variation.Variation` object that + specifies the width of the corridor. + corridor_length: A number or a `composer.variation.Variation` object that + specifies the length of the corridor. + ground_rgba: A sequence of 4 numbers or a `composer.variation.Variation` + object specifying the color of the ground. + visible_side_planes: Whether to the side planes that bound the corridor's + perimeter should be rendered. + aesthetic: option to adjust the material properties and skybox + name: The name of this arena. + """ + super(GapsCorridor, self)._build( + corridor_width=corridor_width, + corridor_length=corridor_length, + visible_side_planes=visible_side_planes, + name=name) + + self._platform_length = platform_length + self._gap_length = gap_length + self._ground_rgba = ground_rgba + self._aesthetic = aesthetic + + if self._aesthetic != 'default': + ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic) + sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic) + texturedir = locomotion_arenas_assets.get_texturedir(aesthetic) + self._mjcf_root.compiler.texturedir = texturedir + + self._ground_texture = self._mjcf_root.asset.add( + 'texture', name='aesthetic_texture', file=ground_info.file, + type=ground_info.type) + self._ground_material = self._mjcf_root.asset.add( + 'material', name='aesthetic_material', texture=self._ground_texture, + texuniform='true') + # remove existing skybox + for texture in self._mjcf_root.asset.find_all('texture'): + if texture.type == 'skybox': + texture.remove() + self._skybox = self._mjcf_root.asset.add( + 'texture', name='aesthetic_skybox', file=sky_info.file, + type='skybox', gridsize=sky_info.gridsize, + gridlayout=sky_info.gridlayout) + + self._ground_body = self._mjcf_root.worldbody.add('body', name='ground') + + def regenerate(self, random_state): + """Regenerates this corridor. + + New values are drawn from the `corridor_width` and `corridor_height` + distributions specified in `_build`. The corridor resized accordingly, and + new sets of platforms are created according to values drawn from the + `platform_length`, `gap_length`, and `ground_rgba` distributions specified + in `_build`. + + Args: + random_state: A `numpy.random.RandomState` object that is passed to the + `Variation` objects. + """ + # Resize the entire corridor first. + super(GapsCorridor, self).regenerate(random_state) + + # Move the ground plane down and make it invisible. + self._ground_plane.pos = [self._current_corridor_length / 2, 0, -10] + self._ground_plane.rgba = [0, 0, 0, 0] + + # Clear the existing platform pieces. + self._ground_body.geom.clear() + + # Make the first platform larger. + platform_length = 3. * _CORRIDOR_X_PADDING + platform_pos = [ + platform_length / 2, + 0, + -_WALL_THICKNESS, + ] + platform_size = [ + platform_length / 2, + self._current_corridor_width / 2, + _WALL_THICKNESS, + ] + if self._aesthetic != 'default': + self._ground_body.add( + 'geom', + type='box', + name='start_floor', + pos=platform_pos, + size=platform_size, + material=self._ground_material) + else: + self._ground_body.add( + 'geom', + type='box', + rgba=variation.evaluate(self._ground_rgba, random_state), + name='start_floor', + pos=platform_pos, + size=platform_size) + + current_x = platform_length + platform_id = 0 + while current_x < self._current_corridor_length: + platform_length = variation.evaluate( + self._platform_length, random_state=random_state) + platform_pos = [ + current_x + platform_length / 2., + 0, + -_WALL_THICKNESS, + ] + platform_size = [ + platform_length / 2, + self._current_corridor_width / 2, + _WALL_THICKNESS, + ] + if self._aesthetic != 'default': + self._ground_body.add( + 'geom', + type='box', + name='floor_{}'.format(platform_id), + pos=platform_pos, + size=platform_size, + material=self._ground_material) + else: + self._ground_body.add( + 'geom', + type='box', + rgba=variation.evaluate(self._ground_rgba, random_state), + name='floor_{}'.format(platform_id), + pos=platform_pos, + size=platform_size) + + platform_id += 1 + + # Move x to start of the next platform. + current_x += platform_length + variation.evaluate( + self._gap_length, random_state=random_state) + + @property + def ground_geoms(self): + return (self._ground_plane,) + tuple(self._ground_body.find_all('geom')) + + +class WallsCorridor(EmptyCorridor): + """A corridor obstructed by multiple walls aligned against the two sides.""" + + def _build(self, + wall_gap=2.5, + wall_width=2.5, + wall_height=2.0, + swap_wall_side=True, + wall_rgba=(1, 1, 1, 1), + corridor_width=4, + corridor_length=40, + visible_side_planes=False, + name='walls_corridor'): + """Builds the corridor. + + Args: + wall_gap: A number or a `composer.variation.Variation` object that + specifies the gap between each consecutive pair obstructing walls. + wall_width: A number or a `composer.variation.Variation` object that + specifies the width that the obstructing walls extend into the corridor. + wall_height: A number or a `composer.variation.Variation` object that + specifies the height of the obstructing walls. + swap_wall_side: A boolean or a `composer.variation.Variation` object that + specifies whether the next obstructing wall should be aligned against + the opposite side of the corridor compared to the previous one. + wall_rgba: A sequence of 4 numbers or a `composer.variation.Variation` + object specifying the color of the walls. + corridor_width: A number or a `composer.variation.Variation` object that + specifies the width of the corridor. + corridor_length: A number or a `composer.variation.Variation` object that + specifies the length of the corridor. + visible_side_planes: Whether to the side planes that bound the corridor's + perimeter should be rendered. + name: The name of this arena. + """ + super(WallsCorridor, self)._build( + corridor_width=corridor_width, + corridor_length=corridor_length, + visible_side_planes=visible_side_planes, + name=name) + + self._wall_height = wall_height + self._wall_rgba = wall_rgba + self._wall_gap = wall_gap + self._wall_width = wall_width + self._swap_wall_side = swap_wall_side + + def regenerate(self, random_state): + """Regenerates this corridor. + + New values are drawn from the `corridor_width` and `corridor_height` + distributions specified in `_build`. The corridor resized accordingly, and + new sets of obstructing walls are created according to values drawn from the + `wall_gap`, `wall_width`, `wall_height`, and `wall_rgba` distributions + specified in `_build`. + + Args: + random_state: A `numpy.random.RandomState` object that is passed to the + `Variation` objects. + """ + super(WallsCorridor, self).regenerate(random_state) + wall_x = variation.evaluate( + self._wall_gap, random_state=random_state) - _CORRIDOR_X_PADDING + wall_side = 0 + wall_id = 0 + while wall_x < self._current_corridor_length: + wall_width = variation.evaluate( + self._wall_width, random_state=random_state) + wall_height = variation.evaluate( + self._wall_height, random_state=random_state) + wall_rgba = variation.evaluate(self._wall_rgba, random_state=random_state) + if variation.evaluate(self._swap_wall_side, random_state=random_state): + wall_side = 1 - wall_side + + wall_pos = [ + wall_x, + (2 * wall_side - 1) * (self._current_corridor_width - wall_width) / 2, + wall_height / 2 + ] + wall_size = [_WALL_THICKNESS / 2, wall_width / 2, wall_height / 2] + self._walls_body.add( + 'geom', + type='box', + name='wall_{}'.format(wall_id), + pos=wall_pos, + size=wall_size, + rgba=wall_rgba) + + wall_id += 1 + wall_x += variation.evaluate(self._wall_gap, random_state=random_state) + + @property + def ground_geoms(self): + return (self._ground_plane,) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py new file mode 100644 index 0000000..328a455 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/corridors_test.py @@ -0,0 +1,89 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for locomotion.arenas.corridors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import mjcf +from dm_control.composer.variation import deterministic +from dm_control.locomotion.arenas import corridors +from six.moves import zip + + +class CorridorsTest(parameterized.TestCase): + + @parameterized.parameters([ + corridors.EmptyCorridor, + corridors.GapsCorridor, + corridors.WallsCorridor, + ]) + def test_can_compile_mjcf(self, arena_type): + arena = arena_type() + mjcf.Physics.from_mjcf_model(arena.mjcf_model) + + @parameterized.parameters([ + corridors.EmptyCorridor, + corridors.GapsCorridor, + corridors.WallsCorridor, + ]) + def test_can_regenerate_corridor_size(self, arena_type): + width_sequence = [5.2, 3.8, 7.4] + length_sequence = [21.1, 19.4, 16.3] + + arena = arena_type( + corridor_width=deterministic.Sequence(width_sequence), + corridor_length=deterministic.Sequence(length_sequence)) + + # Add a probe geom that will generate contacts with the side walls. + probe_body = arena.mjcf_model.worldbody.add('body', name='probe') + probe_joint = probe_body.add('freejoint') + probe_geom = probe_body.add('geom', name='probe', type='box') + + for expected_width, expected_length in zip(width_sequence, length_sequence): + # No random_state is required since we are using deterministic variations. + arena.regenerate(random_state=None) + + def resize_probe_geom_and_assert_num_contacts( + delta_size, expected_num_contacts, + expected_width=expected_width, expected_length=expected_length): + probe_geom.size = [ + (expected_length / 2 + corridors._CORRIDOR_X_PADDING) + delta_size, + expected_width / 2 + delta_size, 0.1] + physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model) + probe_geomid = physics.bind(probe_geom).element_id + physics.bind(probe_joint).qpos[:3] = [expected_length / 2, 0, 100] + physics.forward() + probe_contacts = [c for c in physics.data.contact + if c.geom1 == probe_geomid or c.geom2 == probe_geomid] + self.assertLen(probe_contacts, expected_num_contacts) + + epsilon = 1e-7 + + # If the probe geom is epsilon-smaller than the expected corridor size, + # then we expect to detect no contact. + resize_probe_geom_and_assert_num_contacts(-epsilon, 0) + + # If the probe geom is epsilon-larger than the expected corridor size, + # then we expect to generate 4 contacts with each side wall, so 16 total. + resize_probe_geom_and_assert_num_contacts(epsilon, 16) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py new file mode 100644 index 0000000..82de82b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering.py @@ -0,0 +1,143 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Calculates a covering of text mazes with overlapping rectangular walls.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np +from six.moves import range + +GridCoordinates = collections.namedtuple('GridCoordinates', ('y', 'x')) +MazeWall = collections.namedtuple('MazeWall', ('start', 'end')) + + +class _MazeWallCoveringContext(object): + """Calculates a covering of text mazes with overlapping rectangular walls. + + This class uses a greedy algorithm to try and minimize the number of geoms + generated to create a given maze. The solution is not guaranteed to be + optimal, but in most cases should result in a significantly smaller number of + geoms than if each cell were treated as an individual box. + """ + + def __init__(self, text_maze, wall_char='*', make_odd_sized_walls=False): + """Initializes this _MazeWallCoveringContext. + + Args: + text_maze: A `labmaze.TextGrid` instance. + wall_char: (optional) The character that signifies a wall. + make_odd_sized_walls: (optional) A boolean, if `True` all wall sections + generated span odd numbers of grid cells. This option exists primarily + to appease MuJoCo's texture repeating algorithm. + """ + self._text_maze = text_maze + self._wall_char = wall_char + self._make_odd_sized_walls = make_odd_sized_walls + self._covered = np.full(text_maze.shape, False, dtype=np.bool) + self._maze_size = GridCoordinates(*text_maze.shape) + self._next_start = GridCoordinates(0, 0) + self._calculated = False + self._walls = () + + def calculate(self): + """Calculates a covering of text mazes with overlapping rectangular walls. + + Returns: + A tuple of `MazeWall` objects, each describing the corners of a wall. + """ + if not self._calculated: + self._calculated = True + self._find_next_start() + walls = [] + while self._next_start.y < self._maze_size.y: + walls.append(self._find_next_wall()) + self._find_next_start() + self._walls = tuple(walls) + return self._walls + + def _find_next_start(self): + """Moves `self._next_start` to the top-left corner of the next wall.""" + for y in range(self._next_start.y, self._maze_size.y): + start_x = self._next_start.x if y == self._next_start.y else 0 + for x in range(start_x, self._maze_size.x): + if self._text_maze[y, x] == self._wall_char and not self._covered[y, x]: + self._next_start = GridCoordinates(y, x) + return + self._next_start = self._maze_size + + def _scan_row(self, row, start_col, end_col): + """Scans a row of text maze to find the longest strip of wall.""" + for col in range(start_col, end_col): + if (self._text_maze[row, col] != self._wall_char + or self._covered[row, col]): + return col + return end_col + + def _find_next_wall(self): + """Finds the largest piece of rectangular wall at the current location. + + This function assumes that `self._next_start` is already at the top-left + corner of the next piece of wall. + + Returns: + A `MazeWall` named tuple representing the next piece of wall created. + """ + start = self._next_start + x = self._maze_size.x + end_x_for_rows = [] + total_cells = [] + + for y in range(start.y, self._maze_size.y): + x = self._scan_row(y, start.x, x) + if x > start.x: + if self._make_odd_sized_walls and (x - start.x) % 2 == 0: + x -= 1 + end_x_for_rows.append(x) + total_cells.append((x - start.x) * (y - start.y + 1)) + y += 1 + else: + break + + if not self._make_odd_sized_walls: + end_y_offset = total_cells.index(max(total_cells)) + else: + end_y_offset = 2 * total_cells[::2].index(max(total_cells[::2])) + end = GridCoordinates(start.y + end_y_offset + 1, + end_x_for_rows[end_y_offset]) + self._covered[start.y:end.y, start.x:end.x] = True + self._next_start = GridCoordinates(start.y, end.x) + return MazeWall(start, end) + + +def make_walls(text_maze, wall_char='*', make_odd_sized_walls=False): + """Calculates a covering of text mazes with overlapping rectangular walls. + + Args: + text_maze: A `labmaze.TextMaze` instance. + wall_char: (optional) The character that signifies a wall. + make_odd_sized_walls: (optional) A boolean, if `True` all wall sections + generated span odd numbers of grid cells. This option exists primarily + to appease MuJoCo's texture repeating algorithm. + + Returns: + A tuple of `MazeWall` objects, each describing the corners of a wall. + """ + wall_covering_context = _MazeWallCoveringContext( + text_maze, wall_char=wall_char, make_odd_sized_walls=make_odd_sized_walls) + return wall_covering_context.calculate() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py new file mode 100644 index 0000000..b008319 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/covering_test.py @@ -0,0 +1,80 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for arenas.mazes.covering.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control.locomotion.arenas import covering +import labmaze +import numpy as np +import six +from six.moves import range + +if six.PY3: + _STRING_DTYPE = '|U1' +else: + _STRING_DTYPE = '|S1' + + +class CoveringTest(absltest.TestCase): + + def testRandomMazes(self): + maze = labmaze.RandomMaze(height=17, width=17, + max_rooms=5, room_min_size=3, room_max_size=5, + spawns_per_room=0, objects_per_room=0, + random_seed=54321) + for _ in range(1000): + maze.regenerate() + walls = covering.make_walls(maze.entity_layer) + reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE) + for wall in walls: + reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*' + np.testing.assert_array_equal(reconstructed, maze.entity_layer) + + def testOddCovering(self): + maze = labmaze.RandomMaze(height=17, width=17, + max_rooms=5, room_min_size=3, room_max_size=5, + spawns_per_room=0, objects_per_room=0, + random_seed=54321) + for _ in range(1000): + maze.regenerate() + walls = covering.make_walls(maze.entity_layer, make_odd_sized_walls=True) + reconstructed = np.full(maze.entity_layer.shape, ' ', dtype=_STRING_DTYPE) + for wall in walls: + reconstructed[wall.start.y:wall.end.y, wall.start.x:wall.end.x] = '*' + np.testing.assert_array_equal(reconstructed, maze.entity_layer) + for wall in walls: + self.assertEqual((wall.end.y - wall.start.y) % 2, 1) + self.assertEqual((wall.end.x - wall.start.x) % 2, 1) + + def testNoOverlappingWalls(self): + maze_string = """..** + .*** + .*** + """.replace(' ', '') + walls = covering.make_walls(labmaze.TextGrid(maze_string)) + surface = 0 + for wall in walls: + size_x = wall.end.x - wall.start.x + size_y = wall.end.y - wall.start.y + surface += size_x * size_y + self.assertEqual(surface, 8) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py new file mode 100644 index 0000000..47afc91 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors.py @@ -0,0 +1,105 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Simple floor arenas.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.locomotion.arenas import assets as locomotion_arenas_assets +import numpy as np + +_TOP_CAMERA_DISTANCE = 100 +_TOP_CAMERA_Y_PADDING_FACTOR = 1.1 + + +class Floor(composer.Arena): + """A simple floor arena with a checkered pattern.""" + + def _build(self, size=(8, 8), reflectance=.2, aesthetic='default', + name='floor'): + super(Floor, self)._build(name=name) + self._size = size + + self._mjcf_root.visual.headlight.set_attributes( + ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1]) + + if aesthetic != 'default': + ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic) + sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic) + texturedir = locomotion_arenas_assets.get_texturedir(aesthetic) + self._mjcf_root.compiler.texturedir = texturedir + + self._ground_texture = self._mjcf_root.asset.add( + 'texture', name='aesthetic_texture', file=ground_info.file, + type=ground_info.type) + self._ground_material = self._mjcf_root.asset.add( + 'material', name='aesthetic_material', texture=self._ground_texture, + texuniform='true') + self._skybox = self._mjcf_root.asset.add( + 'texture', name='aesthetic_skybox', file=sky_info.file, + type='skybox', gridsize=sky_info.gridsize, + gridlayout=sky_info.gridlayout) + else: + self._ground_texture = self._mjcf_root.asset.add( + 'texture', + rgb1=[.2, .3, .4], + rgb2=[.1, .2, .3], + type='2d', + builtin='checker', + name='groundplane', + width=300, + height=300, + mark='edge', + markrgb=[0.8, 0.8, 0.8]) + self._ground_material = self._mjcf_root.asset.add( + 'material', + name='groundplane', + texrepeat=[3, 3], + texuniform=True, + reflectance=reflectance, + texture=self._ground_texture) + + # Build groundplane. + self._ground_geom = self._mjcf_root.worldbody.add( + 'geom', + type='plane', + name='groundplane', + material=self._ground_material, + size=list(size) + [0.5]) + + # Choose the FOV so that the floor always fits nicely within the frame + # irrespective of actual floor size. + fovy_radians = 2 * np.arctan2(_TOP_CAMERA_Y_PADDING_FACTOR * size[1], + _TOP_CAMERA_DISTANCE) + self._top_camera = self._mjcf_root.worldbody.add( + 'camera', + name='top_camera', + pos=[0, 0, _TOP_CAMERA_DISTANCE], + zaxis=[0, 0, 1], + fovy=np.rad2deg(fovy_radians)) + + @property + def ground_geoms(self): + return (self._ground_geom,) + + def regenerate(self, random_state): + pass + + @property + def size(self): + return self._size diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py new file mode 100644 index 0000000..8979036 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/floors_test.py @@ -0,0 +1,53 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for locomotion.arenas.floors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import mjcf +from dm_control.locomotion.arenas import floors +import numpy as np + + +class FloorsTest(absltest.TestCase): + + def test_can_compile_mjcf(self): + arena = floors.Floor() + mjcf.Physics.from_mjcf_model(arena.mjcf_model) + + def test_size(self): + floor_size = (12.9, 27.1) + arena = floors.Floor(size=floor_size) + self.assertEqual(tuple(arena.ground_geoms[0].size[:2]), floor_size) + + def test_top_camera(self): + floor_width, floor_height = 12.9, 27.1 + arena = floors.Floor(size=[floor_width, floor_height]) + + self.assertGreater(floors._TOP_CAMERA_Y_PADDING_FACTOR, 1) + np.testing.assert_array_equal(arena._top_camera.zaxis, (0, 0, 1)) + + expected_camera_y = floor_height * floors._TOP_CAMERA_Y_PADDING_FACTOR + np.testing.assert_allclose( + np.tan(np.deg2rad(arena._top_camera.fovy / 2)), + expected_camera_y / arena._top_camera.pos[2]) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py new file mode 100644 index 0000000..593146a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/labmaze_textures.py @@ -0,0 +1,87 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""LabMaze textures.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control import mjcf +from labmaze import assets as labmaze_assets +import six + + +class SkyBox(composer.Entity): + """Represents a texture asset for the sky box.""" + + def _build(self, style): + labmaze_textures = labmaze_assets.get_sky_texture_paths(style) + self._mjcf_root = mjcf.RootElement(model='labmaze_' + style) + self._texture = self._mjcf_root.asset.add( + 'texture', type='skybox', name='texture', + fileleft=labmaze_textures.left, fileright=labmaze_textures.right, + fileup=labmaze_textures.up, filedown=labmaze_textures.down, + filefront=labmaze_textures.front, fileback=labmaze_textures.back) + + @property + def mjcf_model(self): + return self._mjcf_root + + @property + def texture(self): + return self._texture + + +class WallTextures(composer.Entity): + """Represents wall texture assets.""" + + def _build(self, style): + labmaze_textures = labmaze_assets.get_wall_texture_paths(style) + self._mjcf_root = mjcf.RootElement(model='labmaze_' + style) + self._textures = [] + for texture_name, texture_path in six.iteritems(labmaze_textures): + self._textures.append(self._mjcf_root.asset.add( + 'texture', type='2d', name=texture_name, + file=texture_path.format(texture_name))) + + @property + def mjcf_model(self): + return self._mjcf_root + + @property + def textures(self): + return self._textures + + +class FloorTextures(composer.Entity): + """Represents floor texture assets.""" + + def _build(self, style): + labmaze_textures = labmaze_assets.get_floor_texture_paths(style) + self._mjcf_root = mjcf.RootElement(model='labmaze_' + style) + self._textures = [] + for texture_name, texture_path in six.iteritems(labmaze_textures): + self._textures.append(self._mjcf_root.asset.add( + 'texture', type='2d', name=texture_name, + file=texture_path.format(texture_name))) + + @property + def mjcf_model(self): + return self._mjcf_root + + @property + def textures(self): + return self._textures diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py new file mode 100644 index 0000000..b30afb9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes.py @@ -0,0 +1,466 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Maze-based arenas.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import string + +from absl import logging +from dm_control import composer +from dm_control.composer.observation import observable +from dm_control.locomotion.arenas import assets as locomotion_arenas_assets +from dm_control.locomotion.arenas import covering +import labmaze +import numpy as np +import six +from six.moves import range +from six.moves import zip + + +# Put all "actual" wall geoms in a separate group since they are not rendered. +_WALL_GEOM_GROUP = 3 + +_TOP_CAMERA_DISTANCE = 100 +_TOP_CAMERA_Y_PADDING_FACTOR = 1.1 + +_DEFAULT_WALL_CHAR = '*' +_DEFAULT_FLOOR_CHAR = '.' + + +class MazeWithTargets(composer.Arena): + """A 2D maze with target positions specified by a LabMaze-style text maze.""" + + def _build(self, maze, xy_scale=2.0, z_height=2.0, + skybox_texture=None, wall_textures=None, floor_textures=None, + aesthetic='default', name='maze'): + """Initializes this maze arena. + + Args: + maze: A `labmaze.BaseMaze` instance. + xy_scale: The size of each maze cell in metres. + z_height: The z-height of the maze in metres. + skybox_texture: (optional) A `composer.Entity` that provides a texture + asset for the skybox. + wall_textures: (optional) Either a `composer.Entity` that provides texture + assets for the maze walls, or a dict mapping printable characters to + such Entities. In the former case, the maze walls are assumed to be + represented by '*' in the maze's entity layer. In the latter case, + the dict's keys specify the different characters that can be present + in the maze's entity layer, and the dict's values are the corresponding + texture providers. + floor_textures: (optional) A `composer.Entity` that provides texture + assets for the maze floor. Unlike with walls, we do not currently + support per-variation floor texture. Instead, we sample textures from + the same texture provider for each variation in the variations layer. + aesthetic: option to adjust the material properties and skybox + name: (optional) A string, the name of this arena. + """ + super(MazeWithTargets, self)._build(name) + self._maze = maze + self._xy_scale = xy_scale + self._z_height = z_height + + self._x_offset = (self._maze.width - 1) / 2 + self._y_offset = (self._maze.height - 1) / 2 + + self._mjcf_root.default.geom.rgba = [1, 1, 1, 1] + + if aesthetic != 'default': + sky_info = locomotion_arenas_assets.get_sky_texture_info(aesthetic) + texturedir = locomotion_arenas_assets.get_texturedir(aesthetic) + self._mjcf_root.compiler.texturedir = texturedir + self._skybox = self._mjcf_root.asset.add( + 'texture', name='aesthetic_skybox', file=sky_info.file, + type='skybox', gridsize=sky_info.gridsize, + gridlayout=sky_info.gridlayout) + elif skybox_texture: + self._skybox_texture = skybox_texture.texture + self.attach(skybox_texture) + else: + self._skybox_texture = self._mjcf_root.asset.add( + 'texture', type='skybox', name='skybox', builtin='gradient', + rgb1=[.4, .6, .8], rgb2=[0, 0, 0], width=100, height=100) + + self._texturing_geom_names = [] + self._texturing_material_names = [] + if wall_textures: + if isinstance(wall_textures, dict): + for texture_provider in set(wall_textures.values()): + self.attach(texture_provider) + self._wall_textures = { + wall_char: texture_provider.textures + for wall_char, texture_provider in six.iteritems(wall_textures) + } + else: + self.attach(wall_textures) + self._wall_textures = {_DEFAULT_WALL_CHAR: wall_textures.textures} + else: + self._wall_textures = {_DEFAULT_WALL_CHAR: [self._mjcf_root.asset.add( + 'texture', type='2d', name='wall', builtin='flat', + rgb1=[.8, .8, .8], width=100, height=100)]} + + if aesthetic != 'default': + ground_info = locomotion_arenas_assets.get_ground_texture_info(aesthetic) + self._floor_textures = [ + self._mjcf_root.asset.add( + 'texture', + name='aesthetic_texture_main', + file=ground_info.file, + type=ground_info.type), + self._mjcf_root.asset.add( + 'texture', + name='aesthetic_texture', + file=ground_info.file, + type=ground_info.type) + ] + elif floor_textures: + self._floor_textures = floor_textures.textures + self.attach(floor_textures) + else: + self._floor_textures = [self._mjcf_root.asset.add( + 'texture', type='2d', name='floor', builtin='flat', + rgb1=[.2, .2, .2], width=100, height=100)] + + ground_x = ((self._maze.width - 1) + 1) * (xy_scale / 2) + ground_y = ((self._maze.height - 1) + 1) * (xy_scale / 2) + self._mjcf_root.worldbody.add( + 'geom', name='ground', type='plane', + pos=[0, 0, 0], size=[ground_x, ground_y, 1], rgba=[0, 0, 0, 0]) + + self._maze_body = self._mjcf_root.worldbody.add('body', name='maze_body') + + self._mjcf_root.visual.map.znear = 0.0005 + + # Choose the FOV so that the maze always fits nicely within the frame + # irrespective of actual maze size. + maze_size = max(self._maze.width, self._maze.height) + top_camera_fovy = (360 / np.pi) * np.arctan2( + _TOP_CAMERA_Y_PADDING_FACTOR * maze_size * self._xy_scale / 2, + _TOP_CAMERA_DISTANCE) + self._top_camera = self._mjcf_root.worldbody.add( + 'camera', name='top_camera', + pos=[0, 0, _TOP_CAMERA_DISTANCE], zaxis=[0, 0, 1], fovy=top_camera_fovy) + + self._target_positions = () + self._spawn_positions = () + + self._text_maze_regenerated_hook = None + self._tile_geom_names = {} + + def _build_observables(self): + return MazeObservables(self) + + @property + def top_camera(self): + return self._top_camera + + @property + def xy_scale(self): + return self._xy_scale + + @property + def z_height(self): + return self._z_height + + @property + def maze(self): + return self._maze + + @property + def text_maze_regenerated_hook(self): + """A callback that is executed after the LabMaze object is regenerated.""" + return self._text_maze_modifier + + @text_maze_regenerated_hook.setter + def text_maze_regenerated_hook(self, hook): + self._text_maze_regenerated_hook = hook + + @property + def target_positions(self): + """A tuple of Cartesian target positions generated for the current maze.""" + return self._target_positions + + @property + def spawn_positions(self): + """The Cartesian position at which the agent should be spawned.""" + return self._spawn_positions + + @property + def target_grid_positions(self): + """A tuple of grid coordinates of targets generated for the current maze.""" + return self._target_grid_positions + + @property + def spawn_grid_positions(self): + """The grid-coordinate position at which the agent should be spawned.""" + return self._spawn_grid_positions + + def regenerate(self): + """Generates a new maze layout.""" + self._maze.regenerate() + logging.debug('GENERATED MAZE:\n%s', self._maze.entity_layer) + self._find_spawn_and_target_positions() + + if self._text_maze_regenerated_hook: + self._text_maze_regenerated_hook() + + # Remove old texturing planes. + for geom_name in self._texturing_geom_names: + del self._mjcf_root.worldbody.geom[geom_name] + self._texturing_geom_names = [] + + # Remove old texturing materials. + for material_name in self._texturing_material_names: + del self._mjcf_root.asset.material[material_name] + self._texturing_material_names = [] + + # Remove old actual-wall geoms. + self._maze_body.geom.clear() + + self._current_wall_texture = { + wall_char: np.random.choice(wall_textures) + for wall_char, wall_textures in six.iteritems(self._wall_textures) + } + + for wall_char in self._wall_textures: + self._make_wall_geoms(wall_char) + self._make_floor_variations() + + def _make_wall_geoms(self, wall_char): + walls = covering.make_walls( + self._maze.entity_layer, wall_char=wall_char, make_odd_sized_walls=True) + for i, wall in enumerate(walls): + wall_mid = covering.GridCoordinates( + (wall.start.y + wall.end.y - 1) / 2, + (wall.start.x + wall.end.x - 1) / 2) + wall_pos = np.array([(wall_mid.x - self._x_offset) * self._xy_scale, + -(wall_mid.y - self._y_offset) * self._xy_scale, + self._z_height / 2]) + wall_size = np.array([(wall.end.x - wall_mid.x - 0.5) * self._xy_scale, + (wall.end.y - wall_mid.y - 0.5) * self._xy_scale, + self._z_height / 2]) + self._maze_body.add('geom', name='wall{}_{}'.format(wall_char, i), + type='box', pos=wall_pos, size=wall_size, + group=_WALL_GEOM_GROUP) + self._make_wall_texturing_planes(wall_char, i, wall_pos, wall_size) + + def _make_wall_texturing_planes(self, wall_char, wall_id, + wall_pos, wall_size): + xyaxes = { + 'x': {-1: [0, -1, 0, 0, 0, 1], 1: [0, 1, 0, 0, 0, 1]}, + 'y': {-1: [1, 0, 0, 0, 0, 1], 1: [-1, 0, 0, 0, 0, 1]}, + 'z': {-1: [-1, 0, 0, 0, 1, 0], 1: [1, 0, 0, 0, 1, 0]} + } + for direction_index, direction in enumerate(('x', 'y', 'z')): + index = list(i for i in range(3) if i != direction_index) + delta_vector = np.array([int(i == direction_index) for i in range(3)]) + material_name = 'wall{}_{}_{}'.format(wall_char, wall_id, direction) + self._texturing_material_names.append(material_name) + mat = self._mjcf_root.asset.add( + 'material', name=material_name, + texture=self._current_wall_texture[wall_char], + texrepeat=(2 * wall_size[index] / self._xy_scale)) + for sign, sign_name in zip((-1, 1), ('neg', 'pos')): + if direction == 'z' and sign == -1: + continue + geom_name = ( + 'wall{}_{}_texturing_{}_{}'.format( + wall_char, wall_id, sign_name, direction)) + self._texturing_geom_names.append(geom_name) + self._mjcf_root.worldbody.add( + 'geom', type='plane', name=geom_name, + pos=(wall_pos + sign * delta_vector * wall_size), + size=np.concatenate([wall_size[index], [self._xy_scale]]), + xyaxes=xyaxes[direction][sign], material=mat, + contype=0, conaffinity=0) + + def _make_floor_variations(self, build_tile_geoms_fn=None): + """Builds the floor tiles. + + Args: + build_tile_geoms_fn: An optional callable returning floor tile geoms. + If not passed, the floor will be built using a default covering method. + Takes a kwarg `wall_char` that can be used control how active floor + tiles are selected. + """ + main_floor_texture = np.random.choice(self._floor_textures) + for variation in _DEFAULT_FLOOR_CHAR + string.ascii_uppercase: + if variation not in self._maze.variations_layer: + break + + if build_tile_geoms_fn is None: + # Break the floor variation down to odd-sized tiles. + tiles = covering.make_walls(self._maze.variations_layer, + wall_char=variation, + make_odd_sized_walls=True) + else: + tiles = build_tile_geoms_fn(wall_char=variation) + + # Sample a texture that's not the same as the main floor texture. + variation_texture = main_floor_texture + if variation != _DEFAULT_FLOOR_CHAR: + if len(self._floor_textures) == 1: + return + else: + while variation_texture is main_floor_texture: + variation_texture = np.random.choice(self._floor_textures) + + for i, tile in enumerate(tiles): + tile_mid = covering.GridCoordinates( + (tile.start.y + tile.end.y - 1) / 2, + (tile.start.x + tile.end.x - 1) / 2) + tile_pos = np.array([(tile_mid.x - self._x_offset) * self._xy_scale, + -(tile_mid.y - self._y_offset) * self._xy_scale, + 0.0]) + tile_size = np.array([(tile.end.x - tile_mid.x - 0.5) * self._xy_scale, + (tile.end.y - tile_mid.y - 0.5) * self._xy_scale, + self._xy_scale]) + if variation == _DEFAULT_FLOOR_CHAR: + tile_name = 'floor_{}'.format(i) + else: + tile_name = 'floor_{}_{}'.format(variation, i) + self._tile_geom_names[tile.start] = tile_name + self._texturing_material_names.append(tile_name) + self._texturing_geom_names.append(tile_name) + material = self._mjcf_root.asset.add( + 'material', name=tile_name, texture=variation_texture, + texrepeat=(2 * tile_size[[0, 1]] / self._xy_scale)) + self._mjcf_root.worldbody.add( + 'geom', name=tile_name, type='plane', material=material, + pos=tile_pos, size=tile_size, contype=0, conaffinity=0) + + @property + def ground_geoms(self): + return tuple([ + geom for geom in self.mjcf_model.find_all('geom') + if 'ground' in geom.name + ]) + + def find_token_grid_positions(self, tokens): + out = {token: [] for token in tokens} + for y in range(self._maze.entity_layer.shape[0]): + for x in range(self._maze.entity_layer.shape[1]): + for token in tokens: + if self._maze.entity_layer[y, x] == token: + out[token].append((y, x)) + return out + + def grid_to_world_positions(self, grid_positions): + out = [] + for y, x in grid_positions: + out.append(np.array([(x - self._x_offset) * self._xy_scale, + -(y - self._y_offset) * self._xy_scale, + 0.0])) + return out + + def world_to_grid_positions(self, world_positions): + out = [] + # the order of x, y is reverse between grid positions format and + # world positions format. + for x, y, _ in world_positions: + out.append(np.array([self._y_offset - y / self._xy_scale, + self._x_offset + x / self._xy_scale])) + return out + + def _find_spawn_and_target_positions(self): + grid_positions = self.find_token_grid_positions([ + labmaze.defaults.OBJECT_TOKEN, labmaze.defaults.SPAWN_TOKEN]) + self._target_grid_positions = tuple( + grid_positions[labmaze.defaults.OBJECT_TOKEN]) + self._spawn_grid_positions = tuple( + grid_positions[labmaze.defaults.SPAWN_TOKEN]) + self._target_positions = tuple( + self.grid_to_world_positions(self._target_grid_positions)) + self._spawn_positions = tuple( + self.grid_to_world_positions(self._spawn_grid_positions)) + + +class MazeObservables(composer.Observables): + + @composer.observable + def top_camera(self): + return observable.MJCFCamera(self._entity.top_camera) + + +class RandomMazeWithTargets(MazeWithTargets): + """A randomly generated 2D maze with target positions.""" + + def _build(self, + x_cells, + y_cells, + xy_scale=2.0, + z_height=2.0, + max_rooms=labmaze.defaults.MAX_ROOMS, + room_min_size=labmaze.defaults.ROOM_MIN_SIZE, + room_max_size=labmaze.defaults.ROOM_MAX_SIZE, + spawns_per_room=labmaze.defaults.SPAWN_COUNT, + targets_per_room=labmaze.defaults.OBJECT_COUNT, + max_variations=labmaze.defaults.MAX_VARIATIONS, + simplify=labmaze.defaults.SIMPLIFY, + skybox_texture=None, + wall_textures=None, + floor_textures=None, + aesthetic='default', + name='random_maze'): + """Initializes this random maze arena. + + Args: + x_cells: The number of cells along the x-direction of the maze. Must be + an odd integer. + y_cells: The number of cells along the y-direction of the maze. Must be + an odd integer. + xy_scale: The size of each maze cell in metres. + z_height: The z-height of the maze in metres. + max_rooms: (optional) The maximum number of rooms in each generated maze. + room_min_size: (optional) The minimum size of each room generated. + room_max_size: (optional) The maximum size of each room generated. + spawns_per_room: (optional) Number of spawn points + to generate in each room. + targets_per_room: (optional) Number of targets to generate in each room. + max_variations: (optional) Maximum number of variations to generate + in the variations layer. + simplify: (optional) flag to simplify the maze. + skybox_texture: (optional) A `composer.Entity` that provides a texture + asset for the skybox. + wall_textures: (optional) A `composer.Entity` that provides texture + assets for the maze walls. + floor_textures: (optional) A `composer.Entity` that provides texture + assets for the maze floor. + aesthetic: option to adjust the material properties and skybox + name: (optional) A string, the name of this arena. + """ + random_seed = np.random.randint(2147483648) # 2**31 + super(RandomMazeWithTargets, self)._build( + maze=labmaze.RandomMaze( + height=y_cells, + width=x_cells, + max_rooms=max_rooms, + room_min_size=room_min_size, + room_max_size=room_max_size, + max_variations=max_variations, + spawns_per_room=spawns_per_room, + objects_per_room=targets_per_room, + simplify=simplify, + random_seed=random_seed), + xy_scale=xy_scale, + z_height=z_height, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures, + aesthetic=aesthetic, + name=name) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py new file mode 100644 index 0000000..80a61c7 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/arenas/mazes_test.py @@ -0,0 +1,52 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for locomotion.arenas.mazes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import mjcf +from dm_control.locomotion.arenas import labmaze_textures +from dm_control.locomotion.arenas import mazes + + +class MazesTest(absltest.TestCase): + + def test_can_compile_mjcf(self): + + # Set the wall and floor textures to match DMLab and set the skybox. + skybox_texture = labmaze_textures.SkyBox(style='sky_03') + wall_textures = labmaze_textures.WallTextures(style='style_01') + floor_textures = labmaze_textures.FloorTextures(style='style_01') + + arena = mazes.RandomMazeWithTargets( + x_cells=11, + y_cells=11, + xy_scale=3, + max_rooms=4, + room_min_size=4, + room_max_size=5, + spawns_per_room=1, + targets_per_room=3, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures) + mjcf.Physics.from_mjcf_model(arena.mjcf_model) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py new file mode 100644 index 0000000..4224c02 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py new file mode 100644 index 0000000..d508fcd --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_cmu_2019.py @@ -0,0 +1,225 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Produces reference environments for CMU humanoid locomotion tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from dm_control import composer +from dm_control.composer.variation import distributions +from dm_control.locomotion.arenas import corridors as corr_arenas +from dm_control.locomotion.arenas import floors +from dm_control.locomotion.arenas import labmaze_textures +from dm_control.locomotion.arenas import mazes +from dm_control.locomotion.props import target_sphere +from dm_control.locomotion.tasks import corridors as corr_tasks +from dm_control.locomotion.tasks import go_to_target +from dm_control.locomotion.tasks import random_goal_maze +from dm_control.locomotion.walkers import cmu_humanoid +from labmaze import fixed_maze + + +def cmu_humanoid_run_walls(random_state=None): + """Requires a CMU humanoid to run down a corridor obstructed by walls.""" + + # Build a position-controlled CMU humanoid walker. + walker = cmu_humanoid.CMUHumanoidPositionControlled( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a corridor-shaped arena that is obstructed by walls. + arena = corr_arenas.WallsCorridor( + wall_gap=4., + wall_width=distributions.Uniform(1, 7), + wall_height=3.0, + corridor_width=10, + corridor_length=100) + + # Build a task that rewards the agent for running down the corridor at a + # specific velocity. + task = corr_tasks.RunThroughCorridor( + walker=walker, + arena=arena, + walker_spawn_position=(0.5, 0, 0), + target_velocity=3.0, + physics_timestep=0.005, + control_timestep=0.03) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def cmu_humanoid_run_gaps(random_state=None): + """Requires a CMU humanoid to run down a corridor with gaps.""" + + # Build a position-controlled CMU humanoid walker. + walker = cmu_humanoid.CMUHumanoidPositionControlled( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a corridor-shaped arena with gaps, where the sizes of the gaps and + # platforms are uniformly randomized. + arena = corr_arenas.GapsCorridor( + platform_length=distributions.Uniform(.3, 2.5), + gap_length=distributions.Uniform(.5, 1.25), + corridor_width=10, + corridor_length=100) + + # Build a task that rewards the agent for running down the corridor at a + # specific velocity. + task = corr_tasks.RunThroughCorridor( + walker=walker, + arena=arena, + walker_spawn_position=(0.5, 0, 0), + target_velocity=3.0, + physics_timestep=0.005, + control_timestep=0.03) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def cmu_humanoid_go_to_target(random_state=None): + """Requires a CMU humanoid to go to a target.""" + + # Build a position-controlled CMU humanoid walker. + walker = cmu_humanoid.CMUHumanoidPositionControlled() + + # Build a standard floor arena. + arena = floors.Floor() + + # Build a task that rewards the agent for going to a target. + task = go_to_target.GoToTarget( + walker=walker, + arena=arena, + physics_timestep=0.005, + control_timestep=0.03) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def cmu_humanoid_maze_forage(random_state=None): + """Requires a CMU humanoid to find all items in a maze.""" + + # Build a position-controlled CMU humanoid walker. + walker = cmu_humanoid.CMUHumanoidPositionControlled( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a maze with rooms and targets. + skybox_texture = labmaze_textures.SkyBox(style='sky_03') + wall_textures = labmaze_textures.WallTextures(style='style_01') + floor_textures = labmaze_textures.FloorTextures(style='style_01') + arena = mazes.RandomMazeWithTargets( + x_cells=11, + y_cells=11, + xy_scale=3, + max_rooms=4, + room_min_size=4, + room_max_size=5, + spawns_per_room=1, + targets_per_room=3, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures, + ) + + # Build a task that rewards the agent for obtaining targets. + task = random_goal_maze.ManyGoalsMaze( + walker=walker, + maze_arena=arena, + target_builder=functools.partial( + target_sphere.TargetSphere, + radius=0.4, + rgb1=(0, 0, 0.4), + rgb2=(0, 0, 0.7)), + target_reward_scale=50., + physics_timestep=0.005, + control_timestep=0.03, + ) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def cmu_humanoid_heterogeneous_forage(random_state=None): + """Requires a CMU humanoid to find all items of a particular type in a maze.""" + level = ('*******\n' + '* *\n' + '* P *\n' + '* *\n' + '* G *\n' + '* *\n' + '*******\n') + + # Build a position-controlled CMU humanoid walker. + walker = cmu_humanoid.CMUHumanoidPositionControlled( + observable_options={'egocentric_camera': dict(enabled=True)}) + + skybox_texture = labmaze_textures.SkyBox(style='sky_03') + wall_textures = labmaze_textures.WallTextures(style='style_01') + floor_textures = labmaze_textures.FloorTextures(style='style_01') + maze = fixed_maze.FixedMazeWithRandomGoals( + entity_layer=level, + variations_layer=None, + num_spawns=1, + num_objects=6, + ) + arena = mazes.MazeWithTargets( + maze=maze, + xy_scale=3.0, + z_height=2.0, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures, + ) + task = random_goal_maze.ManyHeterogeneousGoalsMaze( + walker=walker, + maze_arena=arena, + target_builders=[ + functools.partial( + target_sphere.TargetSphere, + radius=0.4, + rgb1=(0, 0.4, 0), + rgb2=(0, 0.7, 0)), + functools.partial( + target_sphere.TargetSphere, + radius=0.4, + rgb1=(0.4, 0, 0), + rgb2=(0.7, 0, 0)), + ], + randomize_spawn_rotation=False, + target_type_rewards=[30., -10.], + target_type_proportions=[1, 1], + shuffle_target_builders=True, + aliveness_reward=0.01, + control_timestep=.03, + ) + + return composer.Environment( + time_limit=25, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py new file mode 100644 index 0000000..1b6283d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/basic_rodent_2020.py @@ -0,0 +1,175 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Produces reference environments for rodent tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from dm_control import composer +from dm_control.composer.variation import distributions +from dm_control.locomotion.arenas import bowl +from dm_control.locomotion.arenas import corridors as corr_arenas +from dm_control.locomotion.arenas import floors +from dm_control.locomotion.arenas import labmaze_textures +from dm_control.locomotion.arenas import mazes +from dm_control.locomotion.props import target_sphere +from dm_control.locomotion.tasks import corridors as corr_tasks +from dm_control.locomotion.tasks import escape +from dm_control.locomotion.tasks import random_goal_maze +from dm_control.locomotion.tasks import reach +from dm_control.locomotion.walkers import rodent + +_CONTROL_TIMESTEP = .02 +_PHYSICS_TIMESTEP = 0.001 + + +def rodent_escape_bowl(random_state=None): + """Requires a rodent to climb out of a bowl-shaped terrain.""" + + # Build a position-controlled rodent walker. + walker = rodent.Rat( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a bowl-shaped arena. + arena = bowl.Bowl( + size=(20., 20.), + aesthetic='outdoor_natural') + + # Build a task that rewards the agent for being far from the origin. + task = escape.Escape( + walker=walker, + arena=arena, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + return composer.Environment(time_limit=20, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def rodent_run_gaps(random_state=None): + """Requires a rodent to run down a corridor with gaps.""" + + # Build a position-controlled rodent walker. + walker = rodent.Rat( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a corridor-shaped arena with gaps, where the sizes of the gaps and + # platforms are uniformly randomized. + arena = corr_arenas.GapsCorridor( + platform_length=distributions.Uniform(.4, .8), + gap_length=distributions.Uniform(.05, .2), + corridor_width=2, + corridor_length=40, + aesthetic='outdoor_natural') + + # Build a task that rewards the agent for running down the corridor at a + # specific velocity. + task = corr_tasks.RunThroughCorridor( + walker=walker, + arena=arena, + walker_spawn_position=(5, 0, 0), + walker_spawn_rotation=0, + target_velocity=1.0, + contact_termination=False, + terminate_at_height=-0.3, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def rodent_maze_forage(random_state=None): + """Requires a rodent to find all items in a maze.""" + + # Build a position-controlled rodent walker. + walker = rodent.Rat( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build a maze with rooms and targets. + wall_textures = labmaze_textures.WallTextures(style='style_01') + arena = mazes.RandomMazeWithTargets( + x_cells=11, + y_cells=11, + xy_scale=.5, + z_height=.3, + max_rooms=4, + room_min_size=4, + room_max_size=5, + spawns_per_room=1, + targets_per_room=3, + wall_textures=wall_textures, + aesthetic='outdoor_natural') + + # Build a task that rewards the agent for obtaining targets. + task = random_goal_maze.ManyGoalsMaze( + walker=walker, + maze_arena=arena, + target_builder=functools.partial( + target_sphere.TargetSphere, + radius=0.05, + height_above_ground=.125, + rgb1=(0, 0, 0.4), + rgb2=(0, 0, 0.7)), + target_reward_scale=50., + contact_termination=False, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) + + +def rodent_two_touch(random_state=None): + """Requires a rodent to tap an orb, wait an interval, and tap it again.""" + + # Build a position-controlled rodent walker. + walker = rodent.Rat( + observable_options={'egocentric_camera': dict(enabled=True)}) + + # Build an open floor arena + arena = floors.Floor( + size=(10., 10.), + aesthetic='outdoor_natural') + + # Build a task that rewards the walker for touching/reaching orbs with a + # specific time interval between touches + task = reach.TwoTouch( + walker=walker, + arena=arena, + target_builders=[ + functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025), + ], + randomize_spawn_rotation=True, + target_type_rewards=[25.], + shuffle_target_builders=False, + target_area=(1.5, 1.5), + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP, + ) + + return composer.Environment(time_limit=30, + task=task, + random_state=random_state, + strip_singleton_obs_buffer_dim=True) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py new file mode 100644 index 0000000..2c210f1 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/examples_test.py @@ -0,0 +1,89 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for `dm_control.locomotion.examples`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.locomotion.examples import basic_cmu_2019 +from dm_control.locomotion.examples import basic_rodent_2020 + +import numpy as np +from six.moves import range + + +_NUM_EPISODES = 5 +_NUM_STEPS_PER_EPISODE = 10 + + +class ExampleEnvironmentsTest(parameterized.TestCase): + """Tests run on all the tasks registered.""" + + def _validate_observation(self, observation, observation_spec): + self.assertEqual(list(observation.keys()), list(observation_spec.keys())) + for name, array_spec in observation_spec.items(): + array_spec.validate(observation[name]) + + def _validate_reward_range(self, reward): + self.assertIsInstance(reward, float) + self.assertBetween(reward, 0, 1) + + def _validate_discount(self, discount): + self.assertIsInstance(discount, float) + self.assertBetween(discount, 0, 1) + + @parameterized.named_parameters( + ('cmu_humanoid_run_walls', basic_cmu_2019.cmu_humanoid_run_walls), + ('cmu_humanoid_run_gaps', basic_cmu_2019.cmu_humanoid_run_gaps), + ('cmu_humanoid_go_to_target', basic_cmu_2019.cmu_humanoid_go_to_target), + ('cmu_humanoid_maze_forage', basic_cmu_2019.cmu_humanoid_maze_forage), + ('cmu_humanoid_heterogeneous_forage', + basic_cmu_2019.cmu_humanoid_heterogeneous_forage), + ('rodent_escape_bowl', basic_rodent_2020.rodent_escape_bowl), + ('rodent_run_gaps', basic_rodent_2020.rodent_run_gaps), + ('rodent_maze_forage', basic_rodent_2020.rodent_maze_forage), + ('rodent_two_touch', basic_rodent_2020.rodent_two_touch), + ) + def test_env_runs(self, env_constructor): + """Tests that the environment runs and is coherent with its specs.""" + random_state = np.random.RandomState(99) + + env = env_constructor(random_state=random_state) + observation_spec = env.observation_spec() + action_spec = env.action_spec() + self.assertTrue(np.all(np.isfinite(action_spec.minimum))) + self.assertTrue(np.all(np.isfinite(action_spec.maximum))) + + # Run a partial episode, check observations, rewards, discount. + for _ in range(_NUM_EPISODES): + time_step = env.reset() + for _ in range(_NUM_STEPS_PER_EPISODE): + self._validate_observation(time_step.observation, observation_spec) + if time_step.first(): + self.assertIsNone(time_step.reward) + self.assertIsNone(time_step.discount) + else: + self._validate_reward_range(time_step.reward) + self._validate_discount(time_step.discount) + action = random_state.uniform(action_spec.minimum, action_spec.maximum) + env.step(action) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py b/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py new file mode 100644 index 0000000..a966283 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/examples/explore.py @@ -0,0 +1,28 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Simple script to launch viewer with an example environment.""" + +from absl import app + +from dm_control import viewer +from dm_control.locomotion.examples import basic_cmu_2019 + + +def main(unused_argv): + viewer.launch(environment_loader=basic_cmu_2019.cmu_humanoid_run_gaps) + +if __name__ == '__main__': + app.run(main) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/gaps.png b/DMC/src/env/dm_control/dm_control/locomotion/gaps.png new file mode 100644 index 0000000..9a6c16d Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/gaps.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py new file mode 100644 index 0000000..84e687e --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/props/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Props for Locomotion tasks.""" + +from dm_control.locomotion.props.target_sphere import TargetSphere +from dm_control.locomotion.props.target_sphere import TargetSphereTwoTouch diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py new file mode 100644 index 0000000..0c55b0c --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere.py @@ -0,0 +1,227 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A non-colliding sphere that is activated through touch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control import mjcf + + +class TargetSphere(composer.Entity): + """A non-colliding sphere that is activated through touch. + + Once the target has been reached, it remains in the "activated" state + for the remainder of the current episode. + + The target is automatically reset to "not activated" state at episode + initialization time. + """ + + def _build(self, + radius=0.6, + height_above_ground=1, + rgb1=(0, 0.4, 0), + rgb2=(0, 0.7, 0), + specific_collision_geom_ids=None, + name='target'): + """Builds this target sphere. + + Args: + radius: The radius (in meters) of this target sphere. + height_above_ground: The height (in meters) of this target above ground. + rgb1: A sequence of three floating point values between 0.0 and 1.0 + (inclusive) representing the color of the first element in the stripe + pattern of the target. + rgb2: A sequence of three floating point values between 0.0 and 1.0 + (inclusive) representing the color of the second element in the stripe + pattern of the target. + specific_collision_geom_ids: Only activate if collides with these geoms. + name: The name of this entity. + """ + self._mjcf_root = mjcf.RootElement(model=name) + self._texture = self._mjcf_root.asset.add( + 'texture', name='target_sphere', type='cube', + builtin='checker', rgb1=rgb1, rgb2=rgb2, + width='100', height='100') + self._material = self._mjcf_root.asset.add( + 'material', name='target_sphere', texture=self._texture) + self._geom = self._mjcf_root.worldbody.add( + 'geom', type='sphere', name='geom', gap=2*radius, + pos=[0, 0, height_above_ground], size=[radius], material=self._material) + self._geom_id = -1 + self._activated = False + self._specific_collision_geom_ids = specific_collision_geom_ids + + @property + def geom(self): + return self._geom + + @property + def material(self): + return self._material + + @property + def activated(self): + """Whether this target has been reached during this episode.""" + return self._activated + + def reset(self, physics): + self._activated = False + physics.bind(self._material).rgba[-1] = 1 + + @property + def mjcf_model(self): + return self._mjcf_root + + def initialize_episode_mjcf(self, unused_random_state): + self._activated = False + + def _update_activation(self, physics): + if not self._activated: + for contact in physics.data.contact: + if self._specific_collision_geom_ids: + has_specific_collision = ( + contact.geom1 in self._specific_collision_geom_ids or + contact.geom2 in self._specific_collision_geom_ids) + else: + has_specific_collision = True + if (has_specific_collision and + self._geom_id in (contact.geom1, contact.geom2)): + self._activated = True + physics.bind(self._material).rgba[-1] = 0 + + def initialize_episode(self, physics, unused_random_state): + self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom') + self._update_activation(physics) + + def after_substep(self, physics, unused_random_state): + self._update_activation(physics) + + +class TargetSphereTwoTouch(composer.Entity): + """A non-colliding sphere that is activated through touch. + + The target indicates if it has been touched at least once and touched at least + twice this episode with a two-bit activated state tuple. It remains activated + for the remainder of the current episode. + + The target is automatically reset at episode initialization. + """ + + def _build(self, + radius=0.6, + height_above_ground=1, + rgb_initial=((0, 0.4, 0), (0, 0.7, 0)), + rgb_interval=((1., 1., .4), (0.7, 0.7, 0.)), + rgb_final=((.4, 0.7, 1.), (0, 0.4, .7)), + touch_debounce=.2, + specific_collision_geom_ids=None, + name='target'): + """Builds this target sphere. + + Args: + radius: The radius (in meters) of this target sphere. + height_above_ground: The height (in meters) of this target above ground. + rgb_initial: A tuple of two colors for the stripe pattern of the target. + rgb_interval: A tuple of two colors for the stripe pattern of the target. + rgb_final: A tuple of two colors for the stripe pattern of the target. + touch_debounce: duration to not count second touch. + specific_collision_geom_ids: Only activate if collides with these geoms. + name: The name of this entity. + """ + self._mjcf_root = mjcf.RootElement(model=name) + self._texture_initial = self._mjcf_root.asset.add( + 'texture', name='target_sphere_init', type='cube', + builtin='checker', rgb1=rgb_initial[0], rgb2=rgb_initial[1], + width='100', height='100') + self._texture_interval = self._mjcf_root.asset.add( + 'texture', name='target_sphere_inter', type='cube', + builtin='checker', rgb1=rgb_interval[0], rgb2=rgb_interval[1], + width='100', height='100') + self._texture_final = self._mjcf_root.asset.add( + 'texture', name='target_sphere_final', type='cube', + builtin='checker', rgb1=rgb_final[0], rgb2=rgb_final[1], + width='100', height='100') + self._material = self._mjcf_root.asset.add( + 'material', name='target_sphere_init', texture=self._texture_initial) + self._geom = self._mjcf_root.worldbody.add( + 'geom', type='sphere', name='geom', gap=2*radius, + pos=[0, 0, height_above_ground], size=[radius], + material=self._material) + self._geom_id = -1 + self._touched_once = False + self._touched_twice = False + self._touch_debounce = touch_debounce + self._specific_collision_geom_ids = specific_collision_geom_ids + + @property + def geom(self): + return self._geom + + @property + def material(self): + return self._material + + @property + def activated(self): + """Whether this target has been reached during this episode.""" + return (self._touched_once, self._touched_twice) + + def reset(self, physics): + self._touched_once = False + self._touched_twice = False + self._geom.material = self._material + physics.bind(self._material).texid = physics.bind( + self._texture_initial).element_id + + @property + def mjcf_model(self): + return self._mjcf_root + + def initialize_episode_mjcf(self, unused_random_state): + self._touched_once = False + self._touched_twice = False + + def _update_activation(self, physics): + if not (self._touched_once and self._touched_twice): + for contact in physics.data.contact: + if self._specific_collision_geom_ids: + has_specific_collision = ( + contact.geom1 in self._specific_collision_geom_ids or + contact.geom2 in self._specific_collision_geom_ids) + else: + has_specific_collision = True + if (has_specific_collision and + self._geom_id in (contact.geom1, contact.geom2)): + if not self._touched_once: + self._touched_once = True + self._touch_time = physics.time() + physics.bind(self._material).texid = physics.bind( + self._texture_interval).element_id + if self._touched_once and ( + physics.time() > (self._touch_time + self._touch_debounce)): + self._touched_twice = True + physics.bind(self._material).texid = physics.bind( + self._texture_final).element_id + + def initialize_episode(self, physics, unused_random_state): + self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom') + self._update_activation(physics) + + def after_substep(self, physics, unused_random_state): + self._update_activation(physics) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py new file mode 100644 index 0000000..a83a39d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/props/target_sphere_test.py @@ -0,0 +1,70 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for props.target_sphere.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. +from absl.testing import absltest +from dm_control import composer + +from dm_control.entities.props import primitive +from dm_control.locomotion.arenas import floors +from dm_control.locomotion.props import target_sphere + + +class TargetSphereTest(absltest.TestCase): + + def testActivation(self): + target_radius = 0.6 + prop_radius = 0.1 + target_height = 1 + + arena = floors.Floor() + target = target_sphere.TargetSphere(radius=target_radius, + height_above_ground=target_height) + prop = primitive.Primitive(geom_type='sphere', size=[prop_radius]) + arena.attach(target) + arena.add_free_entity(prop) + + task = composer.NullTask(arena) + task.initialize_episode = ( + lambda physics, random_state: prop.set_pose(physics, [0, 0, 2])) + + env = composer.Environment(task) + env.reset() + + max_activated_height = target_height + target_radius + prop_radius + + while env.physics.bind(prop.geom).xpos[2] > max_activated_height: + self.assertFalse(target.activated) + self.assertEqual(env.physics.bind(target.material).rgba[-1], 1) + env.step([]) + + while env.physics.bind(prop.geom).xpos[2] > 0.2: + self.assertTrue(target.activated) + self.assertEqual(env.physics.bind(target.material).rgba[-1], 0) + env.step([]) + + # Target should be reset when the environment is reset. + env.reset() + self.assertFalse(target.activated) + self.assertEqual(env.physics.bind(target.material).rgba[-1], 1) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md b/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md new file mode 100644 index 0000000..49b8504 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/README.md @@ -0,0 +1,64 @@ +# DeepMind MuJoCo Multi-Agent Soccer Environment. + +This submodule contains the components and environment described in ICLR 2019 +paper [Emergent Coordination through Competition][website]. + +# ![soccer](soccer.png) + +## Installation and requirements + +See [dm_control](../../../README.md#installation-and-requirements) for instructions. + +## Quickstart + +```python +import numpy as np +from dm_control.locomotion import soccer as dm_soccer + +# Load the 2-vs-2 soccer environment with episodes of 10 seconds: +env = dm_soccer.load(team_size=2, time_limit=10.) + +# Retrieves action_specs for all 4 players. +action_specs = env.action_spec() + +# Step through the environment for one episode with random actions. +time_step = env.reset() +while not time_step.last(): + actions = [] + for action_spec in action_specs: + action = np.random.uniform( + action_spec.minimum, action_spec.maximum, size=action_spec.shape) + actions.append(action) + time_step = env.step(actions) + + for i in range(len(action_specs)): + print( + "Player {}: reward = {}, discount = {}, observations = {}.".format( + i, time_step.reward[i], time_step.discount, + time_step.observation[i])) +``` + +## Rewards + +The environment provides a reward of +1 to each player when their team +scores a goal, -1 when their team concedes a goal, or 0 if neither team scored +on the current timestep. + +In addition to the sparse reward returned the environment, the player +observations also contain various environment statistics that may be used to +derive custom per-player shaping rewards (as was done in +http://arxiv.org/abs/1902.07151, where the environment reward was ignored). + +## Episode terminations + +Episodes will terminate immediately with a discount factor of 0 when either side +scores a goal. There is also a per-episode `time_limit` (45 seconds by default). +If neither team scores within this time then the episode will terminate with a +discount factor of 1. + +## Environment Viewer + +To visualize an example 2-vs-2 soccer environment in the `dm_control` +interactive viewer, execute `dm_control/locomotion/soccer/explore.py`. + +[website]: https://sites.google.com/corp/view/emergent-coordination/home diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py new file mode 100644 index 0000000..20dc585 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/__init__.py @@ -0,0 +1,94 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Multi-agent MuJoCo soccer environment.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.locomotion.soccer.boxhead import BoxHead +from dm_control.locomotion.soccer.initializers import Initializer +from dm_control.locomotion.soccer.initializers import UniformInitializer +from dm_control.locomotion.soccer.observables import CoreObservablesAdder +from dm_control.locomotion.soccer.observables import InterceptionObservablesAdder +from dm_control.locomotion.soccer.observables import MultiObservablesAdder +from dm_control.locomotion.soccer.observables import ObservablesAdder +from dm_control.locomotion.soccer.pitch import Pitch +from dm_control.locomotion.soccer.pitch import RandomizedPitch +from dm_control.locomotion.soccer.soccer_ball import SoccerBall +from dm_control.locomotion.soccer.task import Task +from dm_control.locomotion.soccer.team import Player +from dm_control.locomotion.soccer.team import Team +from six.moves import range + +_RGBA_BLUE = [.1, .1, .8, 1.] +_RGBA_RED = [.8, .1, .1, 1.] + + +def _make_walker(name, walker_id, marker_rgba): + """Construct a BoxHead walker.""" + return BoxHead( + name=name, + walker_id=walker_id, + marker_rgba=marker_rgba, + ) + + +def _make_players(team_size): + """Construct home and away teams each of `team_size` players.""" + home_players = [] + away_players = [] + for i in range(team_size): + home_players.append( + Player(Team.HOME, _make_walker("home%d" % i, i, _RGBA_BLUE))) + away_players.append( + Player(Team.AWAY, _make_walker("away%d" % i, i, _RGBA_RED))) + return home_players + away_players + + +def load(team_size, + time_limit=45., + random_state=None, + disable_walker_contacts=False): + """Construct `team_size`-vs-`team_size` soccer environment. + + Args: + team_size: Integer, the number of players per team. Must be between 1 and + 11. + time_limit: Float, the maximum duration of each episode in seconds. + random_state: (optional) an int seed or `np.random.RandomState` instance. + disable_walker_contacts: (optional) if `True`, disable physical contacts + between walkers. + + Returns: + A `composer.Environment` instance. + + Raises: + ValueError: If `team_size` is not between 1 and 11. + """ + if team_size < 0 or team_size > 11: + raise ValueError( + "Team size must be between 1 and 11 (received %d)." % team_size) + + return composer.Environment( + task=Task( + players=_make_players(team_size), + arena=RandomizedPitch( + min_size=(32, 24), max_size=(48, 36), keep_aspect_ratio=True), + disable_walker_contacts=disable_walker_contacts), + time_limit=time_limit, + random_state=random_state) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml new file mode 100644 index 0000000..17cb28d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/boxhead.xml @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png new file mode 100644 index 0000000..b5fde93 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/00.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png new file mode 100644 index 0000000..f52bae4 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/01.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png new file mode 100644 index 0000000..bec2be6 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/02.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png new file mode 100644 index 0000000..654dba2 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/03.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png new file mode 100644 index 0000000..f320726 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/04.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png new file mode 100644 index 0000000..64d8cd6 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/05.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png new file mode 100644 index 0000000..ed97f8b Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/06.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png new file mode 100644 index 0000000..4808c29 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/07.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png new file mode 100644 index 0000000..41d09b9 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/08.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png new file mode 100644 index 0000000..649723c Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/09.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png new file mode 100644 index 0000000..d2cc0d0 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/boxhead/digits/10.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png new file mode 100644 index 0000000..c01b517 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/back.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png new file mode 100644 index 0000000..49ace5b Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/down.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png new file mode 100644 index 0000000..c18fbf0 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/front.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png new file mode 100644 index 0000000..d3b14f7 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/left.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png new file mode 100644 index 0000000..69965a8 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/right.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png new file mode 100644 index 0000000..8c729d7 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/assets/soccer_ball/up.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py new file mode 100644 index 0000000..8963b7a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead.py @@ -0,0 +1,292 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Walkers based on an actuated jumping ball.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable +from dm_control.locomotion.walkers import legacy_base +import numpy as np +from PIL import Image +import six + +from dm_control.utils import io as resources + +_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'boxhead') +_MAX_WALKER_ID = 10 +_INVALID_WALKER_ID = 'walker_id must be in [0-{}], got: {{}}.'.format( + _MAX_WALKER_ID) + + +def _compensate_gravity(physics, body_elements): + """Applies Cartesian forces to bodies in order to exactly counteract gravity. + + Note that this will also affect the output of pressure, force, or torque + sensors within the kinematic chain leading from the worldbody to the bodies + that are being gravity-compensated. + + Args: + physics: An `mjcf.Physics` instance to modify. + body_elements: An iterable of `mjcf.Element`s specifying the bodies to which + gravity compensation will be applied. + """ + gravity = np.hstack([physics.model.opt.gravity, [0, 0, 0]]) + bodies = physics.bind(body_elements) + bodies.xfrc_applied = -gravity * bodies.mass[..., None] + + +def _alpha_blend(foreground, background): + """Does alpha compositing of two RGBA images. + + Both inputs must be (..., 4) numpy arrays whose shapes are compatible for + broadcasting. They are assumed to contain float RGBA values in [0, 1]. + + Args: + foreground: foreground RGBA image. + background: background RGBA image. + + Returns: + A numpy array of shape (..., 4) containing the blended image. + """ + fg, bg = np.broadcast_arrays(foreground, background) + fg_rgb = fg[..., :3] + fg_a = fg[..., 3:] + bg_rgb = bg[..., :3] + bg_a = bg[..., 3:] + out = np.empty_like(bg) + out_a = out[..., 3:] + out_rgb = out[..., :3] + # https://en.wikipedia.org/wiki/Alpha_compositing#Alpha_blending + out_a[:] = fg_a + bg_a * (1. - fg_a) + out_rgb[:] = fg_rgb * fg_a + bg_rgb * bg_a * (1. - fg_a) + # Avoid division by zero if foreground and background are both transparent. + out_rgb[:] = np.where(out_a, out_rgb / out_a, out_rgb) + return out + + +def _asset_png_with_background_rgba_bytes(asset_fname, background_rgba): + """Decode PNG from asset file and add solid background.""" + + # Retrieve PNG image contents as a bytestring, convert to a numpy array. + contents = resources.GetResource(os.path.join(_ASSETS_PATH, asset_fname)) + digit_rgba = np.array(Image.open(six.BytesIO(contents)), dtype=np.double) + + # Add solid background with `background_rgba`. + blended = 255. * _alpha_blend(digit_rgba / 255., np.asarray(background_rgba)) + + # Encode composite image array to a PNG bytestring. + img = Image.fromarray(blended.astype(np.uint8), mode='RGBA') + buf = six.BytesIO() + img.save(buf, format='PNG') + png_encoding = buf.getvalue() + buf.close() + + return png_encoding + + +class BoxHeadObservables(legacy_base.WalkerObservables): + + def __init__(self, entity, camera_resolution): + self._camera_resolution = camera_resolution + super(BoxHeadObservables, self).__init__(entity) + + @composer.observable + def egocentric_camera(self): + width, height = self._camera_resolution + return observable.MJCFCamera(self._entity.egocentric_camera, + width=width, height=height) + + +class BoxHead(legacy_base.Walker): + """A rollable and jumpable ball with a head.""" + + def _build(self, + name='walker', + marker_rgba=None, + camera_control=False, + camera_resolution=(28, 28), + roll_gear=-60, + steer_gear=55, + walker_id=None, + initializer=None): + """Build a BoxHead. + + Args: + name: name of the walker. + marker_rgba: RGBA value set to walker.marker_geoms to distinguish between + walkers (in multi-agent setting). + camera_control: If `True`, the walker exposes two additional actuated + degrees of freedom to control the egocentric camera height and tilt. + camera_resolution: egocentric camera rendering resolution. + roll_gear: gear determining forward acceleration. + steer_gear: gear determining steering (spinning) torque. + walker_id: (Optional) An integer in [0-10], this number will be shown on + the walker's head. Defaults to `None` which does not show any number. + initializer: (Optional) A `WalkerInitializer` object. + + Raises: + ValueError: if received invalid walker_id. + """ + super(BoxHead, self)._build(initializer=initializer) + xml_path = os.path.join(_ASSETS_PATH, 'boxhead.xml') + self._mjcf_root = mjcf.from_xml_string(resources.GetResource(xml_path, 'r')) + if name: + self._mjcf_root.model = name + + if walker_id is not None and not 0 <= walker_id <= _MAX_WALKER_ID: + raise ValueError(_INVALID_WALKER_ID.format(walker_id)) + + self._walker_id = walker_id + if walker_id is not None: + png_bytes = _asset_png_with_background_rgba_bytes( + 'digits/%02d.png' % walker_id, marker_rgba) + head_texture = self._mjcf_root.asset.add( + 'texture', + name='head_texture', + type='2d', + file=mjcf.Asset(png_bytes, '.png')) + head_material = self._mjcf_root.asset.add( + 'material', name='head_material', texture=head_texture) + self._mjcf_root.find('geom', 'head').material = head_material + self._mjcf_root.find('geom', 'head').rgba = None + + self._mjcf_root.find('geom', 'top_down_cam_box').material = head_material + self._mjcf_root.find('geom', 'top_down_cam_box').rgba = None + + self._body_texture = self._mjcf_root.asset.add( + 'texture', + name='ball_body', + type='cube', + builtin='checker', + rgb1=marker_rgba[:-1] if marker_rgba else '.4 .4 .4', + rgb2='.8 .8 .8', + width='100', + height='100') + self._body_material = self._mjcf_root.asset.add( + 'material', name='ball_body', texture=self._body_texture) + self._mjcf_root.find('geom', 'shell').material = self._body_material + + # Set corresponding marker color if specified. + if marker_rgba is not None: + for geom in self.marker_geoms: + geom.set_attributes(rgba=marker_rgba) + + self._root_joints = None + self._camera_control = camera_control + self._camera_resolution = camera_resolution + if not camera_control: + for name in ('camera_pitch', 'camera_yaw'): + self._mjcf_root.find('actuator', name).remove() + self._mjcf_root.find('joint', name).remove() + self._roll_gear = roll_gear + self._steer_gear = steer_gear + self._mjcf_root.find('actuator', 'roll').gear[0] = self._roll_gear + self._mjcf_root.find('actuator', 'steer').gear[0] = self._steer_gear + + # Initialize previous action. + self._prev_action = np.zeros(shape=self.action_spec.shape, + dtype=self.action_spec.dtype) + + def _build_observables(self): + return BoxHeadObservables(self, camera_resolution=self._camera_resolution) + + @property + def marker_geoms(self): + geoms = [ + self._mjcf_root.find('geom', 'arm_l'), + self._mjcf_root.find('geom', 'arm_r'), + self._mjcf_root.find('geom', 'eye_l'), + self._mjcf_root.find('geom', 'eye_r'), + ] + if self._walker_id is None: + geoms.append(self._mjcf_root.find('geom', 'head')) + return geoms + + def create_root_joints(self, attachment_frame): + root_class = self._mjcf_root.find('default', 'root') + root_x = attachment_frame.add( + 'joint', name='root_x', type='slide', axis=[1, 0, 0], dclass=root_class) + root_y = attachment_frame.add( + 'joint', name='root_y', type='slide', axis=[0, 1, 0], dclass=root_class) + root_z = attachment_frame.add( + 'joint', name='root_z', type='slide', axis=[0, 0, 1], dclass=root_class) + self._root_joints = [root_x, root_y, root_z] + + def set_pose(self, physics, position=None, quaternion=None): + if position is not None: + if self._root_joints is not None: + physics.bind(self._root_joints).qpos = position + else: + super(BoxHead, self).set_pose(physics, position, quaternion=None) + physics.bind(self._mjcf_root.find_all('joint')).qpos = 0. + if quaternion is not None: + # This walker can only rotate along the z-axis, so we extract only that + # component from the quaternion. + z_angle = np.arctan2( + 2 * (quaternion[0] * quaternion[3] + quaternion[1] * quaternion[2]), + 1 - 2 * (quaternion[2] ** 2 + quaternion[3] ** 2)) + physics.bind(self._mjcf_root.find('joint', 'steer')).qpos = z_angle + + def initialize_episode(self, physics, unused_random_state): + if self._camera_control: + _compensate_gravity(physics, + self._mjcf_root.find('body', 'egocentric_camera')) + self._prev_action = np.zeros(shape=self.action_spec.shape, + dtype=self.action_spec.dtype) + + def apply_action(self, physics, action, random_state): + super(BoxHead, self).apply_action(physics, action, random_state) + + # Updates previous action. + self._prev_action[:] = action + + @property + def mjcf_model(self): + return self._mjcf_root + + @composer.cached_property + def actuators(self): + return self._mjcf_root.find_all('actuator') + + @composer.cached_property + def root_body(self): + return self._mjcf_root.find('body', 'head_body') + + @composer.cached_property + def end_effectors(self): + return (self._mjcf_root.find('body', 'head_body'),) + + @composer.cached_property + def observable_joints(self): + return (self._mjcf_root.find('joint', 'kick'),) + + @composer.cached_property + def egocentric_camera(self): + return self._mjcf_root.find('camera', 'egocentric') + + @composer.cached_property + def ground_contact_geoms(self): + return (self._mjcf_root.find('geom', 'shell'),) + + @property + def prev_action(self): + return self._prev_action diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py new file mode 100644 index 0000000..684b4e7 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/boxhead_test.py @@ -0,0 +1,49 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.soccer.boxhead.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.locomotion.soccer import boxhead + + +class BoxheadTest(parameterized.TestCase): + + @parameterized.parameters( + dict(camera_control=True, walker_id=None), + dict(camera_control=False, walker_id=None), + dict(camera_control=True, walker_id=0), + dict(camera_control=False, walker_id=10)) + def test_instantiation(self, camera_control, walker_id): + boxhead.BoxHead(marker_rgba=[.8, .1, .1, 1.], + camera_control=camera_control, + walker_id=walker_id) + + @parameterized.parameters(-1, 11) + def test_invalid_walker_id(self, walker_id): + with self.assertRaisesWithLiteralMatch( + ValueError, boxhead._INVALID_WALKER_ID.format(walker_id)): + boxhead.BoxHead(walker_id=walker_id) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py new file mode 100644 index 0000000..29d020d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/explore.py @@ -0,0 +1,35 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Interactive viewer for MuJoCo soccer enviornmnet.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from absl import app +from dm_control import viewer +from dm_control.locomotion import soccer + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + viewer.launch(environment_loader=functools.partial(soccer.load, team_size=2)) + + +if __name__ == '__main__': + app.run(main) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py new file mode 100644 index 0000000..22785ce --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/initializers.py @@ -0,0 +1,73 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Soccer task episode initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import numpy as np +import six + + +_INIT_BALL_Z = 0.5 +_SPAWN_RATIO = 0.6 + + +@six.add_metaclass(abc.ABCMeta) +class Initializer(object): + + @abc.abstractmethod + def __call__(self, task, physics, random_state): + """Initialize episode for a task.""" + + +class UniformInitializer(Initializer): + """Uniformly initialize walkers and soccer ball over spawn_range.""" + + def __init__(self, spawn_ratio=_SPAWN_RATIO, init_ball_z=_INIT_BALL_Z): + self._spawn_ratio = spawn_ratio + self._init_ball_z = init_ball_z + + def _initialize_ball(self, ball, spawn_range, physics, random_state): + x, y = random_state.uniform(-spawn_range, spawn_range) + ball.set_pose(physics, [x, y, self._init_ball_z]) + # Note: this method is not always called immediately after `physics.reset()` + # so we need to explicitly zero out the velocity. + ball.set_velocity(physics, velocity=0., angular_velocity=0.) + + def _initialize_walker(self, walker, spawn_range, physics, random_state): + """Uniformly initialize walker in spawn_range.""" + walker.reinitialize_pose(physics, random_state) + x, y = random_state.uniform(-spawn_range, spawn_range) + (_, _, z), quat = walker.get_pose(physics) + walker.set_pose(physics, [x, y, z], quat) + rotation = random_state.uniform(-np.pi, np.pi) + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walker.shift_pose(physics, quaternion=quat) + # TODO(b/132759890): `walker.set_velocity` has no effect for walkers without + # freejoints, such as `BoxHead`. + # Note: this method is not always called immediately after `physics.reset()` + # so we need to explicitly zero out the velocity. + walker.set_velocity(physics, velocity=0., angular_velocity=0.) + + def __call__(self, task, physics, random_state): + spawn_range = np.asarray(task.arena.size) * self._spawn_ratio + self._initialize_ball(task.ball, spawn_range, physics, random_state) + for player in task.players: + self._initialize_walker(player.walker, spawn_range, physics, random_state) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py new file mode 100644 index 0000000..8ec2ec9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/loader_test.py @@ -0,0 +1,107 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.soccer.load.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +from dm_control.locomotion import soccer +import numpy as np +from six.moves import range + + +class LoadTest(parameterized.TestCase): + + @parameterized.named_parameters( + ("2vs2_nocontacts", 2, True), ("2vs2_contacts", 2, False), + ("1vs1_nocontacts", 1, True), ("1vs1_contacts", 1, False)) + def test_load_env(self, team_size, disable_walker_contacts): + env = soccer.load(team_size=team_size, time_limit=2., + disable_walker_contacts=disable_walker_contacts) + action_specs = env.action_spec() + + random_state = np.random.RandomState(0) + time_step = env.reset() + while not time_step.last(): + actions = [] + for action_spec in action_specs: + action = random_state.uniform( + action_spec.minimum, action_spec.maximum, size=action_spec.shape) + actions.append(action) + time_step = env.step(actions) + + for i in range(len(action_specs)): + logging.info( + "Player %d: reward = %s, discount = %s, observations = %s.", i, + time_step.reward[i], time_step.discount, time_step.observation[i]) + + def assertSameObservation(self, expected_observation, actual_observation): + self.assertLen(actual_observation, len(expected_observation)) + for player_id in range(len(expected_observation)): + expected_player_observations = expected_observation[player_id] + actual_player_observations = actual_observation[player_id] + expected_keys = expected_player_observations.keys() + actual_keys = actual_player_observations.keys() + msg = ("Observation keys differ for player {}.\nExpected: {}.\nActual: {}" + .format(player_id, expected_keys, actual_keys)) + self.assertEqual(expected_keys, actual_keys, msg) + for key in expected_player_observations: + expected_array = expected_player_observations[key] + actual_array = actual_player_observations[key] + msg = ("Observation {!r} differs for player {}.\nExpected:\n{}\n" + "Actual:\n{}" + .format(key, player_id, expected_array, actual_array)) + np.testing.assert_array_equal(expected_array, actual_array, + err_msg=msg) + + @parameterized.parameters(True, False) + def test_same_first_observation_if_same_seed(self, disable_walker_contacts): + seed = 42 + timestep_1 = soccer.load( + team_size=2, + random_state=seed, + disable_walker_contacts=disable_walker_contacts).reset() + timestep_2 = soccer.load( + team_size=2, + random_state=seed, + disable_walker_contacts=disable_walker_contacts).reset() + self.assertSameObservation(timestep_1.observation, timestep_2.observation) + + @parameterized.parameters(True, False) + def test_different_first_observation_if_different_seed( + self, disable_walker_contacts): + timestep_1 = soccer.load( + team_size=2, + random_state=1, + disable_walker_contacts=disable_walker_contacts).reset() + timestep_2 = soccer.load( + team_size=2, + random_state=2, + disable_walker_contacts=disable_walker_contacts).reset() + try: + self.assertSameObservation(timestep_1.observation, timestep_2.observation) + except AssertionError: + pass + else: + self.fail("Observations are unexpectedly identical.") + + +if __name__ == "__main__": + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py new file mode 100644 index 0000000..da3d406 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/observables.py @@ -0,0 +1,432 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Soccer observables modules.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from dm_control.composer.observation import observable as base_observable +from dm_control.locomotion.soccer import team as team_lib +import numpy as np +import six +from six.moves import zip + + +@six.add_metaclass(abc.ABCMeta) +class ObservablesAdder(object): + """A callable that adds a set of per-player observables for a task.""" + + @abc.abstractmethod + def __call__(self, task, player): + """Adds observables to a player for the given task. + + Args: + task: A `soccer.Task` instance. + player: A `Walker` instance to which observables will be added. + """ + + +class MultiObservablesAdder(ObservablesAdder): + """Applies multiple `ObservablesAdder`s to a soccer task and player.""" + + def __init__(self, observables): + """Initializes a `MultiObservablesAdder` instance. + + Args: + observables: A list of `ObservablesAdder` instances. + """ + self._observables = observables + + def __call__(self, task, player): + """Adds observables to a player for the given task. + + Args: + task: A `soccer.Task` instance. + player: A `Walker` instance to which observables will be added. + """ + for observable in self._observables: + observable(task, player) + + +class CoreObservablesAdder(ObservablesAdder): + """Core set of per player observables.""" + + def __call__(self, task, player): + """Adds observables to a player for the given task. + + Args: + task: A `soccer.Task` instance. + player: A `Walker` instance to which observables will be added. + """ + # Enable proprioceptive observables. + self._add_player_proprio_observables(player) + + # Add egocentric observations of soccer ball. + self._add_player_observables_on_ball(player, task.ball) + + # Add egocentric observations of others. + teammate_id = 0 + opponent_id = 0 + for other in task.players: + if other is player: + continue + # Infer team prefix for `other` conditioned on `player.team`. + if player.team != other.team: + prefix = 'opponent_{}'.format(opponent_id) + opponent_id += 1 + else: + prefix = 'teammate_{}'.format(teammate_id) + teammate_id += 1 + + self._add_player_observables_on_other(player, other, prefix) + + self._add_player_arena_observables(player, task.arena) + + # Add per player game statistics. + self._add_player_stats_observables(task, player) + + def _add_player_observables_on_other(self, player, other, prefix): + """Add observables of another player in this player's egocentric frame. + + Args: + player: A `Walker` instance, the player we are adding observables to. + other: A `Walker` instance corresponding to a different player. + prefix: A string specifying a prefix to apply to the names of observables + belonging to `player`. + """ + if player is other: + raise ValueError('Cannot add egocentric observables of player on itself.') + # Origin callable in xpos, xvel for `player`. + xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos + xvel_xyz_callable = lambda p: p.bind(player.walker.root_body).cvel[3:] + # Egocentric observation of other's position, orientation and + # linear velocities. + def _cvel_observation(physics, other=other): + # Velocitmeter reads in local frame but we need world frame observable + # for egocentric transformation. + return physics.bind(other.walker.root_body).cvel[3:] + + def _egocentric_end_effectors_xpos(physics, other=other): + origin_xpos = xpos_xyz_callable(physics) + egocentric_end_effectors_xpos = [] + for end_effector_body in other.walker.end_effectors: + xpos = physics.bind(end_effector_body).xpos + delta = xpos - origin_xpos + ego_xpos = player.walker.transform_vec_to_egocentric_frame( + physics, delta) + egocentric_end_effectors_xpos.append(ego_xpos) + return np.concatenate(egocentric_end_effectors_xpos) + + player.walker.observables.add_egocentric_vector( + '{}_ego_linear_velocity'.format(prefix), + base_observable.Generic(_cvel_observation), + origin_callable=xvel_xyz_callable) + player.walker.observables.add_egocentric_vector( + '{}_ego_position'.format(prefix), + other.walker.observables.position, + origin_callable=xpos_xyz_callable) + player.walker.observables.add_egocentric_xmat( + '{}_ego_orientation'.format(prefix), + other.walker.observables.orientation) + + # Adds end effectors of the other agents in the player's egocentric frame. + player.walker.observables.add_observable( + '{}_ego_end_effectors_pos'.format(prefix), + base_observable.Generic(_egocentric_end_effectors_xpos)) + + # Adds end effectors of the other agents in the other's egocentric frame. + # A is seeing B's hand extended to B's right. + player.walker.observables.add_observable( + '{}_end_effectors_pos'.format(prefix), + other.walker.observables.end_effectors_pos) + + def _add_player_observables_on_ball(self, player, ball): + """Add observables of the soccer ball in this player's egocentric frame. + + Args: + player: A `Walker` instance, the player we are adding observations for. + ball: A `SoccerBall` instance. + """ + # Origin callables for egocentric transformations. + xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos + xvel_xyz_callable = lambda p: p.bind(player.walker.root_body).cvel[3:] + + # Add egocentric ball observations. + player.walker.observables.add_egocentric_vector( + 'ball_ego_angular_velocity', ball.observables.angular_velocity) + player.walker.observables.add_egocentric_vector( + 'ball_ego_position', + ball.observables.position, + origin_callable=xpos_xyz_callable) + player.walker.observables.add_egocentric_vector( + 'ball_ego_linear_velocity', + ball.observables.linear_velocity, + origin_callable=xvel_xyz_callable) + + def _add_player_proprio_observables(self, player): + """Add proprioceptive observables to the given player. + + Args: + player: A `Walker` instance, the player we are adding observations for. + """ + for observable in (player.walker.observables.proprioception + + player.walker.observables.kinematic_sensors): + observable.enabled = True + + # Also enable previous action observable as part of proprioception. + player.walker.observables.prev_action.enabled = True + + def _add_player_arena_observables(self, player, arena): + """Add observables of the arena. + + Args: + player: A `Walker` instance to which observables will be added. + arena: A `Pitch` instance. + """ + # Enable egocentric view of position detectors (goal, field). + # Corners named according to walker *facing towards opponent goal*. + clockwise_names = [ + 'team_goal_back_right', + 'team_goal_mid', + 'team_goal_front_left', + 'field_front_left', + 'opponent_goal_back_left', + 'opponent_goal_mid', + 'opponent_goal_front_right', + 'field_back_right', + ] + clockwise_features = [ + lambda _: arena.home_goal.lower[:2], + lambda _: arena.home_goal.mid, + lambda _: arena.home_goal.upper[:2], + lambda _: arena.field.upper, + lambda _: arena.away_goal.upper[:2], + lambda _: arena.away_goal.mid, + lambda _: arena.away_goal.lower[:2], + lambda _: arena.field.lower, + ] + xpos_xyz_callable = lambda p: p.bind(player.walker.root_body).xpos + xpos_xy_callable = lambda p: p.bind(player.walker.root_body).xpos[:2] + # A list of egocentric reference origin for each one of clockwise_features. + clockwise_origins = [ + xpos_xy_callable, + xpos_xyz_callable, + xpos_xy_callable, + xpos_xy_callable, + xpos_xy_callable, + xpos_xyz_callable, + xpos_xy_callable, + xpos_xy_callable, + ] + if player.team != team_lib.Team.HOME: + half = len(clockwise_features) // 2 + clockwise_features = clockwise_features[half:] + clockwise_features[:half] + clockwise_origins = clockwise_origins[half:] + clockwise_origins[:half] + + for name, feature, origin in zip(clockwise_names, clockwise_features, + clockwise_origins): + player.walker.observables.add_egocentric_vector( + name, base_observable.Generic(feature), origin_callable=origin) + + def _add_player_stats_observables(self, task, player): + """Add observables corresponding to game statistics. + + Args: + task: A `soccer.Task` instance. + player: A `Walker` instance to which observables will be added. + """ + + def _stats_vel_to_ball(physics): + dir_ = ( + physics.bind(task.ball.geom).xpos - + physics.bind(player.walker.root_body).xpos) + vel_to_ball = np.dot(dir_[:2] / (np.linalg.norm(dir_[:2]) + 1e-7), + physics.bind(player.walker.root_body).cvel[3:5]) + return np.sum(vel_to_ball) + + player.walker.observables.add_observable( + 'stats_vel_to_ball', base_observable.Generic(_stats_vel_to_ball)) + + def _stats_closest_vel_to_ball(physics): + """Velocity to the ball if this walker is the team's closest.""" + closest = None + min_team_dist_to_ball = np.inf + for player_ in task.players: + if player_.team == player.team: + dist_to_ball = np.linalg.norm( + physics.bind(task.ball.geom).xpos - + physics.bind(player_.walker.root_body).xpos) + if dist_to_ball < min_team_dist_to_ball: + min_team_dist_to_ball = dist_to_ball + closest = player_ + if closest is player: + return _stats_vel_to_ball(physics) + return 0. + + player.walker.observables.add_observable( + 'stats_closest_vel_to_ball', + base_observable.Generic(_stats_closest_vel_to_ball)) + + def _stats_veloc_forward(physics): + """Player's forward velocity.""" + return player.walker.observables.veloc_forward(physics) + + player.walker.observables.add_observable( + 'stats_veloc_forward', base_observable.Generic(_stats_veloc_forward)) + + def _stats_vel_ball_to_goal(physics): + """Ball velocity towards opponents' goal.""" + if player.team == team_lib.Team.HOME: + goal = task.arena.away_goal + else: + goal = task.arena.home_goal + + goal_center = (goal.upper + goal.lower) / 2. + direction = goal_center - physics.bind(task.ball.geom).xpos + ball_vel_observable = task.ball.observables.linear_velocity + ball_vel = ball_vel_observable.observation_callable(physics)() + + norm_dir = np.linalg.norm(direction) + normalized_dir = direction / norm_dir if norm_dir else direction + return np.sum(np.dot(normalized_dir, ball_vel)) + + player.walker.observables.add_observable( + 'stats_vel_ball_to_goal', + base_observable.Generic(_stats_vel_ball_to_goal)) + + def _stats_avg_teammate_dist(physics): + """Compute average distance from `walker` to its teammates.""" + teammate_dists = [] + for other in task.players: + if player is other: + continue + if other.team != player.team: + continue + dist = np.linalg.norm( + physics.bind(player.walker.root_body).xpos - + physics.bind(other.walker.root_body).xpos) + teammate_dists.append(dist) + return np.mean(teammate_dists) if teammate_dists else 0. + + player.walker.observables.add_observable( + 'stats_home_avg_teammate_dist', + base_observable.Generic(_stats_avg_teammate_dist)) + + def _stats_teammate_spread_out(physics): + """Compute average distance from `walker` to its teammates.""" + return _stats_avg_teammate_dist(physics) > 5. + + player.walker.observables.add_observable( + 'stats_teammate_spread_out', + base_observable.Generic(_stats_teammate_spread_out)) + + def _stats_home_score(unused_physics): + if (task.arena.detected_goal() and + task.arena.detected_goal() == player.team): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_home_score', base_observable.Generic(_stats_home_score)) + + has_opponent = any([p.team != player.team for p in task.players]) + + def _stats_away_score(unused_physics): + if (has_opponent and task.arena.detected_goal() and + task.arena.detected_goal() != player.team): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_away_score', base_observable.Generic(_stats_away_score)) + + +# TODO(b/124848293): add unit-test interception observables. +class InterceptionObservablesAdder(ObservablesAdder): + """Adds obervables representing interception events. + + These observables represent events where this player received the ball from + another player, or when an opponent intercepted the ball from this player's + team. For each type of event there are three different thresholds applied to + the distance travelled by the ball since it last made contact with a player + (5, 10, or 15 meters). + + For example, on a given timestep `stats_i_received_ball_10m` will be 1 if + * This player just made contact with the ball + * The last player to have made contact with the ball was a different player + * The ball travelled for at least 10 m since it last hit a player + and 0 otherwise. + + Conversely, `stats_opponent_intercepted_ball_10m` will be 1 if: + * An opponent just made contact with the ball + * The last player to have made contact with the ball was on this player's team + * The ball travelled for at least 10 m since it last hit a player + """ + + def __call__(self, task, player): + """Adds observables to a player for the given task. + + Args: + task: A `soccer.Task` instance. + player: A `Walker` instance to which observables will be added. + """ + + def _stats_i_received_ball(unused_physics): + if (task.ball.hit and task.ball.repossessed and + task.ball.last_hit is player): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_i_received_ball', + base_observable.Generic(_stats_i_received_ball)) + + def _stats_opponent_intercepted_ball(unused_physics): + """Indicator on if an opponent intercepted the ball.""" + if (task.ball.hit and task.ball.intercepted and + task.ball.last_hit.team != player.team): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_opponent_intercepted_ball', + base_observable.Generic(_stats_opponent_intercepted_ball)) + + for dist in [5, 10, 15]: + + def _stats_i_received_ball_dist(physics, dist=dist): + if (_stats_i_received_ball(physics) and + task.ball.dist_between_last_hits is not None and + task.ball.dist_between_last_hits > dist): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_i_received_ball_%dm' % dist, + base_observable.Generic(_stats_i_received_ball_dist)) + + def _stats_opponent_intercepted_ball_dist(physics, dist=dist): + if (_stats_opponent_intercepted_ball(physics) and + task.ball.dist_between_last_hits is not None and + task.ball.dist_between_last_hits > dist): + return 1. + return 0. + + player.walker.observables.add_observable( + 'stats_opponent_intercepted_ball_%dm' % dist, + base_observable.Generic(_stats_opponent_intercepted_ball_dist)) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py new file mode 100644 index 0000000..2ee5295 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch.py @@ -0,0 +1,346 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A soccer pitch with home/away goals and one field with position detection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging +from dm_control import composer +from dm_control.composer.variation import distributions +from dm_control.entities import props +from dm_control.locomotion.soccer import team +import numpy as np + + +_TOP_CAMERA_Y_PADDING_FACTOR = 1.1 +_TOP_CAMERA_DISTANCE = 100. +_WALL_HEIGHT = 10. +_WALL_THICKNESS = .5 +_SIDE_WIDTH = 32. / 6. +_GROUND_GEOM_GRID_RATIO = 1. / 100 # Grid size for lighting. +_FIELD_BOX_CONTACT_BIT = 1 << 7 # Use a higher bit to prevent potential clash. + +_DEFAULT_PITCH_SIZE = (12, 9) +_DEFAULT_GOAL_LENGTH_RATIO = 0.33 # Goal length / pitch width. + + +def _top_down_cam_fovy(size, top_camera_distance): + return (360 / np.pi) * np.arctan2(_TOP_CAMERA_Y_PADDING_FACTOR * max(size), + top_camera_distance) + + +def _wall_pos_xyaxes(size): + """Infers position and size of bounding walls given pitch size. + + Walls are placed around `ground_geom` that represents the pitch. Note that + the ball cannot travel beyond `field` but walkers can walk outside of the + `field` but not the surrounding walls. + + Args: + size: a tuple of (length, width) of the pitch. + + Returns: + a list of 4 tuples, each representing the position and xyaxes of a wall + plane. In order, walls are placed along x-negative, x-positive, y-negative, + y-positive relative the center of the pitch. + """ + return [ + ((0., -size[1], 0.), (-1, 0, 0, 0, 0, 1)), + ((0., size[1], 0.), (1, 0, 0, 0, 0, 1)), + ((-size[0], 0., 0.), (0, 1, 0, 0, 0, 1)), + ((size[0], 0., 0.), (0, -1, 0, 0, 0, 1)), + ] + + +def _roof_size(size): + return (size[0], size[1], _WALL_THICKNESS) + + +class Pitch(composer.Arena): + """A pitch with a plane, two goals and a field with position detection.""" + + def _build(self, + size=_DEFAULT_PITCH_SIZE, + goal_size=None, + top_camera_distance=_TOP_CAMERA_DISTANCE, + field_box=False, + name='pitch'): + """Construct a pitch with walls and position detectors. + + Args: + size: a tuple of (length, width) of the pitch. + goal_size: optional (depth, width, height) indicating the goal size. + If not specified, the goal size is inferred from pitch size with a fixed + default ratio. + top_camera_distance: the distance of the top-down camera to the pitch. + field_box: adds a "field box" that collides with the ball but not the + walkers. + name: the name of this arena. + """ + super(Pitch, self)._build(name=name) + self._size = size + self._goal_size = goal_size + self._top_camera_distance = top_camera_distance + + self._top_camera = self._mjcf_root.worldbody.add( + 'camera', + name='top_down', + pos=[0, 0, top_camera_distance], + zaxis=[0, 0, 1], + fovy=_top_down_cam_fovy(self._size, top_camera_distance)) + + self._mjcf_root.visual.headlight.set_attributes( + ambient=[.4, .4, .4], diffuse=[.8, .8, .8], specular=[.1, .1, .1]) + + # Ensure close up geoms are rendered by egocentric cameras. + self._mjcf_root.visual.map.znear = 0.0005 + + # Build groundplane. + if len(self._size) != 2: + raise ValueError('`size` should be a sequence of length 2: got {!r}' + .format(self._size)) + self._ground_texture = self._mjcf_root.asset.add( + 'texture', + type='2d', + builtin='checker', + name='groundplane', + rgb1=[0.3, 0.8, 0.3], + rgb2=[0.1, 0.6, 0.1], + width=300, + height=300, + mark='edge', + markrgb=[0.8, 0.8, 0.8]) + self._ground_material = self._mjcf_root.asset.add( + 'material', name='groundplane', texture=self._ground_texture) + self._ground_geom = self._mjcf_root.worldbody.add( + 'geom', + type='plane', + material=self._ground_material, + size=list(self._size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO]) + + # Build walls. + self._walls = [] + for wall_pos, wall_xyaxes in _wall_pos_xyaxes(self._size): + self._walls.append( + self._mjcf_root.worldbody.add( + 'geom', + type='plane', + rgba=[.1, .1, .1, .8], + pos=wall_pos, + size=[1e-7, 1e-7, 1e-7], + xyaxes=wall_xyaxes)) + + # Build goal position detectors. + # If field_box is enabled, offset goal by 1.0 such that ball reaches the + # goal position detector before bouncing off the field_box. + self._fb_offset = 0.5 if field_box else 0.0 + goal_size = self._get_goal_size() + self._home_goal = props.PositionDetector( + pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0, + goal_size[2]), + size=goal_size, + rgba=(0, 0, 1, 0.5), + visible=True, + name='home_goal') + self.attach(self._home_goal) + + self._away_goal = props.PositionDetector( + pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]), + size=goal_size, + rgba=(1, 0, 0, 0.5), + visible=True, + name='away_goal') + self.attach(self._away_goal) + + # Build inverted field position detectors. + self._field = props.PositionDetector( + pos=(0, 0), + size=(self._size[0] - 2 * goal_size[0], + self._size[1] - 2 * goal_size[0]), + rgba=(1, 0, 0, 0.1), + inverted=True, + visible=True, + name='field') + self.attach(self._field) + + # Build field box. + self._field_box = [] + if field_box: + for wall_pos, wall_xyaxes in _wall_pos_xyaxes( + (self._field.upper - self._field.lower) / 2.0): + self._field_box.append( + self._mjcf_root.worldbody.add( + 'geom', + type='plane', + rgba=[.3, .3, .3, .3], + pos=wall_pos, + size=[1e-7, 1e-7, 1e-7], + xyaxes=wall_xyaxes)) + + def _get_goal_size(self): + goal_size = self._goal_size + if goal_size is None: + goal_size = ( + _SIDE_WIDTH / 2, + self._size[1] * _DEFAULT_GOAL_LENGTH_RATIO, + _SIDE_WIDTH / 2, + ) + return goal_size + + def register_ball(self, ball): + self._home_goal.register_entities(ball) + self._away_goal.register_entities(ball) + + if self._field_box: + # Geoms a and b collides if: + # (a.contype & b.conaffinity) || (b.contype & a.conaffinity) != 0. + # See: http://www.mujoco.org/book/computation.html#Collision + ball.geom.contype = (ball.geom.contype or 1) | _FIELD_BOX_CONTACT_BIT + for wall in self._field_box: + wall.conaffinity = _FIELD_BOX_CONTACT_BIT + wall.contype = _FIELD_BOX_CONTACT_BIT + else: + self._field.register_entities(ball) + + def detected_goal(self): + """Returning the team that scored a goal.""" + if self._home_goal.detected_entities: + return team.Team.AWAY + if self._away_goal.detected_entities: + return team.Team.HOME + return None + + def detected_off_court(self): + return self._field.detected_entities + + @property + def size(self): + return self._size + + @property + def home_goal(self): + return self._home_goal + + @property + def away_goal(self): + return self._away_goal + + @property + def field(self): + return self._field + + @property + def ground_geom(self): + return self._ground_geom + + +class RandomizedPitch(Pitch): + """RandomizedPitch that randomizes its size between (min_size, max_size).""" + + def __init__(self, + min_size, + max_size, + randomizer=None, + keep_aspect_ratio=False, + goal_size=None, + field_box=False, + top_camera_distance=_TOP_CAMERA_DISTANCE, + name='randomized_pitch'): + """Construct a randomized pitch. + + Args: + min_size: a tuple of minimum (length, width) of the pitch. + max_size: a tuple of maximum (length, width) of the pitch. + randomizer: a callable that returns ratio between [0., 1.] that scales + between min_size, max_size. + keep_aspect_ratio: if `True`, keep the aspect ratio constant during + randomization. + goal_size: optional (depth, width, height) indicating the goal size. + If not specified, the goal size is inferred from pitch size with a fixed + default ratio. + field_box: optional indicating if we should construct field box containing + the ball (but not the walkers). + top_camera_distance: the distance of the top-down camera to the pitch. + name: the name of this arena. + """ + super(RandomizedPitch, self).__init__( + size=max_size, + goal_size=goal_size, + top_camera_distance=top_camera_distance, + field_box=field_box, + name=name) + + self._min_size = min_size + self._max_size = max_size + + self._randomizer = randomizer or distributions.Uniform() + self._keep_aspect_ratio = keep_aspect_ratio + + # Sample a new size and regenerate the soccer pitch. + logging.info('%s between (%s, %s) with %s', self.__class__.__name__, + min_size, max_size, self._randomizer) + + def _resize_goals(self, goal_size): + self._home_goal.resize( + pos=(-self._size[0] + goal_size[0] + self._fb_offset, 0, goal_size[2]), + size=goal_size) + self._away_goal.resize( + pos=(self._size[0] - goal_size[0] - self._fb_offset, 0, goal_size[2]), + size=goal_size) + + def initialize_episode_mjcf(self, random_state): + super(RandomizedPitch, self).initialize_episode_mjcf(random_state) + min_len, min_wid = self._min_size + max_len, max_wid = self._max_size + + if self._keep_aspect_ratio: + len_ratio = self._randomizer(random_state=random_state) + wid_ratio = len_ratio + else: + len_ratio = self._randomizer(random_state=random_state) + wid_ratio = self._randomizer(random_state=random_state) + + self._size = (min_len + len_ratio * (max_len - min_len), + min_wid + wid_ratio * (max_wid - min_wid)) + + # Reset top_down camera field of view. + self._top_camera.fovy = _top_down_cam_fovy(self._size, + self._top_camera_distance) + + # Resize ground geom size. + self._ground_geom.size = list( + self._size) + [max(self._size) * _GROUND_GEOM_GRID_RATIO] + + # Resize and reposition walls and roof geoms. + for i, (wall_pos, _) in enumerate(_wall_pos_xyaxes(self._size)): + self._walls[i].pos = wall_pos + + goal_size = self._get_goal_size() + self._resize_goals(goal_size) + + # Resize inverted field position detectors. + self._field.resize( + pos=(0, 0), + size=(self._size[0] - 2 * goal_size[0], + self._size[1] - 2 * goal_size[0])) + + # Resize and reposition field box geoms. + if self._field_box: + for i, (pos, _) in enumerate( + _wall_pos_xyaxes((self._field.upper - self._field.lower) / 2.0)): + self._field_box[i].pos = pos diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py new file mode 100644 index 0000000..d4e9141 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/pitch_test.py @@ -0,0 +1,86 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.soccer.pitch.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control.composer.variation import distributions +from dm_control.entities import props +from dm_control.locomotion.soccer import pitch as pitch_lib +from dm_control.locomotion.soccer import team as team_lib +import numpy as np + + +class PitchTest(parameterized.TestCase): + + def _pitch_with_ball(self, pitch_size, ball_pos): + pitch = pitch_lib.Pitch(size=pitch_size) + self.assertEqual(pitch.size, pitch_size) + + sphere = props.Primitive(geom_type='sphere', size=(0.1,), pos=ball_pos) + pitch.register_ball(sphere) + pitch.attach(sphere) + + env = composer.Environment( + composer.NullTask(pitch), random_state=np.random.RandomState(42)) + env.reset() + return pitch + + def test_pitch_none_detected(self): + pitch = self._pitch_with_ball((12, 9), (0, 0, 0)) + self.assertEmpty(pitch.detected_off_court()) + self.assertIsNone(pitch.detected_goal()) + + def test_pitch_detected_off_court(self): + pitch = self._pitch_with_ball((12, 9), (20, 0, 0)) + self.assertLen(pitch.detected_off_court(), 1) + self.assertIsNone(pitch.detected_goal()) + + def test_pitch_detected_away_goal(self): + pitch = self._pitch_with_ball((12, 9), (-9.5, 0, 1)) + self.assertLen(pitch.detected_off_court(), 1) + self.assertEqual(team_lib.Team.AWAY, pitch.detected_goal()) + + def test_pitch_detected_home_goal(self): + pitch = self._pitch_with_ball((12, 9), (9.5, 0, 1)) + self.assertLen(pitch.detected_off_court(), 1) + self.assertEqual(team_lib.Team.HOME, pitch.detected_goal()) + + @parameterized.parameters((True, distributions.Uniform()), + (False, distributions.Uniform())) + def test_randomize_pitch(self, keep_aspect_ratio, randomizer): + pitch = pitch_lib.RandomizedPitch( + min_size=(4, 3), + max_size=(8, 6), + randomizer=randomizer, + keep_aspect_ratio=keep_aspect_ratio) + pitch.initialize_episode_mjcf(np.random.RandomState(42)) + + self.assertBetween(pitch.size[0], 4, 8) + self.assertBetween(pitch.size[1], 3, 6) + + if keep_aspect_ratio: + self.assertAlmostEqual((pitch.size[0] - 4) / (8. - 4.), + (pitch.size[1] - 3) / (6. - 3.)) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png new file mode 100644 index 0000000..3511794 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer.png differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py new file mode 100644 index 0000000..5ac8a6e --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball.py @@ -0,0 +1,236 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A soccer ball that keeps track of ball-player contacts.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from dm_control import mjcf +from dm_control.entities import props +import numpy as np + +from dm_control.utils import io as resources + +_ASSETS_PATH = os.path.join(os.path.dirname(__file__), 'assets', 'soccer_ball') + + +def _get_texture(name): + contents = resources.GetResource( + os.path.join(_ASSETS_PATH, '{}.png'.format(name))) + return mjcf.Asset(contents, '.png') + + +class SoccerBall(props.Primitive): + """A soccer ball that keeps track of entities that come into contact.""" + + def _build(self, radius=0.35, mass=0.045, name='soccer_ball'): + """Builds this soccer ball. + + Args: + radius: The radius (in meters) of this target sphere. + mass: Mass (in kilograms) of the ball. + name: The name of this entity. + """ + super(SoccerBall, self)._build( + geom_type='sphere', size=(radius,), name=name) + texture = self._mjcf_root.asset.add( + 'texture', + name='soccer_ball', + type='cube', + fileup=_get_texture('up'), + filedown=_get_texture('down'), + filefront=_get_texture('front'), + fileback=_get_texture('back'), + fileleft=_get_texture('left'), + fileright=_get_texture('right')) + material = self._mjcf_root.asset.add( + 'material', name='soccer_ball', texture=texture) + self._geom.set_attributes( + pos=[0, 0, radius], + size=[radius], + condim=6, + friction=[.7, .075, .075], + mass=mass, + material=material) + + # Add some tracking cameras for visualization and logging. + self._mjcf_root.worldbody.add( + 'camera', + name='ball_cam_near', + pos=[0, -2, 2], + zaxis=[0, -1, 1], + fovy=70, + mode='trackcom') + self._mjcf_root.worldbody.add( + 'camera', + name='ball_cam', + pos=[0, -7, 7], + zaxis=[0, -1, 1], + fovy=70, + mode='trackcom') + self._mjcf_root.worldbody.add( + 'camera', + name='ball_cam_far', + pos=[0, -10, 10], + zaxis=[0, -1, 1], + fovy=70, + mode='trackcom') + + # Keep track of entities to team mapping. + self._players = [] + + # Initialize tracker attributes. + self.initialize_entity_trackers() + + def register_player(self, player): + self._players.append(player) + + def initialize_entity_trackers(self): + self._last_hit = None + self._hit = False + self._repossessed = False + self._intercepted = False + + # Tracks distance traveled by the ball in between consecutive hits. + self._pos_at_last_step = None + self._dist_since_last_hit = None + self._dist_between_last_hits = None + + def initialize_episode(self, physics, unused_random_state): + self._geom_id = physics.model.name2id(self._geom.full_identifier, 'geom') + self._geom_id_to_player = {} + for player in self._players: + geoms = player.walker.mjcf_model.find_all('geom') + for geom in geoms: + geom_id = physics.model.name2id(geom.full_identifier, 'geom') + self._geom_id_to_player[geom_id] = player + + self.initialize_entity_trackers() + + def after_substep(self, physics, unused_random_state): + """Resolve contacts and update ball-player contact trackers.""" + if self._hit: + # Ball has already registered a valid contact within step (during one of + # previous after_substep calls). + return + + # Iterate through all contacts to find the first contact between the ball + # and one of the registered entities. + for contact in physics.data.contact: + # Keep contacts that involve the ball and one of the registered entities. + has_self = False + for geom_id in (contact.geom1, contact.geom2): + if geom_id == self._geom_id: + has_self = True + else: + player = self._geom_id_to_player.get(geom_id) + + if has_self and player: + # Detected a contact between the ball and an registered player. + if self._last_hit is not None: + self._intercepted = player.team != self._last_hit.team + else: + self._intercepted = True + + # Register repossessed before updating last_hit player. + self._repossessed = player is not self._last_hit + self._last_hit = player + # Register hit event. + self._hit = True + break + + def before_step(self, physics, random_state): + super(SoccerBall, self).before_step(physics, random_state) + # Reset per simulation step indicator. + self._hit = False + self._repossessed = False + self._intercepted = False + + def after_step(self, physics, random_state): + super(SoccerBall, self).after_step(physics, random_state) + pos = physics.bind(self._geom).xpos + if self._hit: + # SoccerBall is hit on this step. Update dist_between_last_hits + # to dist_since_last_hit before resetting dist_since_last_hit. + self._dist_between_last_hits = self._dist_since_last_hit + self._dist_since_last_hit = 0. + self._pos_at_last_step = pos.copy() + + if self._dist_since_last_hit is not None: + # Accumulate distance traveled since last hit event. + self._dist_since_last_hit += np.linalg.norm(pos - self._pos_at_last_step) + + self._pos_at_last_step = pos.copy() + + @property + def last_hit(self): + """The player that last came in contact with the ball or `None`.""" + return self._last_hit + + @property + def hit(self): + """Indicates if the ball is hit during the last simulation step. + + For a timeline shown below: + ..., agent.step, simulation, agent.step, ... + + Returns: + True: if the ball is hit by a registered player during simulation step. + False: if not. + """ + return self._hit + + @property + def repossessed(self): + """Indicates if the ball has been repossessed by a different player. + + For a timeline shown below: + ..., agent.step, simulation, agent.step, ... + + Returns: + True if the ball is hit by a registered player during simulation step + and that player is different from `last_hit`. + False: if the ball is not hit, or the ball is hit by `last_hit` player. + """ + return self._repossessed + + @property + def intercepted(self): + """Indicates if the ball has been intercepted by a different team. + + For a timeline shown below: + ..., agent.step, simulation, agent.step, ... + + Returns: + True: if the ball is hit for the first time, or repossessed by an player + from a different team. + False: if the ball is not hit, not repossessed, or repossessed by a + teammate to `last_hit`. + """ + return self._intercepted + + @property + def dist_between_last_hits(self): + """Distance between last consecutive hits. + + Returns: + Distance between last two consecutive hit events or `None` if there has + not been two consecutive hits on the ball. + """ + return self._dist_between_last_hits diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py new file mode 100644 index 0000000..372b703 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/soccer_ball_test.py @@ -0,0 +1,73 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.soccer.soccer_ball.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import composer +from dm_control import mjcf +from dm_control.entities import props +from dm_control.locomotion.soccer import soccer_ball +from dm_control.locomotion.soccer import team +import numpy as np + + +class SoccerBallTest(absltest.TestCase): + + def test_detect_hit(self): + arena = composer.Arena() + ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball') + player = team.Player( + team=team.Team.HOME, + walker=props.Primitive(geom_type='sphere', size=(0.1,), name='home')) + arena.add_free_entity(player.walker) + ball.register_player(player) + arena.add_free_entity(ball) + + random_state = np.random.RandomState(42) + physics = mjcf.Physics.from_mjcf_model(arena.mjcf_model) + physics.step() + + ball.initialize_episode(physics, random_state) + ball.before_step(physics, random_state) + self.assertEqual(ball.hit, False) + self.assertEqual(ball.repossessed, False) + self.assertEqual(ball.intercepted, False) + self.assertIsNone(ball.last_hit) + self.assertIsNone(ball.dist_between_last_hits) + + ball.after_substep(physics, random_state) + ball.after_step(physics, random_state) + + self.assertEqual(ball.hit, True) + self.assertEqual(ball.repossessed, True) + self.assertEqual(ball.intercepted, True) + self.assertEqual(ball.last_hit, player) + # Only one hit registered. + self.assertIsNone(ball.dist_between_last_hits) + + def test_has_tracking_cameras(self): + ball = soccer_ball.SoccerBall(radius=0.35, mass=0.045, name='test_ball') + expected_camera_names = ['ball_cam_near', 'ball_cam', 'ball_cam_far'] + camera_names = [cam.name for cam in ball.mjcf_model.find_all('camera')] + self.assertCountEqual(expected_camera_names, camera_names) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py new file mode 100644 index 0000000..e3f5044 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task.py @@ -0,0 +1,191 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""""A task where players play a soccer game.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.locomotion.soccer import initializers +from dm_control.locomotion.soccer import observables as observables_lib +from dm_control.locomotion.soccer import soccer_ball +from dm_env import specs +import numpy as np +from six.moves import zip + +_THROW_IN_BALL_Z = 0.5 + + +def _disable_geom_contacts(entities): + for entity in entities: + mjcf_model = entity.mjcf_model + for geom in mjcf_model.find_all("geom"): + geom.set_attributes(contype=0) + + +class Task(composer.Task): + """A task where two teams of walkers play soccer.""" + + def __init__(self, + players, + arena, + ball=None, + initializer=None, + observables=None, + disable_walker_contacts=False, + nconmax_per_player=200, + njmax_per_player=200, + control_timestep=0.025): + """Construct an instance of soccer.Task. + + This task implements the high-level game logic of multi-agent MuJoCo soccer. + + Args: + players: a sequence of `soccer.Player` instances, representing + participants to the game from both teams. + arena: an instance of `soccer.Pitch`, implementing the physical geoms and + the sensors associated with the pitch. + ball: optional instance of `soccer.SoccerBall`, implementing the physical + geoms and sensors associated with the soccer ball. If None, defaults to + using `soccer_ball.SoccerBall()`. + initializer: optional instance of `soccer.Initializer` that initializes + the task at the start of each episode. If None, defaults to + `initializers.UniformInitializer()`. + observables: optional instance of `soccer.ObservablesAdder` that adds + observables for each player. If None, defaults to + `observables.CoreObservablesAdder()`. + disable_walker_contacts: if `True`, disable physical contacts between + players. + nconmax_per_player: allocated maximum number of contacts per player. It + may be necessary to increase this value if you encounter errors due to + `mjWARN_CONTACTFULL`. + njmax_per_player: allocated maximum number of scalar constraints per + player. It may be necessary to increase this value if you encounter + errors due to `mjWARN_CNSTRFULL`. + control_timestep: control timestep of the agent. + """ + self.arena = arena + self.players = players + + self._initializer = initializer or initializers.UniformInitializer() + self._observables = observables or observables_lib.CoreObservablesAdder() + + if disable_walker_contacts: + _disable_geom_contacts([p.walker for p in self.players]) + + # Create ball and attach ball to arena. + self.ball = ball or soccer_ball.SoccerBall() + self.arena.add_free_entity(self.ball) + self.arena.register_ball(self.ball) + + # Register soccer ball contact tracking for players. + for player in self.players: + player.walker.create_root_joints(self.arena.attach(player.walker)) + self.ball.register_player(player) + # Add per-walkers observables. + self._observables(self, player) + + self.set_timesteps( + physics_timestep=0.005, control_timestep=control_timestep) + self.root_entity.mjcf_model.size.nconmax = nconmax_per_player * len(players) + self.root_entity.mjcf_model.size.njmax = njmax_per_player * len(players) + + @property + def observables(self): + observables = [] + for player in self.players: + observables.append( + player.walker.observables.as_dict(fully_qualified=False)) + return observables + + def _throw_in(self, physics, random_state, ball): + x, y, _ = physics.bind(ball.geom).xpos + shrink_x, shrink_y = random_state.uniform([0.7, 0.7], [0.9, 0.9]) + ball.set_pose(physics, [x * shrink_x, y * shrink_y, _THROW_IN_BALL_Z]) + ball.set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + ball.initialize_entity_trackers() + + def initialize_episode_mjcf(self, random_state): + self.arena.initialize_episode_mjcf(random_state) + + def initialize_episode(self, physics, random_state): + self.arena.initialize_episode(physics, random_state) + self._initializer(self, physics, random_state) + + @property + def root_entity(self): + return self.arena + + def get_reward(self, physics): + """Returns a list of per-player rewards. + + Each player will receive a reward of: + +1 if their team scored a goal + -1 if their team conceded a goal + 0 if no goals were scored on this timestep. + + Note: the observations also contain various environment statistics that may + be used to derive per-player rewards (as done in + http://arxiv.org/abs/1902.07151). + + Args: + physics: An instance of `Physics`. + + Returns: + A list of 0-dimensional numpy arrays, one per player. + """ + scoring_team = self.arena.detected_goal() + if not scoring_team: + return [np.zeros((), dtype=np.float32) for _ in self.players] + + rewards = [] + for p in self.players: + if p.team == scoring_team: + rewards.append(np.ones((), dtype=np.float32)) + else: + rewards.append(-np.ones((), dtype=np.float32)) + return rewards + + def get_reward_spec(self): + return [ + specs.Array(name="reward", shape=(), dtype=np.float32) + for _ in self.players + ] + + def get_discount(self, physics): + if self.arena.detected_goal(): + return np.zeros((), np.float32) + return np.ones((), np.float32) + + def get_discount_spec(self): + return specs.Array(name="discount", shape=(), dtype=np.float32) + + def should_terminate_episode(self, physics): + """Returns True if a goal was scored by either team.""" + return self.arena.detected_goal() is not None + + def before_step(self, physics, actions, random_state): + for player, action in zip(self.players, actions): + player.walker.apply_action(physics, action, random_state) + + if self.arena.detected_off_court(): + self._throw_in(physics, random_state, self.ball) + + def action_spec(self, physics): + """Return multi-agent action_spec.""" + return [player.walker.action_spec for player in self.players] diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py new file mode 100644 index 0000000..c4dfacc --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/task_test.py @@ -0,0 +1,543 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for locomotion.tasks.soccer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest +# Internal dependencies. + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control import mjcf +from dm_control.locomotion import soccer +from dm_control.locomotion.soccer import initializers +from dm_control.mujoco.wrapper import mjbindings +import numpy as np +from six.moves import range +from six.moves import zip + +RGBA_BLUE = [.1, .1, .8, 1.] +RGBA_RED = [.8, .1, .1, 1.] + + +def _walker(name, walker_id, marker_rgba): + return soccer.BoxHead( + name=name, + walker_id=walker_id, + marker_rgba=marker_rgba, + ) + + +def _team_players(team_size, team, team_name, team_color): + team_of_players = [] + for i in range(team_size): + team_of_players.append( + soccer.Player(team, _walker("%s%d" % (team_name, i), i, team_color))) + return team_of_players + + +def _home_team(team_size): + return _team_players(team_size, soccer.Team.HOME, "home", RGBA_BLUE) + + +def _away_team(team_size): + return _team_players(team_size, soccer.Team.AWAY, "away", RGBA_RED) + + +def _env(players, disable_walker_contacts=True, observables=None, + random_state=42, **task_kwargs): + return composer.Environment( + task=soccer.Task( + players=players, + arena=soccer.Pitch((20, 15)), + observables=observables, + disable_walker_contacts=disable_walker_contacts, + **task_kwargs + ), + random_state=random_state, + time_limit=1) + + +def _observables_adder(observables_adder): + if observables_adder == "core": + return soccer.CoreObservablesAdder() + if observables_adder == "core_interception": + return soccer.MultiObservablesAdder( + [soccer.CoreObservablesAdder(), + soccer.InterceptionObservablesAdder()]) + raise ValueError("Unrecognized observable_adder %s" % observables_adder) + + +class TaskTest(parameterized.TestCase): + + def _assert_all_count_equal(self, list_of_lists): + """Check all lists in the list are count equal.""" + if not list_of_lists: + return + + first = sorted(list_of_lists[0]) + for other in list_of_lists[1:]: + self.assertCountEqual(first, other) + + @parameterized.named_parameters( + ("1vs1_core", 1, "core", 33, True), + ("2vs2_core", 2, "core", 43, True), + ("1vs1_interception", 1, "core_interception", 41, True), + ("2vs2_interception", 2, "core_interception", 51, True), + ("1vs1_core_contact", 1, "core", 33, False), + ("2vs2_core_contact", 2, "core", 43, False), + ("1vs1_interception_contact", 1, "core_interception", 41, False), + ("2vs2_interception_contact", 2, "core_interception", 51, False), + ) + def test_step_environment(self, team_size, observables_adder, num_obs, + disable_walker_contacts): + env = _env( + _home_team(team_size) + _away_team(team_size), + observables=_observables_adder(observables_adder), + disable_walker_contacts=disable_walker_contacts) + self.assertLen(env.action_spec(), 2 * team_size) + self.assertLen(env.observation_spec(), 2 * team_size) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + timestep = env.reset() + + for observation, spec in zip(timestep.observation, env.observation_spec()): + self.assertLen(spec, num_obs) + self.assertCountEqual(list(observation.keys()), list(spec.keys())) + for key in observation.keys(): + self.assertEqual(observation[key].shape, spec[key].shape) + + while not timestep.last(): + timestep = env.step(actions) + + # TODO(b/124848293): consolidate environment stepping loop for task tests. + @parameterized.named_parameters( + ("1vs2", 1, 2, 38), + ("2vs1", 2, 1, 38), + ("3vs0", 3, 0, 38), + ("0vs2", 0, 2, 33), + ("2vs2", 2, 2, 43), + ("0vs0", 0, 0, None), + ) + def test_num_players(self, home_size, away_size, num_observations): + env = _env(_home_team(home_size) + _away_team(away_size)) + self.assertLen(env.action_spec(), home_size + away_size) + self.assertLen(env.observation_spec(), home_size + away_size) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + timestep = env.reset() + + # Members of the same team should have identical specs. + self._assert_all_count_equal( + [spec.keys() for spec in env.observation_spec()[:home_size]]) + self._assert_all_count_equal( + [spec.keys() for spec in env.observation_spec()[-away_size:]]) + + for observation, spec in zip(timestep.observation, env.observation_spec()): + self.assertCountEqual(list(observation.keys()), list(spec.keys())) + for key in observation.keys(): + self.assertEqual(observation[key].shape, spec[key].shape) + + self.assertLen(spec, num_observations) + + while not timestep.last(): + timestep = env.step(actions) + + self.assertLen(timestep.observation, home_size + away_size) + + self.assertLen(timestep.reward, home_size + away_size) + for player_spec, player_reward in zip(env.reward_spec(), timestep.reward): + player_spec.validate(player_reward) + + discount_spec = env.discount_spec() + discount_spec.validate(timestep.discount) + + def test_all_contacts(self): + env = _env(_home_team(1) + _away_team(1)) + + def _all_contact_configuration(physics, unused_random_state): + walkers = [p.walker for p in env.task.players] + ball = env.task.ball + + x, y, rotation = 0., 0., np.pi / 6. + ball.set_pose(physics, [x, y, 5.]) + ball.set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + x, y, rotation = 0., 0., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[0].set_pose(physics, [x, y, 3.], quat) + walkers[0].set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + x, y, rotation = 0., 0., np.pi / 3. + np.pi + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[1].set_pose(physics, [x, y, 1.], quat) + walkers[1].set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + env.add_extra_hook("initialize_episode", _all_contact_configuration) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + timestep = env.reset() + while not timestep.last(): + timestep = env.step(actions) + + def test_symmetric_observations(self): + env = _env(_home_team(1) + _away_team(1)) + + def _symmetric_configuration(physics, unused_random_state): + walkers = [p.walker for p in env.task.players] + ball = env.task.ball + + x, y, rotation = 0., 0., np.pi / 6. + ball.set_pose(physics, [x, y, 0.5]) + ball.set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + x, y, rotation = 5., 3., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[0].set_pose(physics, [x, y, 0.], quat) + walkers[0].set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + x, y, rotation = -5., -3., np.pi / 3. + np.pi + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[1].set_pose(physics, [x, y, 0.], quat) + walkers[1].set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + env.add_extra_hook("initialize_episode", _symmetric_configuration) + + timestep = env.reset() + obs_a, obs_b = timestep.observation + self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys())) + for k in sorted(obs_a.keys()): + o_a, o_b = obs_a[k], obs_b[k] + np.testing.assert_allclose(o_a, o_b, err_msg=k + " not equal.", atol=1e-6) + + def test_symmetric_dynamic_observations(self): + env = _env(_home_team(1) + _away_team(1)) + + def _symmetric_configuration(physics, unused_random_state): + walkers = [p.walker for p in env.task.players] + ball = env.task.ball + + x, y, rotation = 0., 0., np.pi / 6. + ball.set_pose(physics, [x, y, 0.5]) + # Ball shooting up. Walkers going tangent. + ball.set_velocity(physics, velocity=[0., 0., 1.], + angular_velocity=[0., 0., 0.]) + + x, y, rotation = 5., 3., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[0].set_pose(physics, [x, y, 0.], quat) + walkers[0].set_velocity(physics, velocity=[y, -x, 0.], + angular_velocity=[0., 0., 0.]) + + x, y, rotation = -5., -3., np.pi / 3. + np.pi + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[1].set_pose(physics, [x, y, 0.], quat) + walkers[1].set_velocity(physics, velocity=[y, -x, 0.], + angular_velocity=[0., 0., 0.]) + + env.add_extra_hook("initialize_episode", _symmetric_configuration) + + timestep = env.reset() + obs_a, obs_b = timestep.observation + self.assertCountEqual(list(obs_a.keys()), list(obs_b.keys())) + for k in sorted(obs_a.keys()): + o_a, o_b = obs_a[k], obs_b[k] + np.testing.assert_allclose(o_a, o_b, err_msg=k + " not equal.", atol=1e-6) + + def test_prev_actions(self): + env = _env(_home_team(1) + _away_team(1)) + + actions = [] + for i, player in enumerate(env.task.players): + spec = player.walker.action_spec + actions.append((i + 1) * np.ones(spec.shape, dtype=spec.dtype)) + + env.reset() + timestep = env.step(actions) + + for walker_idx, obs in enumerate(timestep.observation): + np.testing.assert_allclose( + np.squeeze(obs["prev_action"], axis=0), + actions[walker_idx], + err_msg="Walker {}: incorrect previous action.".format(walker_idx)) + + @parameterized.named_parameters( + dict(testcase_name="1vs2_draw", + home_size=1, away_size=2, ball_vel_x=0, expected_home_score=0), + dict(testcase_name="1vs2_home_score", + home_size=1, away_size=2, ball_vel_x=50, expected_home_score=1), + dict(testcase_name="2vs1_away_score", + home_size=2, away_size=1, ball_vel_x=-50, expected_home_score=-1), + dict(testcase_name="3vs0_home_score", + home_size=3, away_size=0, ball_vel_x=50, expected_home_score=1), + dict(testcase_name="0vs2_home_score", + home_size=0, away_size=2, ball_vel_x=50, expected_home_score=1), + dict(testcase_name="2vs2_away_score", + home_size=2, away_size=2, ball_vel_x=-50, expected_home_score=-1), + ) + def test_scoring_rewards( + self, home_size, away_size, ball_vel_x, expected_home_score): + env = _env(_home_team(home_size) + _away_team(away_size)) + + def _score_configuration(physics, random_state): + del random_state # Unused. + # Send the ball shooting towards either the home or away goal. + env.task.ball.set_pose(physics, [0., 0., 0.5]) + env.task.ball.set_velocity(physics, + velocity=[ball_vel_x, 0., 0.], + angular_velocity=[0., 0., 0.]) + + env.add_extra_hook("initialize_episode", _score_configuration) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + # Disable contacts and gravity so that the ball follows a straight path. + with env.physics.model.disable("contact", "gravity"): + + timestep = env.reset() + with self.subTest("Reward and discount are None on the first timestep"): + self.assertTrue(timestep.first()) + self.assertIsNone(timestep.reward) + self.assertIsNone(timestep.discount) + + # Step until the episode ends. + timestep = env.step(actions) + while not timestep.last(): + self.assertTrue(timestep.mid()) + # For non-terminal timesteps, the reward should always be 0 and the + # discount should always be 1. + np.testing.assert_array_equal(np.hstack(timestep.reward), 0.) + self.assertEqual(timestep.discount, 1.) + timestep = env.step(actions) + + # If a goal was scored then the epsiode should have ended with a discount of + # 0. If neither team scored and the episode ended due to hitting the time + # limit then the discount should be 1. + with self.subTest("Correct terminal discount"): + if expected_home_score != 0: + expected_discount = 0. + else: + expected_discount = 1. + self.assertEqual(timestep.discount, expected_discount) + + with self.subTest("Correct terminal reward"): + reward = np.hstack(timestep.reward) + np.testing.assert_array_equal(reward[:home_size], expected_home_score) + np.testing.assert_array_equal(reward[home_size:], -expected_home_score) + + def test_throw_in(self): + env = _env(_home_team(1) + _away_team(1)) + + def _throw_in_configuration(physics, unused_random_state): + walkers = [p.walker for p in env.task.players] + ball = env.task.ball + + x, y, rotation = 0., 3., np.pi / 6. + ball.set_pose(physics, [x, y, 0.5]) + # Ball shooting out of bounds. + ball.set_velocity(physics, velocity=[0., 50., 0.], + angular_velocity=[0., 0., 0.]) + + x, y, rotation = 0., -3., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[0].set_pose(physics, [x, y, 0.], quat) + walkers[0].set_velocity(physics, velocity=[0., 0., 0.], + angular_velocity=[0., 0., 0.]) + x, y, rotation = 0., -5., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[1].set_pose(physics, [x, y, 0.], quat) + walkers[1].set_velocity(physics, velocity=[0., 0., 0.], + angular_velocity=[0., 0., 0.]) + + env.add_extra_hook("initialize_episode", _throw_in_configuration) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + timestep = env.reset() + + while not timestep.last(): + timestep = env.step(actions) + + terminal_ball_vel = np.linalg.norm( + timestep.observation[0]["ball_ego_linear_velocity"]) + self.assertAlmostEqual(terminal_ball_vel, 0.) + + @parameterized.named_parameters(("score", 50., 0.), ("timeout", 0., 1.)) + def test_terminal_discount(self, init_ball_vel_x, expected_terminal_discount): + env = _env(_home_team(1) + _away_team(1)) + + def _initial_configuration(physics, unused_random_state): + walkers = [p.walker for p in env.task.players] + ball = env.task.ball + + x, y, rotation = 0., 0., np.pi / 6. + ball.set_pose(physics, [x, y, 0.5]) + # Ball shooting up. Walkers going tangent. + ball.set_velocity(physics, velocity=[init_ball_vel_x, 0., 0.], + angular_velocity=[0., 0., 0.]) + + x, y, rotation = 0., -3., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[0].set_pose(physics, [x, y, 0.], quat) + walkers[0].set_velocity(physics, velocity=[0., 0., 0.], + angular_velocity=[0., 0., 0.]) + x, y, rotation = 0., 3., np.pi / 3. + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + walkers[1].set_pose(physics, [x, y, 0.], quat) + walkers[1].set_velocity(physics, velocity=[0., 0., 0.], + angular_velocity=[0., 0., 0.]) + + env.add_extra_hook("initialize_episode", _initial_configuration) + + actions = [np.zeros(s.shape, s.dtype) for s in env.action_spec()] + + timestep = env.reset() + + while not timestep.last(): + timestep = env.step(actions) + + self.assertEqual(timestep.discount, expected_terminal_discount) + + +class UniformInitializerTest(parameterized.TestCase): + + @parameterized.parameters([0.3, 0.7]) + def test_walker_position(self, spawn_ratio): + initializer = initializers.UniformInitializer(spawn_ratio=spawn_ratio) + env = _env(_home_team(2) + _away_team(2), initializer=initializer) + root_bodies = [p.walker.root_body for p in env.task.players] + xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio + env.reset() + xy = env.physics.bind(root_bodies).xpos[:, :2].copy() + with self.subTest("X and Y positions within bounds"): + if np.any(abs(xy) > xy_bounds): + self.fail("Walker(s) spawned out of bounds. Expected abs(xy) " + "<= {}, got:\n{}".format(xy_bounds, xy)) + env.reset() + xy2 = env.physics.bind(root_bodies).xpos[:, :2].copy() + with self.subTest("X and Y positions change after reset"): + if np.any(xy == xy2): + self.fail("Walker(s) have the same X and/or Y coordinates before and " + "after reset. Before: {}, after: {}.".format(xy, xy2)) + + def test_walker_rotation(self): + initializer = initializers.UniformInitializer() + env = _env(_home_team(2) + _away_team(2), initializer=initializer) + + def quats_to_eulers(quats): + eulers = np.empty((len(quats), 3), dtype=np.double) + dt = 1. + for i, quat in enumerate(quats): + mjbindings.mjlib.mju_quat2Vel(eulers[i], quat, dt) + return eulers + + # TODO(b/132671988): Switch to using `get_pose` to get the quaternion once + # `BoxHead.get_pose` and `BoxHead.set_pose` are + # implemented in a consistent way. + def get_quat(walker): + return env.physics.bind(walker.root_body).xquat + + env.reset() + quats = [get_quat(p.walker) for p in env.task.players] + eulers = quats_to_eulers(quats) + with self.subTest("Rotation is about the Z-axis only"): + np.testing.assert_array_equal(eulers[:, :2], 0.) + + env.reset() + quats2 = [get_quat(p.walker) for p in env.task.players] + eulers2 = quats_to_eulers(quats2) + with self.subTest("Rotation about Z changes after reset"): + if np.any(eulers[:, 2] == eulers2[:, 2]): + self.fail("Walker(s) have the same rotation about Z before and " + "after reset. Before: {}, after: {}." + .format(eulers[:, 2], eulers2[:, 2])) + + # TODO(b/132759890): Remove `expectedFailure` decorator once `set_velocity` + # works correctly for the `BoxHead` walker. + @unittest.expectedFailure + def test_walker_velocity(self): + initializer = initializers.UniformInitializer() + env = _env(_home_team(2) + _away_team(2), initializer=initializer) + root_joints = [] + non_root_joints = [] + for player in env.task.players: + attachment_frame = mjcf.get_attachment_frame(player.walker.mjcf_model) + root_joints.extend( + attachment_frame.find_all("joint", immediate_children_only=True)) + non_root_joints.extend(player.walker.mjcf_model.find_all("joint")) + # Assign a non-zero sentinel value to the velocities of all root and + # non-root joints. + sentinel_velocity = 3.14 + env.physics.bind(root_joints + non_root_joints).qvel = sentinel_velocity + # The initializer should zero the velocities of the root joints, but not the + # non-root joints. + initializer(env.task, env.physics, env.random_state) + np.testing.assert_array_equal(env.physics.bind(non_root_joints).qvel, + sentinel_velocity) + np.testing.assert_array_equal(env.physics.bind(root_joints).qvel, 0.) + + @parameterized.parameters([ + dict(spawn_ratio=0.3, init_ball_z=0.4), + dict(spawn_ratio=0.5, init_ball_z=0.6), + ]) + def test_ball_position(self, spawn_ratio, init_ball_z): + initializer = initializers.UniformInitializer( + spawn_ratio=spawn_ratio, init_ball_z=init_ball_z) + env = _env(_home_team(2) + _away_team(2), initializer=initializer) + xy_bounds = np.asarray(env.task.arena.size) * spawn_ratio + env.reset() + position, _ = env.task.ball.get_pose(env.physics) + xyz = position.copy() + with self.subTest("X and Y positions within bounds"): + if np.any(abs(xyz[:2]) > xy_bounds): + self.fail("Ball spawned out of bounds. Expected abs(xy) " + "<= {}, got:\n{}".format(xy_bounds, xyz[:2])) + with self.subTest("Z position equal to `init_ball_z`"): + self.assertEqual(xyz[2], init_ball_z) + env.reset() + position, _ = env.task.ball.get_pose(env.physics) + xyz2 = position.copy() + with self.subTest("X and Y positions change after reset"): + if np.any(xyz[:2] == xyz2[:2]): + self.fail("Ball has the same XY position before and after reset. " + "Before: {}, after: {}.".format(xyz[:2], xyz2[:2])) + + def test_ball_velocity(self): + initializer = initializers.UniformInitializer() + env = _env(_home_team(1) + _away_team(1), initializer=initializer) + ball_root_joint = mjcf.get_frame_freejoint(env.task.ball.mjcf_model) + # Set the velocities of the ball root joint to a non-zero sentinel value. + env.physics.bind(ball_root_joint).qvel = 3.14 + initializer(env.task, env.physics, env.random_state) + # The initializer should set the ball velocity to zero. + ball_velocity = env.physics.bind(ball_root_joint).qvel + np.testing.assert_array_equal(ball_velocity, 0.) + +if __name__ == "__main__": + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py b/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py new file mode 100644 index 0000000..ce49e3b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/soccer/team.py @@ -0,0 +1,32 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Define teams and players participating in a match.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import enum + + +class Team(enum.Enum): + HOME = 0 + AWAY = 1 + + +Player = collections.namedtuple('Player', ['team', 'walker']) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py new file mode 100644 index 0000000..f83b97a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tasks in the Locomotion library.""" + + +from dm_control.locomotion.tasks.corridors import RunThroughCorridor +from dm_control.locomotion.tasks.escape import Escape +from dm_control.locomotion.tasks.go_to_target import GoToTarget +from dm_control.locomotion.tasks.random_goal_maze import ManyGoalsMaze +from dm_control.locomotion.tasks.random_goal_maze import ManyHeterogeneousGoalsMaze +from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMaze +from dm_control.locomotion.tasks.random_goal_maze import RepeatSingleGoalMazeAugmentedWithTargets +from dm_control.locomotion.tasks.reach import TwoTouch + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py new file mode 100644 index 0000000..f865643 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors.py @@ -0,0 +1,161 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Corridor-based locomotion tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.composer import variation +from dm_control.utils import rewards +import numpy as np + + +class RunThroughCorridor(composer.Task): + """A task that requires a walker to run through a corridor. + + This task rewards an agent for controlling a walker to move at a specific + target velocity along the corridor, and for minimising the magnitude of the + control signals used to achieve this. + """ + + def __init__(self, + walker, + arena, + walker_spawn_position=(0, 0, 0), + walker_spawn_rotation=None, + target_velocity=3.0, + contact_termination=True, + terminate_at_height=-0.5, + physics_timestep=0.005, + control_timestep=0.025): + """Initializes this task. + + Args: + walker: an instance of `locomotion.walkers.base.Walker`. + arena: an instance of `locomotion.arenas.corridors.Corridor`. + walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation` + instance that generates such sequences, specifying the position at + which the walker is spawned at the beginning of an episode. + walker_spawn_rotation: a number, or a `composer.Variation` instance that + generates a number, specifying the yaw angle offset (in radians) that is + applied to the walker at the beginning of an episode. + target_velocity: a number specifying the target velocity (in meters per + second) for the walker. + contact_termination: whether to terminate if a non-foot geom touches the + ground. + terminate_at_height: a number specifying the height of end effectors below + which the episode terminates. + physics_timestep: a number specifying the timestep (in seconds) of the + physics simulation. + control_timestep: a number specifying the timestep (in seconds) at which + the agent applies its control inputs (in seconds). + """ + + self._arena = arena + self._walker = walker + self._walker.create_root_joints(self._arena.attach(self._walker)) + self._walker_spawn_position = walker_spawn_position + self._walker_spawn_rotation = walker_spawn_rotation + + enabled_observables = [] + enabled_observables += self._walker.observables.proprioception + enabled_observables += self._walker.observables.kinematic_sensors + enabled_observables += self._walker.observables.dynamic_sensors + enabled_observables.append(self._walker.observables.sensors_touch) + enabled_observables.append(self._walker.observables.egocentric_camera) + for observable in enabled_observables: + observable.enabled = True + + self._vel = target_velocity + self._contact_termination = contact_termination + self._terminate_at_height = terminate_at_height + + self.set_timesteps( + physics_timestep=physics_timestep, control_timestep=control_timestep) + + @property + def root_entity(self): + return self._arena + + def initialize_episode_mjcf(self, random_state): + self._arena.regenerate(random_state) + self._arena.mjcf_model.visual.map.znear = 0.00025 + self._arena.mjcf_model.visual.map.zfar = 4. + + def initialize_episode(self, physics, random_state): + self._walker.reinitialize_pose(physics, random_state) + if self._walker_spawn_rotation: + rotation = variation.evaluate( + self._walker_spawn_rotation, random_state=random_state) + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + else: + quat = None + self._walker.shift_pose( + physics, + position=variation.evaluate( + self._walker_spawn_position, random_state=random_state), + quaternion=quat, + rotate_velocity=True) + + self._failure_termination = False + walker_foot_geoms = set(self._walker.ground_contact_geoms) + walker_nonfoot_geoms = [ + geom for geom in self._walker.mjcf_model.find_all('geom') + if geom not in walker_foot_geoms] + self._walker_nonfoot_geomids = set( + physics.bind(walker_nonfoot_geoms).element_id) + self._ground_geomids = set( + physics.bind(self._arena.ground_geoms).element_id) + + def _is_disallowed_contact(self, contact): + set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids + return ((contact.geom1 in set1 and contact.geom2 in set2) or + (contact.geom1 in set2 and contact.geom2 in set1)) + + def before_step(self, physics, action, random_state): + self._walker.apply_action(physics, action, random_state) + + def after_step(self, physics, random_state): + self._failure_termination = False + if self._contact_termination: + for c in physics.data.contact: + if self._is_disallowed_contact(c): + self._failure_termination = True + break + if self._terminate_at_height is not None: + if any(physics.bind(self._walker.end_effectors).xpos[:, -1] < + self._terminate_at_height): + self._failure_termination = True + + def get_reward(self, physics): + walker_xvel = physics.bind(self._walker.root_body).subtree_linvel[0] + xvel_term = rewards.tolerance( + walker_xvel, (self._vel, self._vel), + margin=self._vel, + sigmoid='linear', + value_at_margin=0.0) + return xvel_term + + def should_terminate_episode(self, physics): + return self._failure_termination + + def get_discount(self, physics): + if self._failure_termination: + return 0. + else: + return 1. diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py new file mode 100644 index 0000000..c7ca7c9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/corridors_test.py @@ -0,0 +1,138 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.tasks.corridors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.variation import deterministic +from dm_control.composer.variation import rotations +from dm_control.locomotion.arenas import corridors as corridor_arenas +from dm_control.locomotion.tasks import corridors as corridor_tasks +from dm_control.locomotion.walkers import cmu_humanoid +import numpy as np +from six.moves import range + + +class CorridorsTest(parameterized.TestCase): + + @parameterized.parameters( + dict(position_offset=(0, 0, 0), + rotate_180_degrees=False, + use_variations=False), + dict(position_offset=(1, 2, 3), + rotate_180_degrees=True, + use_variations=True)) + def test_walker_is_correctly_reinitialized( + self, position_offset, rotate_180_degrees, use_variations): + walker_spawn_position = position_offset + + if not rotate_180_degrees: + walker_spawn_rotation = None + else: + walker_spawn_rotation = np.pi + + if use_variations: + walker_spawn_position = deterministic.Constant(position_offset) + walker_spawn_rotation = deterministic.Constant(walker_spawn_rotation) + + walker = cmu_humanoid.CMUHumanoid() + arena = corridor_arenas.EmptyCorridor() + task = corridor_tasks.RunThroughCorridor( + walker=walker, + arena=arena, + walker_spawn_position=walker_spawn_position, + walker_spawn_rotation=walker_spawn_rotation) + + # Randomize the initial pose and joint positions in order to check that they + # are set correctly by `initialize_episode`. + random_state = np.random.RandomState(12345) + task.initialize_episode_mjcf(random_state) + physics = mjcf.Physics.from_mjcf_model(task.root_entity.mjcf_model) + + walker_joints = walker.mjcf_model.find_all('joint') + physics.bind(walker_joints).qpos = random_state.uniform( + size=len(walker_joints)) + walker.set_pose(physics, + position=random_state.uniform(size=3), + quaternion=rotations.UniformQuaternion()(random_state)) + + task.initialize_episode(physics, random_state) + physics.forward() + + with self.subTest('Correct joint positions'): + walker_qpos = physics.bind(walker_joints).qpos + if walker.upright_pose.qpos is not None: + np.testing.assert_array_equal(walker_qpos, walker.upright_pose.qpos) + else: + walker_qpos0 = physics.bind(walker_joints).qpos0 + np.testing.assert_array_equal(walker_qpos, walker_qpos0) + + walker_xpos, walker_xquat = walker.get_pose(physics) + + with self.subTest('Correct position'): + expected_xpos = walker.upright_pose.xpos + np.array(position_offset) + np.testing.assert_array_equal(walker_xpos, expected_xpos) + + with self.subTest('Correct orientation'): + upright_xquat = walker.upright_pose.xquat.copy() + upright_xquat /= np.linalg.norm(walker.upright_pose.xquat) + if rotate_180_degrees: + expected_xquat = (-upright_xquat[3], -upright_xquat[2], + upright_xquat[1], upright_xquat[0]) + else: + expected_xquat = upright_xquat + np.testing.assert_allclose(walker_xquat, expected_xquat) + + def test_termination_and_discount(self): + walker = cmu_humanoid.CMUHumanoid() + arena = corridor_arenas.EmptyCorridor() + task = corridor_tasks.RunThroughCorridor(walker, arena) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + zero_action = np.zeros_like(env.physics.data.ctrl) + + # Walker starts in upright position. + # Should not trigger failure termination in the first few steps. + for _ in range(5): + env.step(zero_action) + self.assertFalse(task.should_terminate_episode(env.physics)) + self.assertEqual(task.get_discount(env.physics), 1) + + # Rotate the walker upside down and run the physics until it makes contact. + current_time = env.physics.data.time + walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0)) + env.physics.forward() + while env.physics.data.ncon == 0: + env.physics.step() + env.physics.data.time = current_time + + # Should now trigger a failure termination. + env.step(zero_action) + self.assertTrue(task.should_terminate_episode(env.physics)) + self.assertEqual(task.get_discount(env.physics), 0) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py new file mode 100644 index 0000000..1651a34 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape.py @@ -0,0 +1,188 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Escape locomotion tasks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable as base_observable +from dm_control.rl import control +from dm_control.utils import rewards + +import numpy as np + +# Constants related to terrain generation. +_HEIGHTFIELD_ID = 0 + + +class Escape(composer.Task): + """A task solved by escaping a starting area (e.g. bowl-shaped terrain).""" + + def __init__(self, + walker, + arena, + walker_spawn_position=(0, 0, 0), + walker_spawn_rotation=None, + physics_timestep=0.005, + control_timestep=0.025): + """Initializes this task. + + Args: + walker: an instance of `locomotion.walkers.base.Walker`. + arena: an instance of `locomotion.arenas`. + walker_spawn_position: a sequence of 3 numbers, or a `composer.Variation` + instance that generates such sequences, specifying the position at + which the walker is spawned at the beginning of an episode. + walker_spawn_rotation: a number, or a `composer.Variation` instance that + generates a number, specifying the yaw angle offset (in radians) that is + applied to the walker at the beginning of an episode. + physics_timestep: a number specifying the timestep (in seconds) of the + physics simulation. + control_timestep: a number specifying the timestep (in seconds) at which + the agent applies its control inputs (in seconds). + """ + + self._arena = arena + self._walker = walker + self._walker.create_root_joints(self._arena.attach(self._walker)) + self._walker_spawn_position = walker_spawn_position + self._walker_spawn_rotation = walker_spawn_rotation + + enabled_observables = [] + enabled_observables += self._walker.observables.proprioception + enabled_observables += self._walker.observables.kinematic_sensors + enabled_observables += self._walker.observables.dynamic_sensors + enabled_observables.append(self._walker.observables.sensors_touch) + enabled_observables.append(self._walker.observables.egocentric_camera) + for observable in enabled_observables: + observable.enabled = True + + if 'CMUHumanoid' in str(type(self._walker)): + core_body = 'walker/root' + self._reward_body = 'walker/root' + elif 'Rat' in str(type(self._walker)): + core_body = 'walker/torso' + self._reward_body = 'walker/head' + else: + raise ValueError('Expects Rat or CMUHumanoid.') + + def _origin(physics): + """Returns origin position in the torso frame.""" + torso_frame = physics.named.data.xmat[core_body].reshape(3, 3) + torso_pos = physics.named.data.xpos[core_body] + return -torso_pos.dot(torso_frame) + + self._walker.observables.add_observable( + 'origin', base_observable.Generic(_origin)) + + self.set_timesteps( + physics_timestep=physics_timestep, control_timestep=control_timestep) + + @property + def root_entity(self): + return self._arena + + def initialize_episode_mjcf(self, random_state): + if hasattr(self._arena, 'regenerate'): + self._arena.regenerate(random_state) + self._arena.mjcf_model.visual.map.znear = 0.00025 + self._arena.mjcf_model.visual.map.zfar = 50. + + def initialize_episode(self, physics, random_state): + super(Escape, self).initialize_episode(physics, random_state) + + # Initial configuration. + orientation = random_state.randn(4) + orientation /= np.linalg.norm(orientation) + _find_non_contacting_height(physics, self._walker, orientation) + + def get_reward(self, physics): + # Escape reward term. + terrain_size = physics.model.hfield_size[_HEIGHTFIELD_ID, 0] + escape_reward = rewards.tolerance( + np.asarray(np.linalg.norm( + physics.named.data.site_xpos[self._reward_body])), + bounds=(terrain_size, float('inf')), + margin=terrain_size, + value_at_margin=0, + sigmoid='linear') + upright_reward = _upright_reward(physics, self._walker, deviation_angle=30) + return upright_reward * escape_reward + + def get_discount(self, physics): + return 1. + + +def _find_non_contacting_height(physics, walker, orientation, + x_pos=0.0, y_pos=0.0, maxiter=1000): + """Find a height with no contacts given a body orientation. + + Args: + physics: An instance of `Physics`. + walker: the focal walker. + orientation: A quaternion. + x_pos: A float. Position along global x-axis. + y_pos: A float. Position along global y-axis. + maxiter: maximum number of iterations to try + """ + z_pos = 0.0 # Start embedded in the floor. + num_contacts = 1 + count = 1 + # Move up in 1cm increments until no contacts. + while num_contacts > 0: + try: + with physics.reset_context(): + freejoint = mjcf.get_frame_freejoint(walker.mjcf_model) + physics.bind(freejoint).qpos[:3] = x_pos, y_pos, z_pos + physics.bind(freejoint).qpos[3:] = orientation + except control.PhysicsError: + # We may encounter a PhysicsError here due to filling the contact + # buffer, in which case we simply increment the height and continue. + pass + num_contacts = physics.data.ncon + z_pos += 0.01 + count += 1 + if count > maxiter: + raise ValueError( + 'maxiter reached: possibly contacts in null pose of body.' + ) + + +def _upright_reward(physics, walker, deviation_angle=0): + """Returns a reward proportional to how upright the torso is. + + Args: + physics: an instance of `Physics`. + walker: the focal walker. + deviation_angle: A float, in degrees. The reward is 0 when the torso is + exactly upside-down and 1 when the torso's z-axis is less than + `deviation_angle` away from the global z-axis. + """ + deviation = np.cos(np.deg2rad(deviation_angle)) + upright_torso = physics.bind(walker.root_body).xmat[-1] + if hasattr(walker, 'pelvis_body'): + upright_pelvis = physics.bind(walker.pelvis_body).xmat[-1] + upright_zz = np.stack([upright_torso, upright_pelvis]) + else: + upright_zz = upright_torso + upright = rewards.tolerance(upright_zz, + bounds=(deviation, float('inf')), + sigmoid='linear', + margin=1 + deviation, + value_at_margin=0) + return np.min(upright) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py new file mode 100644 index 0000000..1efecca --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/escape_test.py @@ -0,0 +1,90 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for locomotion.tasks.escape.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + +from dm_control import composer +from dm_control.locomotion.arenas import bowl +from dm_control.locomotion.tasks import escape +from dm_control.locomotion.walkers import rodent + +import numpy as np +from six.moves import range + +_CONTROL_TIMESTEP = .02 +_PHYSICS_TIMESTEP = 0.001 + + +class EscapeTest(absltest.TestCase): + + def test_observables(self): + walker = rodent.Rat() + + # Build a corridor-shaped arena that is obstructed by walls. + arena = bowl.Bowl( + size=(20., 20.), + aesthetic='outdoor_natural') + + # Build a task that rewards the agent for running down the corridor at a + # specific velocity. + task = escape.Escape( + walker=walker, + arena=arena, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + timestep = env.reset() + + self.assertIn('walker/joints_pos', timestep.observation) + + def test_contact(self): + walker = rodent.Rat() + + # Build a corridor-shaped arena that is obstructed by walls. + arena = bowl.Bowl( + size=(20., 20.), + aesthetic='outdoor_natural') + + # Build a task that rewards the agent for running down the corridor at a + # specific velocity. + task = escape.Escape( + walker=walker, + arena=arena, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + zero_action = np.zeros_like(env.physics.data.ctrl) + + # Walker starts in upright position. + # Should not trigger failure termination in the first few steps. + for _ in range(5): + env.step(zero_action) + self.assertFalse(task.should_terminate_episode(env.physics)) + np.testing.assert_array_equal(task.get_discount(env.physics), 1) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py new file mode 100644 index 0000000..173608d --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target.py @@ -0,0 +1,220 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Task for a walker to move to a target.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from dm_control import composer +from dm_control.composer import variation +from dm_control.composer.observation import observable +from dm_control.composer.variation import distributions +import numpy as np + +DEFAULT_DISTANCE_TOLERANCE_TO_TARGET = 1.0 + + +class GoToTarget(composer.Task): + """A task that requires a walker to move towards a target.""" + + def __init__(self, + walker, + arena, + moving_target=False, + target_relative=False, + target_relative_dist=1.5, + steps_before_moving_target=10, + distance_tolerance=DEFAULT_DISTANCE_TOLERANCE_TO_TARGET, + target_spawn_position=None, + walker_spawn_position=None, + walker_spawn_rotation=None, + physics_timestep=0.005, + control_timestep=0.025): + """Initializes this task. + + Args: + walker: an instance of `locomotion.walkers.base.Walker`. + arena: an instance of `locomotion.arenas.floors.Floor`. + moving_target: bool, Whether the target should move after receiving the + walker reaches it. + target_relative: bool, Whether the target be set relative to its current + position. + target_relative_dist: float, new target distance range if + using target_relative. + steps_before_moving_target: int, the number of steps before the target + moves, if moving_target==True. + distance_tolerance: Accepted to distance to the target position before + providing reward. + target_spawn_position: a sequence of 2 numbers, or a `composer.Variation` + instance that generates such sequences, specifying the position at + which the target is spawned at the beginning of an episode. + If None, the entire arena is used to generate random target positions. + walker_spawn_position: a sequence of 2 numbers, or a `composer.Variation` + instance that generates such sequences, specifying the position at + which the walker is spawned at the beginning of an episode. + If None, the entire arena is used to generate random spawn positions. + walker_spawn_rotation: a number, or a `composer.Variation` instance that + generates a number, specifying the yaw angle offset (in radians) that is + applied to the walker at the beginning of an episode. + physics_timestep: a number specifying the timestep (in seconds) of the + physics simulation. + control_timestep: a number specifying the timestep (in seconds) at which + the agent applies its control inputs (in seconds). + """ + + self._arena = arena + self._walker = walker + self._walker.create_root_joints(self._arena.attach(self._walker)) + + arena_position = distributions.Uniform( + low=-np.array(arena.size) / 2, high=np.array(arena.size) / 2) + if target_spawn_position is not None: + self._target_spawn_position = target_spawn_position + else: + self._target_spawn_position = arena_position + + if walker_spawn_position is not None: + self._walker_spawn_position = walker_spawn_position + else: + self._walker_spawn_position = arena_position + + self._walker_spawn_rotation = walker_spawn_rotation + + self._distance_tolerance = distance_tolerance + self._moving_target = moving_target + self._target_relative = target_relative + self._target_relative_dist = target_relative_dist + self._steps_before_moving_target = steps_before_moving_target + self._reward_step_counter = 0 + + self._target = self.root_entity.mjcf_model.worldbody.add( + 'site', + name='target', + type='sphere', + pos=(0., 0., 0.), + size=(0.1,), + rgba=(0.9, 0.6, 0.6, 1.0)) + + enabled_observables = [] + enabled_observables += self._walker.observables.proprioception + enabled_observables += self._walker.observables.kinematic_sensors + enabled_observables += self._walker.observables.dynamic_sensors + enabled_observables.append(self._walker.observables.sensors_touch) + for obs in enabled_observables: + obs.enabled = True + + walker.observables.add_egocentric_vector( + 'target', + observable.MJCFFeature('pos', self._target), + origin_callable=lambda physics: physics.bind(walker.root_body).xpos) + + self.set_timesteps( + physics_timestep=physics_timestep, control_timestep=control_timestep) + + @property + def root_entity(self): + return self._arena + + def target_position(self, physics): + return np.array(physics.bind(self._target).pos) + + def initialize_episode_mjcf(self, random_state): + self._arena.regenerate(random_state=random_state) + + target_x, target_y = variation.evaluate( + self._target_spawn_position, random_state=random_state) + self._target.pos = [target_x, target_y, 0.] + + def initialize_episode(self, physics, random_state): + self._walker.reinitialize_pose(physics, random_state) + if self._walker_spawn_rotation: + rotation = variation.evaluate( + self._walker_spawn_rotation, random_state=random_state) + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + else: + quat = None + walker_x, walker_y = variation.evaluate( + self._walker_spawn_position, random_state=random_state) + self._walker.shift_pose( + physics, + position=[walker_x, walker_y, 0.], + quaternion=quat, + rotate_velocity=True) + + self._failure_termination = False + walker_foot_geoms = set(self._walker.ground_contact_geoms) + walker_nonfoot_geoms = [ + geom for geom in self._walker.mjcf_model.find_all('geom') + if geom not in walker_foot_geoms] + self._walker_nonfoot_geomids = set( + physics.bind(walker_nonfoot_geoms).element_id) + self._ground_geomids = set( + physics.bind(self._arena.ground_geoms).element_id) + self._ground_geomids.add(physics.bind(self._target).element_id) + + def _is_disallowed_contact(self, contact): + set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids + return ((contact.geom1 in set1 and contact.geom2 in set2) or + (contact.geom1 in set2 and contact.geom2 in set1)) + + def should_terminate_episode(self, physics): + return self._failure_termination + + def get_discount(self, physics): + if self._failure_termination: + return 0. + else: + return 1. + + def get_reward(self, physics): + reward = 0. + distance = np.linalg.norm( + physics.bind(self._target).pos[:2] - + physics.bind(self._walker.root_body).xpos[:2]) + if distance < self._distance_tolerance: + reward = 1. + if self._moving_target: + self._reward_step_counter += 1 + return reward + + def before_step(self, physics, action, random_state): + self._walker.apply_action(physics, action, random_state) + + def after_step(self, physics, random_state): + self._failure_termination = False + for contact in physics.data.contact: + if self._is_disallowed_contact(contact): + self._failure_termination = True + break + if (self._moving_target and + self._reward_step_counter >= self._steps_before_moving_target): + + # Reset the target position. + if self._target_relative: + walker_pos = physics.bind(self._walker.root_body).xpos[:2] + target_x, target_y = random_state.uniform( + -np.array([self._target_relative_dist, self._target_relative_dist]), + np.array([self._target_relative_dist, self._target_relative_dist])) + target_x += walker_pos[0] + target_y += walker_pos[1] + else: + target_x, target_y = variation.evaluate( + self._target_spawn_position, random_state=random_state) + physics.bind(self._target).pos = [target_x, target_y, 0.] + + # Reset the number of steps at the target for the moving target. + self._reward_step_counter = 0 diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py new file mode 100644 index 0000000..183f0df --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/go_to_target_test.py @@ -0,0 +1,160 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for locomotion.tasks.go_to_target.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest + +from dm_control import composer +from dm_control.locomotion.arenas import floors +from dm_control.locomotion.tasks import go_to_target +from dm_control.locomotion.walkers import cmu_humanoid +import numpy as np +from six.moves import range + + +class GoToTargetTest(absltest.TestCase): + + def test_observables(self): + walker = cmu_humanoid.CMUHumanoid() + arena = floors.Floor() + task = go_to_target.GoToTarget( + walker=walker, arena=arena, moving_target=False) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + timestep = env.reset() + + self.assertIn('walker/target', timestep.observation) + + def test_target_position_randomized_on_reset(self): + walker = cmu_humanoid.CMUHumanoid() + arena = floors.Floor() + task = go_to_target.GoToTarget( + walker=walker, arena=arena, moving_target=False) + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + first_target_position = task.target_position(env.physics) + env.reset() + second_target_position = task.target_position(env.physics) + self.assertFalse(np.all(first_target_position == second_target_position), + 'Target positions are unexpectedly identical.') + + def test_reward_fixed_target(self): + walker = cmu_humanoid.CMUHumanoid() + arena = floors.Floor() + task = go_to_target.GoToTarget( + walker=walker, arena=arena, moving_target=False) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + target_position = task.target_position(env.physics) + zero_action = np.zeros_like(env.physics.data.ctrl) + for _ in range(2): + timestep = env.step(zero_action) + self.assertEqual(timestep.reward, 0) + walker_pos = env.physics.bind(walker.root_body).xpos + walker.set_pose( + env.physics, + position=[target_position[0], target_position[1], walker_pos[2]]) + env.physics.forward() + + # Receive reward while the agent remains at that location. + timestep = env.step(zero_action) + self.assertEqual(timestep.reward, 1) + + # Target position should not change. + np.testing.assert_array_equal(target_position, + task.target_position(env.physics)) + + def test_reward_moving_target(self): + walker = cmu_humanoid.CMUHumanoid() + arena = floors.Floor() + + steps_before_moving_target = 2 + task = go_to_target.GoToTarget( + walker=walker, + arena=arena, + moving_target=True, + steps_before_moving_target=steps_before_moving_target) + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + target_position = task.target_position(env.physics) + zero_action = np.zeros_like(env.physics.data.ctrl) + for _ in range(2): + timestep = env.step(zero_action) + self.assertEqual(timestep.reward, 0) + + walker_pos = env.physics.bind(walker.root_body).xpos + walker.set_pose( + env.physics, + position=[target_position[0], target_position[1], walker_pos[2]]) + env.physics.forward() + + # Receive reward while the agent remains at that location. + for _ in range(steps_before_moving_target): + timestep = env.step(zero_action) + self.assertEqual(timestep.reward, 1) + np.testing.assert_array_equal(target_position, + task.target_position(env.physics)) + + # After taking > steps_before_moving_target, the target should move and + # reward should be 0. + timestep = env.step(zero_action) + self.assertEqual(timestep.reward, 0) + + def test_termination_and_discount(self): + walker = cmu_humanoid.CMUHumanoid() + arena = floors.Floor() + task = go_to_target.GoToTarget(walker=walker, arena=arena) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + zero_action = np.zeros_like(env.physics.data.ctrl) + + # Walker starts in upright position. + # Should not trigger failure termination in the first few steps. + for _ in range(5): + env.step(zero_action) + self.assertFalse(task.should_terminate_episode(env.physics)) + np.testing.assert_array_equal(task.get_discount(env.physics), 1) + + # Rotate the walker upside down and run the physics until it makes contact. + current_time = env.physics.data.time + walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0)) + env.physics.forward() + while env.physics.data.ncon == 0: + env.physics.step() + env.physics.data.time = current_time + + # Should now trigger a failure termination. + env.step(zero_action) + self.assertTrue(task.should_terminate_episode(env.physics)) + np.testing.assert_array_equal(task.get_discount(env.physics), 0) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py new file mode 100644 index 0000000..6e8ef52 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze.py @@ -0,0 +1,546 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A task consisting of finding goals/targets in a random maze.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable as observable_lib +from dm_control.locomotion.props import target_sphere +from dm_control.mujoco.wrapper import mjbindings +import numpy as np +from six.moves import range +from six.moves import zip + +_NUM_RAYS = 10 + +# Aliveness in [-1., 0.]. +DEFAULT_ALIVE_THRESHOLD = -0.5 + +DEFAULT_PHYSICS_TIMESTEP = 0.001 +DEFAULT_CONTROL_TIMESTEP = 0.025 + + +class NullGoalMaze(composer.Task): + """A base task for maze with goals.""" + + def __init__(self, + walker, + maze_arena, + randomize_spawn_position=True, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + aliveness_threshold=DEFAULT_ALIVE_THRESHOLD, + contact_termination=True, + enable_global_task_observables=False, + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + """Initializes goal-directed maze task. + + Args: + walker: The body to navigate the maze. + maze_arena: The physical maze arena object. + randomize_spawn_position: Flag to randomize position of spawning. + randomize_spawn_rotation: Flag to randomize orientation of spawning. + rotation_bias_factor: A non-negative number that concentrates initial + orientation away from walls. When set to zero, the initial orientation + is uniformly random. The larger the value of this number, the more + likely it is that the initial orientation would face the direction that + is farthest away from a wall. + aliveness_reward: Reward for being alive. + aliveness_threshold: Threshold if should terminate based on walker + aliveness feature. + contact_termination: whether to terminate if a non-foot geom touches the + ground. + enable_global_task_observables: Flag to provide task observables that + contain global information, including map layout. + physics_timestep: timestep of simulation. + control_timestep: timestep at which agent changes action. + """ + self._walker = walker + self._maze_arena = maze_arena + self._walker.create_root_joints(self._maze_arena.attach(self._walker)) + + self._randomize_spawn_position = randomize_spawn_position + self._randomize_spawn_rotation = randomize_spawn_rotation + self._rotation_bias_factor = rotation_bias_factor + + self._aliveness_reward = aliveness_reward + self._aliveness_threshold = aliveness_threshold + self._contact_termination = contact_termination + self._discount = 1.0 + + self.set_timesteps( + physics_timestep=physics_timestep, control_timestep=control_timestep) + + self._walker.observables.egocentric_camera.height = 64 + self._walker.observables.egocentric_camera.width = 64 + + for observable in (self._walker.observables.proprioception + + self._walker.observables.kinematic_sensors + + self._walker.observables.dynamic_sensors): + observable.enabled = True + self._walker.observables.egocentric_camera.enabled = True + + if enable_global_task_observables: + # Reveal maze text map as observable. + maze_obs = observable_lib.Generic( + lambda _: self._maze_arena.maze.entity_layer) + maze_obs.enabled = True + + # absolute walker position + def get_walker_pos(physics): + walker_pos = physics.bind(self._walker.root_body).xpos + return walker_pos + absolute_position = observable_lib.Generic(get_walker_pos) + absolute_position.enabled = True + + # absolute walker orientation + def get_walker_ori(physics): + walker_ori = np.reshape( + physics.bind(self._walker.root_body).xmat, (3, 3)) + return walker_ori + absolute_orientation = observable_lib.Generic(get_walker_ori) + absolute_orientation.enabled = True + + # grid element of player in maze cell: i,j cell in maze layout + def get_walker_ij(physics): + walker_xypos = physics.bind(self._walker.root_body).xpos[:-1] + walker_rel_origin = ( + (walker_xypos + + np.sign(walker_xypos) * self._maze_arena.xy_scale / 2) / + (self._maze_arena.xy_scale)).astype(int) + x_offset = (self._maze_arena.maze.width - 1) / 2 + y_offset = (self._maze_arena.maze.height - 1) / 2 + walker_ij = walker_rel_origin + np.array([x_offset, y_offset]) + return walker_ij + absolute_position_discrete = observable_lib.Generic(get_walker_ij) + absolute_position_discrete.enabled = True + + self._task_observables = collections.OrderedDict({ + 'maze_layout': maze_obs, + 'absolute_position': absolute_position, + 'absolute_orientation': absolute_orientation, + 'location_in_maze': absolute_position_discrete, # from bottom left + }) + else: + self._task_observables = collections.OrderedDict({}) + + @property + def task_observables(self): + return self._task_observables + + @property + def name(self): + return 'goal_maze' + + @property + def root_entity(self): + return self._maze_arena + + def initialize_episode_mjcf(self, unused_random_state): + self._maze_arena.regenerate() + + def _respawn(self, physics, random_state): + self._walker.reinitialize_pose(physics, random_state) + + if self._randomize_spawn_position: + self._spawn_position = self._maze_arena.spawn_positions[ + random_state.randint(0, len(self._maze_arena.spawn_positions))] + + if self._randomize_spawn_rotation: + # Move walker up out of the way before raycasting. + self._walker.shift_pose(physics, [0.0, 0.0, 100.0]) + + distances = [] + geomid_out = np.array([-1], dtype=np.intc) + for i in range(_NUM_RAYS): + theta = 2 * np.pi * i / _NUM_RAYS + pos = np.array([self._spawn_position[0], self._spawn_position[1], 0.1], + dtype=np.float64) + vec = np.array([np.cos(theta), np.sin(theta), 0], dtype=np.float64) + dist = mjbindings.mjlib.mj_ray( + physics.model.ptr, physics.data.ptr, pos, vec, + None, 1, -1, geomid_out) + distances.append(dist) + + def remap_with_bias(x): + """Remaps values [-1, 1] -> [-1, 1] with bias.""" + return np.tanh((1 + self._rotation_bias_factor) * np.arctanh(x)) + + max_theta = 2 * np.pi * np.argmax(distances) / _NUM_RAYS + rotation = max_theta + np.pi * ( + 1 + remap_with_bias(random_state.uniform(-1, 1))) + + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + + # Move walker back down. + self._walker.shift_pose(physics, [0.0, 0.0, -100.0]) + else: + quat = None + + self._walker.shift_pose( + physics, [self._spawn_position[0], self._spawn_position[1], 0.0], + quat, + rotate_velocity=True) + + def initialize_episode(self, physics, random_state): + super(NullGoalMaze, self).initialize_episode(physics, random_state) + self._respawn(physics, random_state) + self._discount = 1.0 + + walker_foot_geoms = set(self._walker.ground_contact_geoms) + walker_nonfoot_geoms = [ + geom for geom in self._walker.mjcf_model.find_all('geom') + if geom not in walker_foot_geoms] + self._walker_nonfoot_geomids = set( + physics.bind(walker_nonfoot_geoms).element_id) + self._ground_geomids = set( + physics.bind(self._maze_arena.ground_geoms).element_id) + + def _is_disallowed_contact(self, contact): + set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids + return ((contact.geom1 in set1 and contact.geom2 in set2) or + (contact.geom1 in set2 and contact.geom2 in set1)) + + def after_step(self, physics, random_state): + self._failure_termination = False + if self._contact_termination: + for c in physics.data.contact: + if self._is_disallowed_contact(c): + self._failure_termination = True + break + + def should_terminate_episode(self, physics): + if self._walker.aliveness(physics) < self._aliveness_threshold: + self._failure_termination = True + if self._failure_termination: + self._discount = 0.0 + return True + else: + return False + + def get_reward(self, physics): + del physics + return self._aliveness_reward + + def get_discount(self, physics): + del physics + return self._discount + + +class RepeatSingleGoalMaze(NullGoalMaze): + """Requires an agent to repeatedly find the same goal in a maze.""" + + def __init__(self, + walker, + maze_arena, + target=target_sphere.TargetSphere(), + target_reward_scale=1.0, + randomize_spawn_position=True, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + aliveness_threshold=DEFAULT_ALIVE_THRESHOLD, + contact_termination=True, + max_repeats=0, + enable_global_task_observables=False, + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + super(RepeatSingleGoalMaze, self).__init__( + walker=walker, + maze_arena=maze_arena, + randomize_spawn_position=randomize_spawn_position, + randomize_spawn_rotation=randomize_spawn_rotation, + rotation_bias_factor=rotation_bias_factor, + aliveness_reward=aliveness_reward, + aliveness_threshold=aliveness_threshold, + contact_termination=contact_termination, + enable_global_task_observables=enable_global_task_observables, + physics_timestep=physics_timestep, + control_timestep=control_timestep) + self._target = target + self._rewarded_this_step = False + self._maze_arena.attach(target) + self._target_reward_scale = target_reward_scale + self._max_repeats = max_repeats + self._targets_obtained = 0 + + if enable_global_task_observables: + xpos_origin_callable = lambda phys: phys.bind(walker.root_body).xpos + + def _target_pos(physics, target=target): + return physics.bind(target.geom).xpos + + walker.observables.add_egocentric_vector( + 'target_0', + observable_lib.Generic(_target_pos), + origin_callable=xpos_origin_callable) + + def initialize_episode_mjcf(self, random_state): + super(RepeatSingleGoalMaze, self).initialize_episode_mjcf(random_state) + self._target_position = self._maze_arena.target_positions[ + random_state.randint(0, len(self._maze_arena.target_positions))] + mjcf.get_attachment_frame( + self._target.mjcf_model).pos = self._target_position + + def initialize_episode(self, physics, random_state): + super(RepeatSingleGoalMaze, self).initialize_episode(physics, random_state) + self._rewarded_this_step = False + self._targets_obtained = 0 + + def after_step(self, physics, random_state): + super(RepeatSingleGoalMaze, self).after_step(physics, random_state) + if self._target.activated: + self._rewarded_this_step = True + self._targets_obtained += 1 + if self._targets_obtained <= self._max_repeats: + self._respawn(physics, random_state) + self._target.reset(physics) + else: + self._rewarded_this_step = False + + def should_terminate_episode(self, physics): + if super(RepeatSingleGoalMaze, self).should_terminate_episode(physics): + return True + if self._targets_obtained > self._max_repeats: + return True + + def get_reward(self, physics): + del physics + if self._rewarded_this_step: + target_reward = self._target_reward_scale + else: + target_reward = 0.0 + return target_reward + self._aliveness_reward + + +class ManyHeterogeneousGoalsMaze(NullGoalMaze): + """Requires an agent to find multiple goals with different rewards.""" + + def __init__(self, + walker, + maze_arena, + target_builders, + target_type_rewards, + target_type_proportions, + shuffle_target_builders=False, + randomize_spawn_position=True, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + aliveness_threshold=DEFAULT_ALIVE_THRESHOLD, + contact_termination=True, + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + super(ManyHeterogeneousGoalsMaze, self).__init__( + walker=walker, + maze_arena=maze_arena, + randomize_spawn_position=randomize_spawn_position, + randomize_spawn_rotation=randomize_spawn_rotation, + rotation_bias_factor=rotation_bias_factor, + aliveness_reward=aliveness_reward, + aliveness_threshold=aliveness_threshold, + contact_termination=contact_termination, + physics_timestep=physics_timestep, + control_timestep=control_timestep) + self._active_targets = [] + self._target_builders = target_builders + self._target_type_rewards = tuple(target_type_rewards) + self._target_type_fractions = ( + np.array(target_type_proportions, dtype=float) / + np.sum(target_type_proportions)) + self._shuffle_target_builders = shuffle_target_builders + + def _get_targets(self, total_target_count, random_state): + # Multiply total target count by the fraction for each type, rounded down. + target_numbers = np.array([int(frac * total_target_count) + for frac in self._target_type_fractions]) + + # Calculate deviations from the ideal ratio incurred by rounding. + errors = (self._target_type_fractions - + target_numbers / float(total_target_count)) + + # Sort the target types by deviations from ideal ratios. + target_types_sorted_by_errors = list(np.argsort(errors)) + + # Top up individual target classes until we reach the desired total, + # starting from the class that is furthest away from the ideal ratio. + current_total = np.sum(target_numbers) + while current_total < total_target_count: + target_numbers[target_types_sorted_by_errors.pop()] += 1 + current_total += 1 + + if self._shuffle_target_builders: + random_state.shuffle(self._target_builders) + + all_targets = [] + for target_type, num in enumerate(target_numbers): + targets = [] + target_builder = self._target_builders[target_type] + for i in range(num): + target = target_builder(name='target_{}_{}'.format(target_type, i)) + targets.append(target) + all_targets.append(targets) + return all_targets + + def initialize_episode_mjcf(self, random_state): + super( + ManyHeterogeneousGoalsMaze, self).initialize_episode_mjcf(random_state) + for target in itertools.chain(*self._active_targets): + target.detach() + target_positions = list(self._maze_arena.target_positions) + random_state.shuffle(target_positions) + all_targets = self._get_targets(len(target_positions), random_state) + for pos, target in zip(target_positions, itertools.chain(*all_targets)): + self._maze_arena.attach(target) + mjcf.get_attachment_frame(target.mjcf_model).pos = pos + target.initialize_episode_mjcf(random_state) + self._active_targets = all_targets + self._target_rewarded = [[False] * len(targets) for targets in all_targets] + + def get_reward(self, physics): + del physics + reward = self._aliveness_reward + for target_type, targets in enumerate(self._active_targets): + for i, target in enumerate(targets): + if target.activated and not self._target_rewarded[target_type][i]: + reward += self._target_type_rewards[target_type] + self._target_rewarded[target_type][i] = True + return reward + + def should_terminate_episode(self, physics): + if super(ManyHeterogeneousGoalsMaze, + self).should_terminate_episode(physics): + return True + else: + for target in itertools.chain(*self._active_targets): + if not target.activated: + return False + # All targets have been activated: successful termination. + return True + + +class ManyGoalsMaze(ManyHeterogeneousGoalsMaze): + """Requires an agent to find all goals in a random maze.""" + + def __init__(self, + walker, + maze_arena, + target_builder, + target_reward_scale=1.0, + randomize_spawn_position=True, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + aliveness_threshold=DEFAULT_ALIVE_THRESHOLD, + contact_termination=True, + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + super(ManyGoalsMaze, self).__init__( + walker=walker, + maze_arena=maze_arena, + target_builders=[target_builder], + target_type_rewards=[target_reward_scale], + target_type_proportions=[1], + randomize_spawn_position=randomize_spawn_position, + randomize_spawn_rotation=randomize_spawn_rotation, + rotation_bias_factor=rotation_bias_factor, + aliveness_reward=aliveness_reward, + aliveness_threshold=aliveness_threshold, + contact_termination=contact_termination, + physics_timestep=physics_timestep, + control_timestep=control_timestep) + + +class RepeatSingleGoalMazeAugmentedWithTargets(RepeatSingleGoalMaze): + """Augments the single goal maze with many lower reward targets.""" + + def __init__(self, + walker, + main_target, + maze_arena, + num_subtargets=20, + target_reward_scale=10.0, + subtarget_reward_scale=1.0, + subtarget_colors=((0, 0, 0.4), (0, 0, 0.7)), + randomize_spawn_position=True, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + aliveness_threshold=DEFAULT_ALIVE_THRESHOLD, + contact_termination=True, + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + super(RepeatSingleGoalMazeAugmentedWithTargets, self).__init__( + walker=walker, + target=main_target, + maze_arena=maze_arena, + target_reward_scale=target_reward_scale, + randomize_spawn_position=randomize_spawn_position, + randomize_spawn_rotation=randomize_spawn_rotation, + rotation_bias_factor=rotation_bias_factor, + aliveness_reward=aliveness_reward, + aliveness_threshold=aliveness_threshold, + contact_termination=contact_termination, + physics_timestep=physics_timestep, + control_timestep=control_timestep) + self._subtarget_reward_scale = subtarget_reward_scale + self._subtargets = [] + for i in range(num_subtargets): + subtarget = target_sphere.TargetSphere( + radius=0.4, rgb1=subtarget_colors[0], rgb2=subtarget_colors[1], + name='subtarget_{}'.format(i) + ) + self._subtargets.append(subtarget) + self._maze_arena.attach(subtarget) + self._subtarget_rewarded = None + + def initialize_episode_mjcf(self, random_state): + super(RepeatSingleGoalMazeAugmentedWithTargets, + self).initialize_episode_mjcf(random_state) + subtarget_positions = self._maze_arena.target_positions + for pos, subtarget in zip(subtarget_positions, self._subtargets): + mjcf.get_attachment_frame(subtarget.mjcf_model).pos = pos + self._subtarget_rewarded = [False] * len(self._subtargets) + + def get_reward(self, physics): + main_reward = super(RepeatSingleGoalMazeAugmentedWithTargets, + self).get_reward(physics) + subtarget_reward = 0 + for i, subtarget in enumerate(self._subtargets): + if subtarget.activated and not self._subtarget_rewarded[i]: + subtarget_reward += 1 + self._subtarget_rewarded[i] = True + subtarget_reward *= self._subtarget_reward_scale + return main_reward + subtarget_reward + + def should_terminate_episode(self, physics): + if super(RepeatSingleGoalMazeAugmentedWithTargets, + self).should_terminate_episode(physics): + return True + else: + for subtarget in self._subtargets: + if not subtarget.activated: + return False + # All subtargets have been activated. + return True diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py new file mode 100644 index 0000000..bef643c --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/random_goal_maze_test.py @@ -0,0 +1,135 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for locomotion.tasks.random_goal_maze.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from absl.testing import absltest + +from dm_control import composer +from dm_control.locomotion.arenas import labmaze_textures +from dm_control.locomotion.arenas import mazes +from dm_control.locomotion.props import target_sphere +from dm_control.locomotion.tasks import random_goal_maze +from dm_control.locomotion.walkers import cmu_humanoid + +import numpy as np +from six.moves import range + + +class RandomGoalMazeTest(absltest.TestCase): + + def test_observables(self): + walker = cmu_humanoid.CMUHumanoid() + + # Build a maze with rooms and targets. + skybox_texture = labmaze_textures.SkyBox(style='sky_03') + wall_textures = labmaze_textures.WallTextures(style='style_01') + floor_textures = labmaze_textures.FloorTextures(style='style_01') + arena = mazes.RandomMazeWithTargets( + x_cells=11, + y_cells=11, + xy_scale=3, + max_rooms=4, + room_min_size=4, + room_max_size=5, + spawns_per_room=1, + targets_per_room=3, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures, + ) + + task = random_goal_maze.ManyGoalsMaze( + walker=walker, + maze_arena=arena, + target_builder=functools.partial( + target_sphere.TargetSphere, + radius=0.4, + rgb1=(0, 0, 0.4), + rgb2=(0, 0, 0.7)), + control_timestep=.03, + physics_timestep=.005, + ) + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + timestep = env.reset() + + self.assertIn('walker/joints_pos', timestep.observation) + + def test_termination_and_discount(self): + walker = cmu_humanoid.CMUHumanoid() + + # Build a maze with rooms and targets. + skybox_texture = labmaze_textures.SkyBox(style='sky_03') + wall_textures = labmaze_textures.WallTextures(style='style_01') + floor_textures = labmaze_textures.FloorTextures(style='style_01') + arena = mazes.RandomMazeWithTargets( + x_cells=11, + y_cells=11, + xy_scale=3, + max_rooms=4, + room_min_size=4, + room_max_size=5, + spawns_per_room=1, + targets_per_room=3, + skybox_texture=skybox_texture, + wall_textures=wall_textures, + floor_textures=floor_textures, + ) + + task = random_goal_maze.ManyGoalsMaze( + walker=walker, + maze_arena=arena, + target_builder=functools.partial( + target_sphere.TargetSphere, + radius=0.4, + rgb1=(0, 0, 0.4), + rgb2=(0, 0, 0.7)), + control_timestep=.03, + physics_timestep=.005, + ) + + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + env.reset() + + zero_action = np.zeros_like(env.physics.data.ctrl) + + # Walker starts in upright position. + # Should not trigger failure termination in the first few steps. + for _ in range(5): + env.step(zero_action) + self.assertFalse(task.should_terminate_episode(env.physics)) + np.testing.assert_array_equal(task.get_discount(env.physics), 1) + + # Rotate the walker upside down and run the physics until it makes contact. + current_time = env.physics.data.time + walker.shift_pose(env.physics, position=(0, 0, 10), quaternion=(0, 1, 0, 0)) + env.physics.forward() + while env.physics.data.ncon == 0: + env.physics.step() + env.physics.data.time = current_time + + # Should now trigger a failure termination. + env.step(zero_action) + self.assertTrue(task.should_terminate_episode(env.physics)) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py new file mode 100644 index 0000000..e575691 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach.py @@ -0,0 +1,294 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A (visuomotor) task consisting of reaching to targets for reward.""" + + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import itertools + +from dm_control import composer +from dm_control.composer.observation import observable as dm_observable +import enum +import numpy as np +from six.moves import range +from six.moves import zip + +DEFAULT_ALIVE_THRESHOLD = -1.0 +DEFAULT_PHYSICS_TIMESTEP = 0.005 +DEFAULT_CONTROL_TIMESTEP = 0.03 + + +class TwoTouchState(enum.IntEnum): + PRE_TOUCH = 0 + TOUCHED_ONCE = 1 + TOUCHED_TWICE = 2 # at appropriate time + TOUCHED_TOO_SOON = 3 + NO_SECOND_TOUCH = 4 + + +class TwoTouch(composer.Task): + """Task with target to tap with short delay (for Rat).""" + + def __init__(self, + walker, + arena, + target_builders, + target_type_rewards, + shuffle_target_builders=False, + randomize_spawn_position=False, + randomize_spawn_rotation=True, + rotation_bias_factor=0, + aliveness_reward=0.0, + touch_interval=0.8, + interval_tolerance=0.1, # consider making a curriculum + failure_timeout=1.2, + reset_delay=0., + z_height=.14, # 5.5" in real experiments + target_area=(), + physics_timestep=DEFAULT_PHYSICS_TIMESTEP, + control_timestep=DEFAULT_CONTROL_TIMESTEP): + self._walker = walker + self._arena = arena + self._walker.create_root_joints(self._arena.attach(self._walker)) + if 'CMUHumanoid' in str(type(self._walker)): + self._lhand_body = walker.mjcf_model.find('body', 'lhand') + self._rhand_body = walker.mjcf_model.find('body', 'rhand') + elif 'Rat' in str(type(self._walker)): + self._lhand_body = walker.mjcf_model.find('body', 'hand_L') + self._rhand_body = walker.mjcf_model.find('body', 'hand_R') + else: + raise ValueError('Expects Rat or CMUHumanoid.') + self._lhand_geoms = self._lhand_body.find_all('geom') + self._rhand_geoms = self._rhand_body.find_all('geom') + + self._targets = [] + self._target_builders = target_builders + self._target_type_rewards = tuple(target_type_rewards) + self._shuffle_target_builders = shuffle_target_builders + + self._randomize_spawn_position = randomize_spawn_position + self._spawn_position = [0.0, 0.0] # x, y + self._randomize_spawn_rotation = randomize_spawn_rotation + self._rotation_bias_factor = rotation_bias_factor + + self._aliveness_reward = aliveness_reward + self._discount = 1.0 + + self._touch_interval = touch_interval + self._interval_tolerance = interval_tolerance + self._failure_timeout = failure_timeout + self._reset_delay = reset_delay + self._target_positions = [] + self._state_logic = TwoTouchState.PRE_TOUCH + + self._z_height = z_height + arena_size = self._arena.size + if target_area: + self._target_area = target_area + else: + self._target_area = [1/2*arena_size[0], 1/2*arena_size[1]] + target_x = 1. + target_y = 1. + self._target_positions.append((target_x, target_y, self._z_height)) + + self.set_timesteps( + physics_timestep=physics_timestep, control_timestep=control_timestep) + + self._task_observables = collections.OrderedDict() + def task_state(physics): + del physics + return np.array([self._state_logic]) + self._task_observables['task_logic'] = dm_observable.Generic(task_state) + + self._walker.observables.egocentric_camera.height = 64 + self._walker.observables.egocentric_camera.width = 64 + + for observable in (self._walker.observables.proprioception + + self._walker.observables.kinematic_sensors + + self._walker.observables.dynamic_sensors + + list(self._task_observables.values())): + observable.enabled = True + self._walker.observables.egocentric_camera.enabled = True + + def _get_targets(self, total_target_count, random_state): + # Multiply total target count by the fraction for each type, rounded down. + target_numbers = np.array([1, len(self._target_positions)-1]) + + if self._shuffle_target_builders: + random_state.shuffle(self._target_builders) + + all_targets = [] + for target_type, num in enumerate(target_numbers): + targets = [] + if num < 1: + break + target_builder = self._target_builders[target_type] + for i in range(num): + target = target_builder(name='target_{}_{}'.format(target_type, i)) + targets.append(target) + all_targets.append(targets) + return all_targets + + @property + def name(self): + return 'two_touch' + + @property + def task_observables(self): + return self._task_observables + + @property + def root_entity(self): + return self._arena + + def _randomize_targets(self, physics, random_state=np.random): + for ii in range(len(self._target_positions)): + target_x = self._target_area[0]*random_state.uniform(-1., 1.) + target_y = self._target_area[1]*random_state.uniform(-1., 1.) + self._target_positions[ii] = (target_x, target_y, self._z_height) + target_positions = np.copy(self._target_positions) + random_state.shuffle(target_positions) + all_targets = self._targets + for pos, target in zip(target_positions, itertools.chain(*all_targets)): + target.reset(physics) + physics.bind(target.geom).pos = pos + self._targets = all_targets + self._target_rewarded_once = [ + [False] * len(targets) for targets in all_targets] + self._target_rewarded_twice = [ + [False] * len(targets) for targets in all_targets] + self._first_touch_time = None + self._second_touch_time = None + self._do_time_out = False + self._state_logic = TwoTouchState.PRE_TOUCH + + def initialize_episode_mjcf(self, random_state): + self._arena.regenerate(random_state) + for target in itertools.chain(*self._targets): + target.detach() + target_positions = np.copy(self._target_positions) + random_state.shuffle(target_positions) + all_targets = self._get_targets(len(self._target_positions), random_state) + for pos, target in zip(target_positions, itertools.chain(*all_targets)): + self._arena.attach(target) + target.geom.pos = pos + target.initialize_episode_mjcf(random_state) + self._targets = all_targets + + def _respawn_walker(self, physics, random_state): + self._walker.reinitialize_pose(physics, random_state) + + if self._randomize_spawn_position: + self._spawn_position = self._arena.spawn_positions[ + random_state.randint(0, len(self._arena.spawn_positions))] + + if self._randomize_spawn_rotation: + rotation = 2*np.pi*np.random.uniform() + quat = [np.cos(rotation / 2), 0, 0, np.sin(rotation / 2)] + + self._walker.shift_pose( + physics, + [self._spawn_position[0], self._spawn_position[1], 0.0], + quat, + rotate_velocity=True) + + def initialize_episode(self, physics, random_state): + super(TwoTouch, self).initialize_episode(physics, random_state) + self._respawn_walker(physics, random_state) + self._state_logic = TwoTouchState.PRE_TOUCH + self._discount = 1.0 + self._lhand_geomids = set(physics.bind(self._lhand_geoms).element_id) + self._rhand_geomids = set(physics.bind(self._rhand_geoms).element_id) + self._hand_geomids = self._lhand_geomids | self._rhand_geomids + self._randomize_targets(physics) + self._must_randomize_targets = False + for target in itertools.chain(*self._targets): + target._specific_collision_geom_ids = self._hand_geomids # pylint: disable=protected-access + + def before_step(self, physics, action, random_state): + super(TwoTouch, self).before_step(physics, action, random_state) + if self._must_randomize_targets: + self._randomize_targets(physics) + self._must_randomize_targets = False + + def should_terminate_episode(self, physics): + failure_termination = False + if failure_termination: + self._discount = 0.0 + return True + else: + return False + + def get_reward(self, physics): + reward = self._aliveness_reward + lhand_pos = physics.bind(self._lhand_body).xpos + rhand_pos = physics.bind(self._rhand_body).xpos + target_pos = physics.bind(self._targets[0][0].geom).xpos + lhand_rew = np.exp(-3.*sum(np.abs(lhand_pos-target_pos))) + rhand_rew = np.exp(-3.*sum(np.abs(rhand_pos-target_pos))) + closeness_reward = np.maximum(lhand_rew, rhand_rew) + reward += .01*closeness_reward*self._target_type_rewards[0] + if self._state_logic == TwoTouchState.PRE_TOUCH: + # touch the first time + for target_type, targets in enumerate(self._targets): + for i, target in enumerate(targets): + if (target.activated[0] and + not self._target_rewarded_once[target_type][i]): + self._first_touch_time = physics.time() + self._state_logic = TwoTouchState.TOUCHED_ONCE + self._target_rewarded_once[target_type][i] = True + reward += self._target_type_rewards[target_type] + elif self._state_logic == TwoTouchState.TOUCHED_ONCE: + for target_type, targets in enumerate(self._targets): + for i, target in enumerate(targets): + if (target.activated[1] and + not self._target_rewarded_twice[target_type][i]): + self._second_touch_time = physics.time() + self._state_logic = TwoTouchState.TOUCHED_TWICE + self._target_rewarded_twice[target_type][i] = True + # check if touched too soon + if ((self._second_touch_time - self._first_touch_time) < + (self._touch_interval - self._interval_tolerance)): + self._do_time_out = True + self._state_logic = TwoTouchState.TOUCHED_TOO_SOON + # check if touched at correct time + elif ((self._second_touch_time - self._first_touch_time) <= + (self._touch_interval + self._interval_tolerance)): + reward += self._target_type_rewards[target_type] + # check if no second touch within time interval + if ((physics.time() - self._first_touch_time) > + (self._touch_interval + self._interval_tolerance)): + self._do_time_out = True + self._state_logic = TwoTouchState.NO_SECOND_TOUCH + self._second_touch_time = physics.time() + elif (self._state_logic == TwoTouchState.TOUCHED_TWICE or + self._state_logic == TwoTouchState.TOUCHED_TOO_SOON or + self._state_logic == TwoTouchState.NO_SECOND_TOUCH): + # hold here due to timeout + if self._do_time_out: + if physics.time() > (self._second_touch_time + self._failure_timeout): + self._do_time_out = False + # reset/re-randomize + elif physics.time() > (self._second_touch_time + self._reset_delay): + self._must_randomize_targets = True + return reward + + def get_discount(self, physics): + del physics + return self._discount diff --git a/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py new file mode 100644 index 0000000..2132409 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/tasks/reach_test.py @@ -0,0 +1,66 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for locomotion.tasks.reach.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +from absl.testing import absltest + +from dm_control import composer +from dm_control.locomotion.arenas import floors +from dm_control.locomotion.props import target_sphere +from dm_control.locomotion.tasks import reach +from dm_control.locomotion.walkers import rodent + +import numpy as np + +_CONTROL_TIMESTEP = .02 +_PHYSICS_TIMESTEP = 0.001 + + +class ReachTest(absltest.TestCase): + + def test_observables(self): + walker = rodent.Rat() + + arena = floors.Floor( + size=(10., 10.), + aesthetic='outdoor_natural') + + task = reach.TwoTouch( + walker=walker, + arena=arena, + target_builders=[ + functools.partial(target_sphere.TargetSphereTwoTouch, radius=0.025), + ], + randomize_spawn_rotation=True, + target_type_rewards=[25.], + shuffle_target_builders=False, + target_area=(1.5, 1.5), + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP, + ) + random_state = np.random.RandomState(12345) + env = composer.Environment(task, random_state=random_state) + timestep = env.reset() + + self.assertIn('walker/joints_pos', timestep.observation) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py new file mode 100644 index 0000000..939a03b --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Walkers for Locomotion tasks.""" + +from dm_control.locomotion.walkers.rodent import Rat diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml new file mode 100644 index 0000000..64c9db9 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/humanoid_CMU.xml @@ -0,0 +1,297 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml new file mode 100644 index 0000000..7268f3c --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent.xml @@ -0,0 +1,609 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn new file mode 100644 index 0000000..53a7e7a Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/walkers/assets/rodent_walker_skin.skn differ diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py new file mode 100644 index 0000000..8890dab --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base.py @@ -0,0 +1,196 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for Walkers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections + +from dm_control import composer +from dm_control.composer.observation import observable + +from dm_env import specs +import numpy as np +import six + + +def _make_readonly_float64_copy(value): + if np.isscalar(value): + return np.float64(value) + else: + out = np.array(value, dtype=np.float64) + out.flags.writeable = False + return out + + +class WalkerPose(collections.namedtuple( + 'WalkerPose', ('qpos', 'xpos', 'xquat'))): + """A named tuple representing a walker's joint and Cartesian pose.""" + + __slots__ = () + + def __new__(cls, qpos=None, xpos=(0, 0, 0), xquat=(1, 0, 0, 0)): + """Creates a new WalkerPose. + + Args: + qpos: The joint position for the pose, or `None` if the `qpos0` values in + the `mjModel` should be used. + xpos: A Cartesian displacement, for example if the walker should be lifted + or lowered by a specific amount for this pose. + xquat: A quaternion displacement for the root body. + + Returns: + A new instance of `WalkerPose`. + """ + return super(WalkerPose, cls).__new__( + cls, + qpos=_make_readonly_float64_copy(qpos) if qpos is not None else None, + xpos=_make_readonly_float64_copy(xpos), + xquat=_make_readonly_float64_copy(xquat)) + + def __eq__(self, other): + return (np.all(self.qpos == other.qpos) and + np.all(self.xpos == other.xpos) and + np.all(self.xquat == other.xquat)) + + +@six.add_metaclass(abc.ABCMeta) +class Walker(composer.Robot): + """Abstract base class for Walker robots.""" + + def create_root_joints(self, attachment_frame): + attachment_frame.add('freejoint') + + def _build_observables(self): + return WalkerObservables(self) + + def transform_vec_to_egocentric_frame(self, physics, vec_in_world_frame): + """Linearly transforms a world-frame vector into walker's egocentric frame. + + Note that this function does not perform an affine transformation of the + vector. In other words, the input vector is assumed to be specified with + respect to the same origin as this walker's egocentric frame. This function + can also be applied to matrices whose innermost dimensions are either 2 or + 3. In this case, a matrix with the same leading dimensions is returned + where the innermost vectors are replaced by their values computed in the + egocentric frame. + + Args: + physics: An `mjcf.Physics` instance. + vec_in_world_frame: A NumPy array with last dimension of shape (2,) or + (3,) that represents a vector quantity in the world frame. + + Returns: + The same quantity as `vec_in_world_frame` but reexpressed in this + entity's egocentric frame. The returned np.array has the same shape as + np.asarray(vec_in_world_frame). + + Raises: + ValueError: if `vec_in_world_frame` does not have shape ending with (2,) + or (3,). + """ + return super(Walker, self).global_vector_to_local_frame( + physics, vec_in_world_frame) + + def transform_xmat_to_egocentric_frame(self, physics, xmat): + """Transforms another entity's `xmat` into this walker's egocentric frame. + + This function takes another entity's (E) xmat, which is an SO(3) matrix + from E's frame to the world frame, and turns it to a matrix that transforms + from E's frame into this walker's egocentric frame. + + Args: + physics: An `mjcf.Physics` instance. + xmat: A NumPy array of shape (3, 3) or (9,) that represents another + entity's xmat. + + Returns: + The `xmat` reexpressed in this entity's egocentric frame. The returned + np.array has the same shape as np.asarray(xmat). + + Raises: + ValueError: if `xmat` does not have shape (3, 3) or (9,). + """ + return super(Walker, self).global_xmat_to_local_frame(physics, xmat) + + @abc.abstractproperty + def root_body(self): + raise NotImplementedError + + @abc.abstractproperty + def observable_joints(self): + raise NotImplementedError + + @property + def action_spec(self): + minimum, maximum = zip(*[ + a.ctrlrange if a.ctrlrange is not None else (-1., 1.) + for a in self.actuators + ]) + return specs.BoundedArray( + shape=(len(self.actuators),), + dtype=np.float, + minimum=minimum, + maximum=maximum, + name='\t'.join([actuator.name for actuator in self.actuators])) + + def apply_action(self, physics, action, random_state): + """Apply action to walker's actuators.""" + del random_state + physics.bind(self.actuators).ctrl = action + + +class WalkerObservables(composer.Observables): + """Base class for Walker obserables.""" + + @composer.observable + def joints_pos(self): + return observable.MJCFFeature('qpos', self._entity.observable_joints) + + @composer.observable + def sensors_gyro(self): + return observable.MJCFFeature('sensordata', + self._entity.mjcf_model.sensor.gyro) + + @composer.observable + def sensors_accelerometer(self): + return observable.MJCFFeature('sensordata', + self._entity.mjcf_model.sensor.accelerometer) + + # Semantic groupings of Walker observables. + def _collect_from_attachments(self, attribute_name): + out = [] + for entity in self._entity.iter_entities(exclude_self=True): + out.extend(getattr(entity.observables, attribute_name, [])) + return out + + @property + def proprioception(self): + return ([self.joints_pos] + + self._collect_from_attachments('proprioception')) + + @property + def kinematic_sensors(self): + return ([self.sensors_gyro, self.sensors_accelerometer] + + self._collect_from_attachments('kinematic_sensors')) + + @property + def dynamic_sensors(self): + return self._collect_from_attachments('dynamic_sensors') + diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py new file mode 100644 index 0000000..0794cec --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/base_test.py @@ -0,0 +1,98 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for dm_control.locomotion.walkers.base.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import mjcf +from dm_control.locomotion.walkers import base + +import numpy as np + + +class FakeWalker(base.Walker): + + def _build(self): + self._mjcf_root = mjcf.RootElement(model='walker') + self._torso_body = self._mjcf_root.worldbody.add( + 'body', name='torso', xyaxes=[0, 1, 0, -1, 0, 0]) + + @property + def mjcf_model(self): + return self._mjcf_root + + @property + def actuators(self): + return [] + + @property + def root_body(self): + return self._torso_body + + @property + def observable_joints(self): + return [] + + +class BaseWalkerTest(absltest.TestCase): + + def testTransformVectorToEgocentricFrame(self): + walker = FakeWalker() + physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model) + + # 3D vectors + np.testing.assert_allclose( + walker.transform_vec_to_egocentric_frame(physics, [0, 1, 0]), [1, 0, 0], + atol=1e-10) + np.testing.assert_allclose( + walker.transform_vec_to_egocentric_frame(physics, [-1, 0, 0]), + [0, 1, 0], + atol=1e-10) + np.testing.assert_allclose( + walker.transform_vec_to_egocentric_frame(physics, [0, 0, 1]), [0, 0, 1], + atol=1e-10) + + # 2D vectors; z-component is ignored + np.testing.assert_allclose( + walker.transform_vec_to_egocentric_frame(physics, [0, 1]), [1, 0], + atol=1e-10) + np.testing.assert_allclose( + walker.transform_vec_to_egocentric_frame(physics, [-1, 0]), [0, 1], + atol=1e-10) + + def testTransformMatrixToEgocentricFrame(self): + walker = FakeWalker() + physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model) + + rotation_atob = np.array([[0, 1, 0], [0, 0, -1], [-1, 0, 0]]) + ego_rotation_atob = np.array([[0, 0, -1], [0, -1, 0], [-1, 0, 0]]) + + np.testing.assert_allclose( + walker.transform_xmat_to_egocentric_frame(physics, rotation_atob), + ego_rotation_atob, atol=1e-10) + + flat_rotation_atob = np.reshape(rotation_atob, -1) + flat_rotation_ego_atob = np.reshape(ego_rotation_atob, -1) + np.testing.assert_allclose( + walker.transform_xmat_to_egocentric_frame(physics, flat_rotation_atob), + flat_rotation_ego_atob, atol=1e-10) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py new file mode 100644 index 0000000..a8000b4 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid.py @@ -0,0 +1,352 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""A CMU humanoid walker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import os + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable +from dm_control.locomotion.walkers import base +from dm_control.locomotion.walkers import legacy_base +from dm_control.locomotion.walkers import scaled_actuators +from dm_control.mujoco import wrapper as mj_wrapper +import numpy as np +import six +from six.moves import zip + +_XML_PATH = os.path.join(os.path.dirname(__file__), 'assets/humanoid_CMU.xml') + +_WALKER_GEOM_GROUP = 2 + +_CMU_MOCAP_JOINTS = ( + 'lfemurrz', 'lfemurry', 'lfemurrx', 'ltibiarx', 'lfootrz', 'lfootrx', + 'ltoesrx', 'rfemurrz', 'rfemurry', 'rfemurrx', 'rtibiarx', 'rfootrz', + 'rfootrx', 'rtoesrx', 'lowerbackrz', 'lowerbackry', 'lowerbackrx', + 'upperbackrz', 'upperbackry', 'upperbackrx', 'thoraxrz', 'thoraxry', + 'thoraxrx', 'lowerneckrz', 'lowerneckry', 'lowerneckrx', 'upperneckrz', + 'upperneckry', 'upperneckrx', 'headrz', 'headry', 'headrx', 'lclaviclerz', + 'lclaviclery', 'lhumerusrz', 'lhumerusry', 'lhumerusrx', 'lradiusrx', + 'lwristry', 'lhandrz', 'lhandrx', 'lfingersrx', 'lthumbrz', 'lthumbrx', + 'rclaviclerz', 'rclaviclery', 'rhumerusrz', 'rhumerusry', 'rhumerusrx', + 'rradiusrx', 'rwristry', 'rhandrz', 'rhandrx', 'rfingersrx', 'rthumbrz', + 'rthumbrx') + + +# pylint: disable=bad-whitespace +PositionActuatorParams = collections.namedtuple( + 'PositionActuatorParams', ['name', 'forcerange', 'kp']) +_POSITION_ACTUATORS = [ + PositionActuatorParams('headrx', [-20, 20 ], 20 ), + PositionActuatorParams('headry', [-20, 20 ], 20 ), + PositionActuatorParams('headrz', [-20, 20 ], 20 ), + PositionActuatorParams('lclaviclery', [-20, 20 ], 20 ), + PositionActuatorParams('lclaviclerz', [-20, 20 ], 20 ), + PositionActuatorParams('lfemurrx', [-120, 120], 120), + PositionActuatorParams('lfemurry', [-80, 80 ], 80 ), + PositionActuatorParams('lfemurrz', [-80, 80 ], 80 ), + PositionActuatorParams('lfingersrx', [-20, 20 ], 20 ), + PositionActuatorParams('lfootrx', [-50, 50 ], 50 ), + PositionActuatorParams('lfootrz', [-50, 50 ], 50 ), + PositionActuatorParams('lhandrx', [-20, 20 ], 20 ), + PositionActuatorParams('lhandrz', [-20, 20 ], 20 ), + PositionActuatorParams('lhumerusrx', [-60, 60 ], 60 ), + PositionActuatorParams('lhumerusry', [-60, 60 ], 60 ), + PositionActuatorParams('lhumerusrz', [-60, 60 ], 60 ), + PositionActuatorParams('lowerbackrx', [-120, 120], 150), + PositionActuatorParams('lowerbackry', [-120, 120], 150), + PositionActuatorParams('lowerbackrz', [-120, 120], 150), + PositionActuatorParams('lowerneckrx', [-20, 20 ], 20 ), + PositionActuatorParams('lowerneckry', [-20, 20 ], 20 ), + PositionActuatorParams('lowerneckrz', [-20, 20 ], 20 ), + PositionActuatorParams('lradiusrx', [-60, 60 ], 60 ), + PositionActuatorParams('lthumbrx', [-20, 20 ], 20) , + PositionActuatorParams('lthumbrz', [-20, 20 ], 20 ), + PositionActuatorParams('ltibiarx', [-80, 80 ], 80 ), + PositionActuatorParams('ltoesrx', [-20, 20 ], 20 ), + PositionActuatorParams('lwristry', [-20, 20 ], 20 ), + PositionActuatorParams('rclaviclery', [-20, 20 ], 20 ), + PositionActuatorParams('rclaviclerz', [-20, 20 ], 20 ), + PositionActuatorParams('rfemurrx', [-120, 120], 120), + PositionActuatorParams('rfemurry', [-80, 80 ], 80 ), + PositionActuatorParams('rfemurrz', [-80, 80 ], 80 ), + PositionActuatorParams('rfingersrx', [-20, 20 ], 20 ), + PositionActuatorParams('rfootrx', [-50, 50 ], 50 ), + PositionActuatorParams('rfootrz', [-50, 50 ], 50 ), + PositionActuatorParams('rhandrx', [-20, 20 ], 20 ), + PositionActuatorParams('rhandrz', [-20, 20 ], 20 ), + PositionActuatorParams('rhumerusrx', [-60, 60 ], 60 ), + PositionActuatorParams('rhumerusry', [-60, 60 ], 60 ), + PositionActuatorParams('rhumerusrz', [-60, 60 ], 60 ), + PositionActuatorParams('rradiusrx', [-60, 60 ], 60 ), + PositionActuatorParams('rthumbrx', [-20, 20 ], 20 ), + PositionActuatorParams('rthumbrz', [-20, 20 ], 20 ), + PositionActuatorParams('rtibiarx', [-80, 80 ], 80 ), + PositionActuatorParams('rtoesrx', [-20, 20 ], 20 ), + PositionActuatorParams('rwristry', [-20, 20 ], 20 ), + PositionActuatorParams('thoraxrx', [-80, 80 ], 100), + PositionActuatorParams('thoraxry', [-80, 80 ], 100), + PositionActuatorParams('thoraxrz', [-80, 80 ], 100), + PositionActuatorParams('upperbackrx', [-80, 80 ], 80 ), + PositionActuatorParams('upperbackry', [-80, 80 ], 80 ), + PositionActuatorParams('upperbackrz', [-80, 80 ], 80 ), + PositionActuatorParams('upperneckrx', [-20, 20 ], 20 ), + PositionActuatorParams('upperneckry', [-20, 20 ], 20 ), + PositionActuatorParams('upperneckrz', [-20, 20 ], 20 ), +] +# pylint: enable=bad-whitespace + +_UPRIGHT_POS = (0.0, 0.0, 0.94) +_UPRIGHT_QUAT = (0.859, 1.0, 1.0, 0.859) + +# Height of head above which the humanoid is considered standing. +_STAND_HEIGHT = 1.5 + +_TORQUE_THRESHOLD = 60 + + +@six.add_metaclass(abc.ABCMeta) +class _CMUHumanoidBase(legacy_base.Walker): + """The abstract base class for walkers compatible with the CMU humanoid.""" + + def _build(self, + name='walker', + marker_rgba=None, + initializer=None): + self._mjcf_root = mjcf.from_path(self._xml_path) + if name: + self._mjcf_root.model = name + + # Set corresponding marker color if specified. + if marker_rgba is not None: + for geom in self.marker_geoms: + geom.set_attributes(rgba=marker_rgba) + + self._actuator_order = np.argsort(_CMU_MOCAP_JOINTS) + self._inverse_order = np.argsort(self._actuator_order) + + super(_CMUHumanoidBase, self)._build(initializer=initializer) + + def _build_observables(self): + return CMUHumanoidObservables(self) + + @abc.abstractproperty + def _xml_path(self): + raise NotImplementedError + + @composer.cached_property + def mocap_joints(self): + return tuple( + self._mjcf_root.find('joint', name) for name in _CMU_MOCAP_JOINTS) + + @property + def actuator_order(self): + """Index of joints from the CMU mocap dataset sorted alphabetically by name. + + Actuators in this walkers are ordered alphabetically by name. This property + provides a mapping between from actuator ordering to canonical CMU ordering. + + Returns: + A list of integers corresponding to joint indices from the CMU dataset. + Specifically, the n-th element in the list is the index of the CMU joint + index that corresponds to the n-th actuator in this walker. + """ + return self._actuator_order + + @property + def actuator_to_joint_order(self): + """Index of actuators corresponding to each CMU mocap joint. + + Actuators in this walkers are ordered alphabetically by name. This property + provides a mapping between from canonical CMU ordering to actuator ordering. + + Returns: + A list of integers corresponding to actuator indices within this walker. + Specifically, the n-th element in the list is the index of the actuator + in this walker that corresponds to the n-th joint from the CMU mocap + dataset. + """ + return self._inverse_order + + @property + def upright_pose(self): + return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT) + + @property + def mjcf_model(self): + return self._mjcf_root + + @composer.cached_property + def actuators(self): + return tuple(self._mjcf_root.find_all('actuator')) + + @composer.cached_property + def root_body(self): + return self._mjcf_root.find('body', 'root') + + @composer.cached_property + def head(self): + return self._mjcf_root.find('body', 'head') + + @composer.cached_property + def left_arm_root(self): + return self._mjcf_root.find('body', 'lclavicle') + + @composer.cached_property + def right_arm_root(self): + return self._mjcf_root.find('body', 'rclavicle') + + @composer.cached_property + def ground_contact_geoms(self): + return tuple(self._mjcf_root.find('body', 'lfoot').find_all('geom') + + self._mjcf_root.find('body', 'rfoot').find_all('geom')) + + @composer.cached_property + def standing_height(self): + return _STAND_HEIGHT + + @composer.cached_property + def end_effectors(self): + return (self._mjcf_root.find('body', 'rradius'), + self._mjcf_root.find('body', 'lradius'), + self._mjcf_root.find('body', 'rfoot'), + self._mjcf_root.find('body', 'lfoot')) + + @composer.cached_property + def observable_joints(self): + return tuple(actuator.joint for actuator in self.actuators + if actuator.joint is not None) + + @composer.cached_property + def bodies(self): + return tuple(self._mjcf_root.find_all('body')) + + @composer.cached_property + def egocentric_camera(self): + return self._mjcf_root.find('camera', 'egocentric') + + @composer.cached_property + def body_camera(self): + return self._mjcf_root.find('camera', 'bodycam') + + @property + def marker_geoms(self): + return (self._mjcf_root.find('geom', 'rradius'), + self._mjcf_root.find('geom', 'lradius')) + + +class CMUHumanoid(_CMUHumanoidBase): + """A CMU humanoid walker.""" + + @property + def _xml_path(self): + return _XML_PATH + + +class CMUHumanoidPositionControlled(CMUHumanoid): + """A position-controlled CMU humanoid with control range scaled to [-1, 1].""" + + def _build(self, *args, **kwargs): + super(CMUHumanoidPositionControlled, self)._build(*args, **kwargs) + self._mjcf_root.default.general.forcelimited = 'true' + self._mjcf_root.actuator.motor.clear() + for actuator_params in _POSITION_ACTUATORS: + associated_joint = self._mjcf_root.find('joint', actuator_params.name) + scaled_actuators.add_position_actuator( + name=actuator_params.name, + target=associated_joint, + kp=actuator_params.kp, + qposrange=associated_joint.range, + ctrlrange=(-1, 1), + forcerange=actuator_params.forcerange) + limits = zip(*(actuator.joint.range for actuator in self.actuators)) # pylint: disable=not-an-iterable + lower, upper = (np.array(limit) for limit in limits) + self._scale = upper - lower + self._offset = upper + lower + + def cmu_pose_to_actuation(self, target_pose): + """Creates the control signal corresponding a CMU mocap joints pose. + + Args: + target_pose: An array containing the target position for each joint. + These must be given in "canonical CMU order" rather than "qpos order", + i.e. the order of `target_pose[self.actuator_order]` should correspond + to the order of `physics.bind(self.actuators).ctrl`. + + Returns: + An array of the same shape as `target_pose` containing inputs for position + controllers. Writing these values into `physics.bind(self.actuators).ctrl` + will cause the actuators to drive joints towards `target_pose`. + """ + return (2 * target_pose[self.actuator_order] - self._offset) / self._scale + + +class CMUHumanoidObservables(legacy_base.WalkerObservables): + """Observables for the Humanoid.""" + + @composer.observable + def body_camera(self): + options = mj_wrapper.MjvOption() + + # Don't render this walker's geoms. + options.geomgroup[_WALKER_GEOM_GROUP] = 0 + return observable.MJCFCamera( + self._entity.body_camera, width=64, height=64, scene_option=options) + + @composer.observable + def head_height(self): + return observable.MJCFFeature('xpos', self._entity.head)[2] + + @composer.observable + def sensors_torque(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.torque, + corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD)) + + @composer.observable + def actuator_activation(self): + return observable.MJCFFeature('act', + self._entity.mjcf_model.find_all('actuator')) + + @composer.observable + def appendages_pos(self): + """Equivalent to `end_effectors_pos` with the head's position appended.""" + def relative_pos_in_egocentric_frame(physics): + end_effectors_with_head = ( + self._entity.end_effectors + (self._entity.head,)) + end_effector = physics.bind(end_effectors_with_head).xpos + torso = physics.bind(self._entity.root_body).xpos + xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3)) + return np.reshape(np.dot(end_effector - torso, xmat), -1) + return observable.Generic(relative_pos_in_egocentric_frame) + + @property + def proprioception(self): + return [ + self.joints_pos, + self.joints_vel, + self.actuator_activation, + self.body_height, + self.end_effectors_pos, + self.appendages_pos, + self.world_zaxis + ] + self._collect_from_attachments('proprioception') diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py new file mode 100644 index 0000000..677bd1a --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/cmu_humanoid_test.py @@ -0,0 +1,166 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for the CMU humanoid.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import mjcf +from dm_control.composer.observation.observable import base as observable_base +from dm_control.locomotion.walkers import cmu_humanoid +import numpy as np +from six.moves import range +from six.moves import zip + + +class CMUHumanoidTest(parameterized.TestCase): + + @parameterized.parameters([ + cmu_humanoid.CMUHumanoid, + cmu_humanoid.CMUHumanoidPositionControlled, + ]) + def test_can_compile_and_step_simulation(self, walker_type): + walker = walker_type() + physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model) + for _ in range(100): + physics.step() + + @parameterized.parameters([ + cmu_humanoid.CMUHumanoid, + cmu_humanoid.CMUHumanoidPositionControlled, + ]) + def test_actuators_sorted_alphabetically(self, walker_type): + walker = walker_type() + actuator_names = [ + actuator.name for actuator in walker.mjcf_model.find_all('actuator')] + np.testing.assert_array_equal(actuator_names, sorted(actuator_names)) + + def test_actuator_to_mocap_joint_mapping(self): + walker = cmu_humanoid.CMUHumanoid() + + with self.subTest('Forward mapping'): + for actuator_num, cmu_mocap_joint_num in enumerate(walker.actuator_order): + self.assertEqual(walker.actuator_to_joint_order[cmu_mocap_joint_num], + actuator_num) + + with self.subTest('Inverse mapping'): + for cmu_mocap_joint_num, actuator_num in enumerate( + walker.actuator_to_joint_order): + self.assertEqual(walker.actuator_order[actuator_num], + cmu_mocap_joint_num) + + def test_cmu_humanoid_position_controlled_has_correct_actuators(self): + walker_torque = cmu_humanoid.CMUHumanoid() + walker_pos = cmu_humanoid.CMUHumanoidPositionControlled() + + actuators_torque = walker_torque.mjcf_model.find_all('actuator') + actuators_pos = walker_pos.mjcf_model.find_all('actuator') + + actuator_pos_params = { + params.name: params for params in cmu_humanoid._POSITION_ACTUATORS} + + self.assertEqual(len(actuators_torque), len(actuators_pos)) + + for actuator_torque, actuator_pos in zip(actuators_torque, actuators_pos): + self.assertEqual(actuator_pos.name, actuator_torque.name) + self.assertEqual(actuator_pos.joint.full_identifier, + actuator_torque.joint.full_identifier) + self.assertEqual(actuator_pos.tag, 'general') + self.assertEqual(actuator_pos.ctrllimited, 'true') + np.testing.assert_array_equal(actuator_pos.ctrlrange, (-1, 1)) + + expected_params = actuator_pos_params[actuator_pos.name] + self.assertEqual(actuator_pos.biasprm[1], -expected_params.kp) + np.testing.assert_array_equal(actuator_pos.forcerange, + expected_params.forcerange) + + @parameterized.parameters([ + 'body_camera', + 'egocentric_camera', + 'head', + 'left_arm_root', + 'right_arm_root', + 'root_body', + ]) + def test_get_element_property(self, name): + attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name) + self.assertIsInstance(attribute_value, mjcf.Element) + + @parameterized.parameters([ + 'actuators', + 'bodies', + 'end_effectors', + 'marker_geoms', + 'mocap_joints', + 'observable_joints', + ]) + def test_get_element_tuple_property(self, name): + attribute_value = getattr(cmu_humanoid.CMUHumanoid(), name) + self.assertNotEmpty(attribute_value) + for item in attribute_value: + self.assertIsInstance(item, mjcf.Element) + + def test_set_name(self): + name = 'fred' + walker = cmu_humanoid.CMUHumanoid(name=name) + self.assertEqual(walker.mjcf_model.model, name) + + def test_set_marker_rgba(self): + marker_rgba = (1., 0., 1., 0.5) + walker = cmu_humanoid.CMUHumanoid(marker_rgba=marker_rgba) + for marker_geom in walker.marker_geoms: + np.testing.assert_array_equal(marker_geom.rgba, marker_rgba) + + @parameterized.parameters( + 'actuator_activation', + 'appendages_pos', + 'body_camera', + 'head_height', + 'sensors_torque', + ) + def test_evaluate_observable(self, name): + walker = cmu_humanoid.CMUHumanoid() + observable = getattr(walker.observables, name) + physics = mjcf.Physics.from_mjcf_model(walker.mjcf_model) + observation = observable(physics) + self.assertIsInstance(observation, (float, np.ndarray)) + + def test_proprioception(self): + walker = cmu_humanoid.CMUHumanoid() + for item in walker.observables.proprioception: + self.assertIsInstance(item, observable_base.Observable) + + def test_cmu_pose_to_actuation(self): + walker = cmu_humanoid.CMUHumanoidPositionControlled() + random_state = np.random.RandomState(123) + + expected_actuation = random_state.uniform(-1, 1, len(walker.actuator_order)) + + cmu_limits = zip(*(joint.range for joint in walker.mocap_joints)) + cmu_lower, cmu_upper = (np.array(limit) for limit in cmu_limits) + cmu_pose = cmu_lower + (cmu_upper - cmu_lower) * ( + 1 + expected_actuation[walker.actuator_to_joint_order]) / 2 + + actual_actuation = walker.cmu_pose_to_actuation(cmu_pose) + + np.testing.assert_allclose(actual_actuation, expected_actuation) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py new file mode 100644 index 0000000..766fa7c --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/initializers/__init__.py @@ -0,0 +1,61 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Initializers for the locomotion walkers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import numpy as np +import six + + +@six.add_metaclass(abc.ABCMeta) +class WalkerInitializer(object): + """The abstract base class for a walker initializer.""" + + @abc.abstractmethod + def initialize_pose(self, physics, walker, random_state): + raise NotImplementedError + + +class UprightInitializer(WalkerInitializer): + """An initializer that uses the walker-declared upright pose.""" + + def initialize_pose(self, physics, walker, random_state): + all_joints_binding = physics.bind(walker.mjcf_model.find_all('joint')) + qpos, xpos, xquat = walker.upright_pose + if qpos is None: + all_joints_binding.qpos = all_joints_binding.qpos0 + else: + all_joints_binding.qpos = qpos + walker.set_pose(physics, position=xpos, quaternion=xquat) + walker.set_velocity( + physics, velocity=np.zeros(3), angular_velocity=np.zeros(3)) + + +class RandomlySampledInitializer(WalkerInitializer): + """Initializer that random selects between many initializers.""" + + def __init__(self, initializers): + self._initializers = initializers + self.num_initializers = len(initializers) + + def initialize_pose(self, physics, walker, random_state): + random_initalizer_idx = np.random.randint(0, self.num_initializers) + self._initializers[random_initalizer_idx].initialize_pose( + physics, walker, random_state) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py new file mode 100644 index 0000000..995fe02 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/legacy_base.py @@ -0,0 +1,347 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Base class for Walkers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +from dm_control import composer +from dm_control.composer.observation import observable +from dm_control.locomotion.walkers import base +from dm_control.locomotion.walkers import initializers +from dm_control.mujoco.wrapper.mjbindings import mjlib + +import numpy as np + +_RANGEFINDER_SCALE = 10.0 +_TOUCH_THRESHOLD = 1e-3 + + +class Walker(base.Walker): + """Legacy base class for Walker robots.""" + + def _build(self, initializer=None): + try: + self._initializers = tuple(initializer) + except TypeError: + self._initializers = (initializer or initializers.UprightInitializer(),) + + @property + def upright_pose(self): + return base.WalkerPose() + + def _build_observables(self): + return WalkerObservables(self) + + def reinitialize_pose(self, physics, random_state): + for initializer in self._initializers: + initializer.initialize_pose(physics, self, random_state) + + def aliveness(self, physics): + """A measure of the aliveness of the walker. + + Aliveness measure could be used for deciding on termination (ant flipped + over and it's impossible for it to recover), or used as a shaping reward + to maintain an alive pose that we desired (humanoids remaining upright). + + Args: + physics: an instance of `Physics`. + + Returns: + a `float` in the range of [-1., 0.] where -1 means not alive and 0. means + alive. In walkers for which the concept of aliveness does not make sense, + the default implementation is to always return 0.0. + """ + return 0. + + @abc.abstractproperty + def ground_contact_geoms(self): + """Geoms in this walker that are expected to be in contact with the ground. + + This property is used by some tasks to determine contact-based failure + termination. It should only contain geoms that are expected to be in + contact with the ground during "normal" locomotion. For example, for a + humanoid model, this property would be expected to contain only the geoms + that make up the two feet. + + Note that certain specialized tasks may also allow geoms that are not listed + here to be in contact with the ground. For example, a humanoid cartwheel + task would also allow the hands to touch the ground in addition to the feet. + """ + raise NotImplementedError + + def after_compile(self, physics, unused_random_state): + super(Walker, self).after_compile(physics, unused_random_state) + self._end_effector_geom_ids = set() + for eff_body in self.end_effectors: + eff_geom = eff_body.find_all('geom') + self._end_effector_geom_ids |= set(physics.bind(eff_geom).element_id) + self._body_geom_ids = set( + physics.bind(geom).element_id + for geom in self.mjcf_model.find_all('geom')) + self._body_geom_ids.difference_update(self._end_effector_geom_ids) + + @property + def end_effector_geom_ids(self): + return self._end_effector_geom_ids + + @property + def body_geom_ids(self): + return self._body_geom_ids + + def end_effector_contacts(self, physics): + """Collect the contacts with the end effectors. + + This function returns any contacts being made with any of the end effectors, + both the other geom with which contact is being made as well as the + magnitude. + + Args: + physics: an instance of `Physics`. + + Returns: + a dict with as key a tuple of geom ids, of which one is an end effector, + and as value the total magnitude of all contacts between these geoms + """ + return self.collect_contacts(physics, self._end_effector_geom_ids) + + def body_contacts(self, physics): + """Collect the contacts with the body. + + This function returns any contacts being made with any of body geoms, except + the end effectors, both the other geom with which contact is being made as + well as the magnitude. + + Args: + physics: an instance of `Physics`. + + Returns: + a dict with as key a tuple of geom ids, of which one is a body geom, + and as value the total magnitude of all contacts between these geoms + """ + return self.collect_contacts(physics, self._body_geom_ids) + + def collect_contacts(self, physics, geom_ids): + contacts = {} + forcetorque = np.zeros(6) + for i, contact in enumerate(physics.data.contact): + if ((contact.geom1 in geom_ids) or + (contact.geom2 in geom_ids)) and contact.dist < contact.includemargin: + mjlib.mj_contactForce(physics.model.ptr, physics.data.ptr, i, + forcetorque) + contacts[(contact.geom1, contact.geom2)] = (forcetorque[0] + + contacts.get( + (contact.geom1, + contact.geom2), 0.)) + return contacts + + @abc.abstractproperty + def end_effectors(self): + raise NotImplementedError + + @abc.abstractproperty + def egocentric_camera(self): + raise NotImplementedError + + @composer.cached_property + def touch_sensors(self): + return self._mjcf_root.sensor.get_children('touch') + + @property + def prev_action(self): + """Returns the actuation actions applied in the previous step. + + Concrete walker implementations should provide caching mechanism themselves + in order to access this observable (for example, through `apply_action`). + """ + raise NotImplementedError + + def after_substep(self, physics, random_state): + del random_state # Unused. + # As of MuJoCo v2.0, updates to `mjData->subtree_linvel` will be skipped + # unless these quantities are needed by the simulation. We need these in + # order to calculate `torso_{x,y}vel`, so we therefore call `mj_subtreeVel` + # explicitly. + # TODO(b/123065920): Consider using a `subtreelinvel` sensor instead. + mjlib.mj_subtreeVel(physics.model.ptr, physics.data.ptr) + + +class WalkerObservables(base.WalkerObservables): + """Legacy base class for Walker obserables.""" + + @composer.observable + def joints_vel(self): + return observable.MJCFFeature('qvel', self._entity.observable_joints) + + @composer.observable + def body_height(self): + return observable.MJCFFeature('xpos', self._entity.root_body)[2] + + @composer.observable + def end_effectors_pos(self): + """Position of end effectors relative to torso, in the egocentric frame.""" + def relative_pos_in_egocentric_frame(physics): + end_effector = physics.bind(self._entity.end_effectors).xpos + torso = physics.bind(self._entity.root_body).xpos + xmat = np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3)) + return np.reshape(np.dot(end_effector - torso, xmat), -1) + return observable.Generic(relative_pos_in_egocentric_frame) + + @composer.observable + def world_zaxis(self): + """The world's z-vector in this Walker's torso frame.""" + return observable.MJCFFeature('xmat', self._entity.root_body)[6:] + + @composer.observable + def sensors_velocimeter(self): + return observable.MJCFFeature('sensordata', + self._entity.mjcf_model.sensor.velocimeter) + + @composer.observable + def sensors_force(self): + return observable.MJCFFeature('sensordata', + self._entity.mjcf_model.sensor.force) + + @composer.observable + def sensors_torque(self): + return observable.MJCFFeature('sensordata', + self._entity.mjcf_model.sensor.torque) + + @composer.observable + def sensors_touch(self): + return observable.MJCFFeature( + 'sensordata', + self._entity.mjcf_model.sensor.touch, + corruptor= + lambda v, random_state: np.array(v > _TOUCH_THRESHOLD, dtype=np.float)) + + @composer.observable + def sensors_rangefinder(self): + def tanh_rangefinder(physics): + raw = physics.bind(self._entity.mjcf_model.sensor.rangefinder).sensordata + raw = np.array(raw) + raw[raw == -1.0] = np.inf + return _RANGEFINDER_SCALE * np.tanh(raw / _RANGEFINDER_SCALE) + return observable.Generic(tanh_rangefinder) + + @composer.observable + def egocentric_camera(self): + return observable.MJCFCamera(self._entity.egocentric_camera, + width=64, height=64) + + @composer.observable + def position(self): + return observable.MJCFFeature('xpos', self._entity.root_body) + + @composer.observable + def orientation(self): + return observable.MJCFFeature('xmat', self._entity.root_body) + + def add_egocentric_vector(self, + name, + world_frame_observable, + enabled=True, + origin_callable=None, + **kwargs): + + def _egocentric(physics, origin_callable=origin_callable): + vec = world_frame_observable.observation_callable(physics)() + origin_callable = origin_callable or (lambda physics: np.zeros(vec.size)) + delta = vec - origin_callable(physics) + return self._entity.transform_vec_to_egocentric_frame(physics, delta) + + self._observables[name] = observable.Generic(_egocentric, **kwargs) + self._observables[name].enabled = enabled + + def add_egocentric_xmat(self, name, xmat_observable, enabled=True, **kwargs): + + def _egocentric(physics): + return self._entity.transform_xmat_to_egocentric_frame( + physics, + xmat_observable.observation_callable(physics)()) + + self._observables[name] = observable.Generic(_egocentric, **kwargs) + self._observables[name].enabled = enabled + + # Semantic groupings of Walker observables. + def _collect_from_attachments(self, attribute_name): + out = [] + for entity in self._entity.iter_entities(exclude_self=True): + out.extend(getattr(entity.observables, attribute_name, [])) + return out + + @property + def proprioception(self): + return ([self.joints_pos, self.joints_vel, + self.body_height, self.end_effectors_pos, self.world_zaxis] + + self._collect_from_attachments('proprioception')) + + @property + def kinematic_sensors(self): + return ([self.sensors_gyro, self.sensors_velocimeter, + self.sensors_accelerometer] + + self._collect_from_attachments('kinematic_sensors')) + + @property + def dynamic_sensors(self): + return ([self.sensors_force, self.sensors_torque, self.sensors_touch] + + self._collect_from_attachments('dynamic_sensors')) + + # Convenience observables for defining rewards and terminations. + @composer.observable + def veloc_strafe(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[1] + + @composer.observable + def veloc_up(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[2] + + @composer.observable + def veloc_forward(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.velocimeter)[0] + + @composer.observable + def gyro_backward_roll(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.gyro)[0] + + @composer.observable + def gyro_rightward_roll(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.gyro)[1] + + @composer.observable + def gyro_anticlockwise_spin(self): + return observable.MJCFFeature( + 'sensordata', self._entity.mjcf_model.sensor.gyro)[2] + + @composer.observable + def torso_xvel(self): + return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[0] + + @composer.observable + def torso_yvel(self): + return observable.MJCFFeature('subtree_linvel', self._entity.root_body)[1] + + @composer.observable + def prev_action(self): + return observable.Generic(lambda _: self._entity.prev_action) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py new file mode 100644 index 0000000..9842718 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent.py @@ -0,0 +1,320 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""A Rodent walker.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation import observable +from dm_control.locomotion.walkers import base +from dm_control.locomotion.walkers import legacy_base +from dm_control.mujoco import wrapper as mj_wrapper +import numpy as np + +_XML_PATH = os.path.join(os.path.dirname(__file__), + 'assets/rodent.xml') + +_RAT_MOCAP_JOINTS = [ + 'vertebra_1_extend', 'vertebra_2_bend', 'vertebra_3_twist', + 'vertebra_4_extend', 'vertebra_5_bend', 'vertebra_6_twist', + 'hip_L_supinate', 'hip_L_abduct', 'hip_L_extend', 'knee_L', 'ankle_L', + 'toe_L', 'hip_R_supinate', 'hip_R_abduct', 'hip_R_extend', 'knee_R', + 'ankle_R', 'toe_R', 'vertebra_C1_extend', 'vertebra_C1_bend', + 'vertebra_C2_extend', 'vertebra_C2_bend', 'vertebra_C3_extend', + 'vertebra_C3_bend', 'vertebra_C4_extend', 'vertebra_C4_bend', + 'vertebra_C5_extend', 'vertebra_C5_bend', 'vertebra_C6_extend', + 'vertebra_C6_bend', 'vertebra_C7_extend', 'vertebra_C9_bend', + 'vertebra_C11_extend', 'vertebra_C13_bend', 'vertebra_C15_extend', + 'vertebra_C17_bend', 'vertebra_C19_extend', 'vertebra_C21_bend', + 'vertebra_C23_extend', 'vertebra_C25_bend', 'vertebra_C27_extend', + 'vertebra_C29_bend', 'vertebra_cervical_5_extend', + 'vertebra_cervical_4_bend', 'vertebra_cervical_3_twist', + 'vertebra_cervical_2_extend', 'vertebra_cervical_1_bend', + 'vertebra_axis_twist', 'vertebra_atlant_extend', 'atlas', 'mandible', + 'scapula_L_supinate', 'scapula_L_abduct', 'scapula_L_extend', 'shoulder_L', + 'shoulder_sup_L', 'elbow_L', 'wrist_L', 'finger_L', 'scapula_R_supinate', + 'scapula_R_abduct', 'scapula_R_extend', 'shoulder_R', 'shoulder_sup_R', + 'elbow_R', 'wrist_R', 'finger_R' +] + + +_UPRIGHT_POS = (0.0, 0.0, 0.0) +_UPRIGHT_QUAT = (1., 0., 0., 0.) +_TORQUE_THRESHOLD = 60 + + +class Rat(legacy_base.Walker): + """A position-controlled rat with control range scaled to [-1, 1].""" + + def _build(self, + params=None, + name='walker', + initializer=None): + self.params = params + self._mjcf_root = mjcf.from_path(_XML_PATH) + if name: + self._mjcf_root.model = name + + self.body_sites = [] + super(Rat, self)._build(initializer=initializer) + + @property + def upright_pose(self): + """Reset pose to upright position.""" + return base.WalkerPose(xpos=_UPRIGHT_POS, xquat=_UPRIGHT_QUAT) + + @property + def mjcf_model(self): + """Return the model root.""" + return self._mjcf_root + + @composer.cached_property + def actuators(self): + """Return all actuators.""" + return tuple(self._mjcf_root.find_all('actuator')) + + @composer.cached_property + def root_body(self): + """Return the body.""" + return self._mjcf_root.find('body', 'torso') + + @composer.cached_property + def pelvis_body(self): + """Return the body.""" + return self._mjcf_root.find('body', 'pelvis') + + @composer.cached_property + def head(self): + """Return the head.""" + return self._mjcf_root.find('body', 'skull') + + @composer.cached_property + def left_arm_root(self): + """Return the left arm.""" + return self._mjcf_root.find('body', 'scapula_L') + + @composer.cached_property + def right_arm_root(self): + """Return the right arm.""" + return self._mjcf_root.find('body', 'scapula_R') + + @composer.cached_property + def ground_contact_geoms(self): + """Return ground contact geoms.""" + return tuple( + self._mjcf_root.find('body', 'foot_L').find_all('geom') + + self._mjcf_root.find('body', 'foot_R').find_all('geom')) + + @composer.cached_property + def standing_height(self): + """Return standing height.""" + return self.params['_STAND_HEIGHT'] + + @composer.cached_property + def end_effectors(self): + """Return end effectors.""" + return (self._mjcf_root.find('body', 'lower_arm_R'), + self._mjcf_root.find('body', 'lower_arm_L'), + self._mjcf_root.find('body', 'foot_R'), + self._mjcf_root.find('body', 'foot_L')) + + @composer.cached_property + def observable_joints(self): + """Return observable joints.""" + return tuple(actuator.joint + for actuator in self.actuators # This lint is mistaken; pylint: disable=not-an-iterable + if actuator.joint is not None) + + @composer.cached_property + def observable_tendons(self): + return self._mjcf_root.find_all('tendon') + + @composer.cached_property + def mocap_joints(self): + return tuple( + self._mjcf_root.find('joint', name) for name in _RAT_MOCAP_JOINTS) + + @composer.cached_property + def mocap_joint_order(self): + return tuple([jnt.name for jnt in self.mocap_joints]) # This lint is mistaken; pylint: disable=not-an-iterable + + @composer.cached_property + def bodies(self): + """Return all bodies.""" + return tuple(self._mjcf_root.find_all('body')) + + @composer.cached_property + def mocap_bodies(self): + """Return bodies for mocap comparison.""" + return tuple(body for body in self._mjcf_root.find_all('body') + if not re.match(r'(vertebra|hand|toe)', body.name)) + + @composer.cached_property + def primary_joints(self): + """Return primary (non-vertebra) joints.""" + return tuple(jnt for jnt in self._mjcf_root.find_all('joint') + if 'vertebra' not in jnt.name) + + @composer.cached_property + def vertebra_joints(self): + """Return vertebra joints.""" + return tuple(jnt for jnt in self._mjcf_root.find_all('joint') + if 'vertebra' in jnt.name) + + @composer.cached_property + def primary_joint_order(self): + joint_names = self.mocap_joint_order + primary_names = tuple([jnt.name for jnt in self.primary_joints]) # pylint: disable=not-an-iterable + primary_order = [] + for nm in primary_names: + primary_order.append(joint_names.index(nm)) + return primary_order + + @composer.cached_property + def vertebra_joint_order(self): + joint_names = self.mocap_joint_order + vertebra_names = tuple([jnt.name for jnt in self.vertebra_joints]) # pylint: disable=not-an-iterable + vertebra_order = [] + for nm in vertebra_names: + vertebra_order.append(joint_names.index(nm)) + return vertebra_order + + @composer.cached_property + def egocentric_camera(self): + """Return the egocentric camera.""" + return self._mjcf_root.find('camera', 'egocentric') + + @property + def _xml_path(self): + """Return the path to th model .xml file.""" + return self.params['_XML_PATH'] + + @composer.cached_property + def joint_actuators(self): + """Return all joint actuators.""" + return tuple([act for act in self._mjcf_root.find_all('actuator') + if act.joint]) + + @composer.cached_property + def joint_actuators_range(self): + act_joint_range = [] + for act in self.joint_actuators: # This lint is mistaken; pylint: disable=not-an-iterable + associated_joint = self._mjcf_root.find('joint', act.name) + act_range = associated_joint.dclass.joint.range + act_joint_range.append(act_range) + return act_joint_range + + def pose_to_actuation(self, pose): + # holds for joint actuators, find desired torque = 0 + # u_ref = [2 q_ref - (r_low + r_up) ]/(r_up - r_low) + r_lower = np.array([ajr[0] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable + r_upper = np.array([ajr[1] for ajr in self.joint_actuators_range]) # This lint is mistaken; pylint: disable=not-an-iterable + num_tendon_actuators = len(self.actuators) - len(self.joint_actuators) + tendon_actions = np.zeros(num_tendon_actuators) + return np.hstack([tendon_actions, (2*pose[self.joint_actuator_order]- + (r_lower+r_upper))/(r_upper-r_lower)]) + + @composer.cached_property + def joint_actuator_order(self): + joint_names = self.mocap_joint_order + joint_actuator_names = tuple([act.name for act in self.joint_actuators]) # This lint is mistaken; pylint: disable=not-an-iterable + actuator_order = [] + for nm in joint_actuator_names: + actuator_order.append(joint_names.index(nm)) + return actuator_order + + def _build_observables(self): + return RodentObservables(self) + + +class RodentObservables(legacy_base.WalkerObservables): + """Observables for the Rat.""" + + @composer.observable + def head_height(self): + """Observe the head height.""" + return observable.MJCFFeature('xpos', self._entity.head)[2] + + @composer.observable + def sensors_torque(self): + """Observe the torque sensors.""" + return observable.MJCFFeature( + 'sensordata', + self._entity.mjcf_model.sensor.torque, + corruptor=lambda v, random_state: np.tanh(2 * v / _TORQUE_THRESHOLD) + ) + + @composer.observable + def tendons_pos(self): + return observable.MJCFFeature('length', self._entity.observable_tendons) + + @composer.observable + def tendons_vel(self): + return observable.MJCFFeature('velocity', self._entity.observable_tendons) + + @composer.observable + def actuator_activation(self): + """Observe the actuator activation.""" + model = self._entity.mjcf_model + return observable.MJCFFeature('act', model.find_all('actuator')) + + @composer.observable + def appendages_pos(self): + """Equivalent to `end_effectors_pos` with head's position appended.""" + + def relative_pos_in_egocentric_frame(physics): + end_effectors_with_head = ( + self._entity.end_effectors + (self._entity.head,)) + end_effector = physics.bind(end_effectors_with_head).xpos + torso = physics.bind(self._entity.root_body).xpos + xmat = \ + np.reshape(physics.bind(self._entity.root_body).xmat, (3, 3)) + return np.reshape(np.dot(end_effector - torso, xmat), -1) + + return observable.Generic(relative_pos_in_egocentric_frame) + + @property + def proprioception(self): + """Return proprioceptive information.""" + return [ + self.joints_pos, self.joints_vel, + self.tendons_pos, self.tendons_vel, + self.actuator_activation, + self.body_height, self.end_effectors_pos, self.appendages_pos, + self.world_zaxis + ] + self._collect_from_attachments('proprioception') + + @composer.observable + def egocentric_camera(self): + """Observable of the egocentric camera.""" + + if not hasattr(self, '_scene_options'): + # Don't render this walker's geoms. + self._scene_options = mj_wrapper.MjvOption() + collision_geom_group = 2 + self._scene_options.geomgroup[collision_geom_group] = 0 + cosmetic_geom_group = 1 + self._scene_options.geomgroup[cosmetic_geom_group] = 0 + + return observable.MJCFCamera(self._entity.egocentric_camera, + width=64, height=64, + scene_option=self._scene_options + ) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py new file mode 100644 index 0000000..42b31d2 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/rodent_test.py @@ -0,0 +1,116 @@ +# Copyright 2020 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for the Rodent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from absl.testing import parameterized +from dm_control import composer +from dm_control import mjcf +from dm_control.composer.observation.observable import base as observable_base +from dm_control.locomotion.arenas import corridors as corr_arenas +from dm_control.locomotion.tasks import corridors as corr_tasks +from dm_control.locomotion.walkers import rodent + +import numpy as np +from six.moves import range + +_CONTROL_TIMESTEP = .02 +_PHYSICS_TIMESTEP = 0.001 + + +def _get_rat_corridor_physics(): + walker = rodent.Rat() + arena = corr_arenas.EmptyCorridor() + task = corr_tasks.RunThroughCorridor( + walker=walker, + arena=arena, + walker_spawn_position=(5, 0, 0), + walker_spawn_rotation=0, + physics_timestep=_PHYSICS_TIMESTEP, + control_timestep=_CONTROL_TIMESTEP) + + env = composer.Environment( + time_limit=30, + task=task, + strip_singleton_obs_buffer_dim=True) + + return walker, env + + +class RatTest(parameterized.TestCase): + + def test_can_compile_and_step_simulation(self): + _, env = _get_rat_corridor_physics() + physics = env.physics + for _ in range(100): + physics.step() + + @parameterized.parameters([ + 'egocentric_camera', + 'head', + 'left_arm_root', + 'right_arm_root', + 'root_body', + 'pelvis_body', + ]) + def test_get_element_property(self, name): + attribute_value = getattr(rodent.Rat(), name) + self.assertIsInstance(attribute_value, mjcf.Element) + + @parameterized.parameters([ + 'actuators', + 'bodies', + 'mocap_bodies', + 'end_effectors', + 'mocap_joints', + 'observable_joints', + ]) + def test_get_element_tuple_property(self, name): + attribute_value = getattr(rodent.Rat(), name) + self.assertNotEmpty(attribute_value) + for item in attribute_value: + self.assertIsInstance(item, mjcf.Element) + + def test_set_name(self): + name = 'fred' + walker = rodent.Rat(name=name) + self.assertEqual(walker.mjcf_model.model, name) + + @parameterized.parameters( + 'tendons_pos', + 'tendons_vel', + 'actuator_activation', + 'appendages_pos', + 'head_height', + 'sensors_torque', + ) + def test_evaluate_observable(self, name): + walker, env = _get_rat_corridor_physics() + physics = env.physics + observable = getattr(walker.observables, name) + observation = observable(physics) + self.assertIsInstance(observation, (float, np.ndarray)) + + def test_proprioception(self): + walker = rodent.Rat() + for item in walker.observables.proprioception: + self.assertIsInstance(item, observable_base.Observable) + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py new file mode 100644 index 0000000..806f5c3 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators.py @@ -0,0 +1,131 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Position & velocity actuators whose controls are scaled to a given range.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +_DISALLOWED_KWARGS = frozenset( + ['biastype', 'gainprm', 'biasprm', 'ctrllimited', + 'joint', 'tendon', 'site', 'slidersite', 'cranksite']) +_ALLOWED_TAGS = frozenset(['joint', 'tendon', 'site']) + +_GOT_INVALID_KWARGS = 'Received invalid keyword argument(s): {}' +_GOT_INVALID_TARGET = '`target` tag type should be one of {}: got {{}}'.format( + sorted(_ALLOWED_TAGS)) + + +def _check_target_and_kwargs(target, **kwargs): + invalid_kwargs = _DISALLOWED_KWARGS.intersection(kwargs) + if invalid_kwargs: + raise TypeError(_GOT_INVALID_KWARGS.format(sorted(invalid_kwargs))) + if target.tag not in _ALLOWED_TAGS: + raise TypeError(_GOT_INVALID_TARGET.format(target)) + + +def add_position_actuator(target, qposrange, ctrlrange=(-1, 1), + kp=1.0, **kwargs): + """Adds a scaled position actuator that is bound to the specified element. + + This is equivalent to MuJoCo's built-in `` actuator where an affine + transformation is pre-applied to the control signal, such that the minimum + control value corresponds to the minimum desired position, and the + maximum control value corresponds to the maximum desired position. + + Args: + target: A PyMJCF joint, tendon, or site element object that is to be + controlled. + qposrange: A sequence of two numbers specifying the allowed range of target + position. + ctrlrange: A sequence of two numbers specifying the allowed range of + this actuator's control signal. + kp: The gain parameter of this position actuator. + **kwargs: Additional MJCF attributes for this actuator element. + The following attributes are disallowed: `['biastype', 'gainprm', + 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site', + 'slidersite', 'cranksite']`. + + Returns: + A PyMJCF actuator element that has been added to the MJCF model containing + the specified `target`. + + Raises: + TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute, + or `target` is not an allowed MJCF element type. + """ + _check_target_and_kwargs(target, **kwargs) + kwargs[target.tag] = target + + slope = (qposrange[1] - qposrange[0]) / (ctrlrange[1] - ctrlrange[0]) + g0 = kp * slope + b0 = kp * (qposrange[0] - slope * ctrlrange[0]) + b1 = -kp + b2 = 0 + return target.root.actuator.add('general', + biastype='affine', + gainprm=[g0], + biasprm=[b0, b1, b2], + ctrllimited=True, + ctrlrange=ctrlrange, + **kwargs) + + +def add_velocity_actuator(target, qvelrange, ctrlrange=(-1, 1), + kv=1.0, **kwargs): + """Adds a scaled velocity actuator that is bound to the specified element. + + This is equivalent to MuJoCo's built-in `` actuator where an affine + transformation is pre-applied to the control signal, such that the minimum + control value corresponds to the minimum desired velocity, and the + maximum control value corresponds to the maximum desired velocity. + + Args: + target: A PyMJCF joint, tendon, or site element object that is to be + controlled. + qvelrange: A sequence of two numbers specifying the allowed range of target + velocity. + ctrlrange: A sequence of two numbers specifying the allowed range of + this actuator's control signal. + kv: The gain parameter of this velocity actuator. + **kwargs: Additional MJCF attributes for this actuator element. + The following attributes are disallowed: `['biastype', 'gainprm', + 'biasprm', 'ctrllimited', 'joint', 'tendon', 'site', + 'slidersite', 'cranksite']`. + + Returns: + A PyMJCF actuator element that has been added to the MJCF model containing + the specified `target`. + + Raises: + TypeError: `kwargs` contains an unrecognized or disallowed MJCF attribute, + or `target` is not an allowed MJCF element type. + """ + _check_target_and_kwargs(target, **kwargs) + kwargs[target.tag] = target + + slope = (qvelrange[1] - qvelrange[0]) / (ctrlrange[1] - ctrlrange[0]) + g0 = kv * slope + b0 = kv * (qvelrange[0] - slope * ctrlrange[0]) + b1 = 0 + b2 = -kv + return target.root.actuator.add('general', + biastype='affine', + gainprm=[g0], + biasprm=[b0, b1, b2], + ctrllimited=True, + ctrlrange=ctrlrange, + **kwargs) diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py new file mode 100644 index 0000000..1ff0f20 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/locomotion/walkers/scaled_actuators_test.py @@ -0,0 +1,135 @@ +# Copyright 2019 The dm_control Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Tests for scaled actuators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import absltest +from dm_control import mjcf +from dm_control.locomotion.walkers import scaled_actuators +import numpy as np +from six.moves import range + + +class ScaledActuatorsTest(absltest.TestCase): + + def setUp(self): + super(ScaledActuatorsTest, self).setUp() + self._mjcf_model = mjcf.RootElement() + self._min = -1.4 + self._max = 2.3 + self._gain = 1.7 + self._scaled_min = -0.8 + self._scaled_max = 1.3 + self._range = self._max - self._min + self._scaled_range = self._scaled_max - self._scaled_min + self._joints = [] + for _ in range(2): + body = self._mjcf_model.worldbody.add('body') + body.add('geom', type='sphere', size=[1]) + self._joints.append(body.add('joint', type='hinge')) + self._scaled_actuator_joint = self._joints[0] + self._standard_actuator_joint = self._joints[1] + self._random_state = np.random.RandomState(3474) + + def _set_actuator_controls(self, physics, normalized_ctrl, + scaled_actuator=None, standard_actuator=None): + if scaled_actuator is not None: + physics.bind(scaled_actuator).ctrl = ( + normalized_ctrl * self._scaled_range + self._scaled_min) + if standard_actuator is not None: + physics.bind(standard_actuator).ctrl = ( + normalized_ctrl * self._range + self._min) + + def _assert_same_qfrc_actuator(self, physics, joint1, joint2): + np.testing.assert_allclose(physics.bind(joint1).qfrc_actuator, + physics.bind(joint2).qfrc_actuator) + + def test_position_actuator(self): + scaled_actuator = scaled_actuators.add_position_actuator( + target=self._scaled_actuator_joint, kp=self._gain, + qposrange=(self._min, self._max), + ctrlrange=(self._scaled_min, self._scaled_max)) + standard_actuator = self._mjcf_model.actuator.add( + 'position', joint=self._standard_actuator_joint, kp=self._gain, + ctrllimited=True, ctrlrange=(self._min, self._max)) + physics = mjcf.Physics.from_mjcf_model(self._mjcf_model) + + # Zero torque. + physics.bind(self._scaled_actuator_joint).qpos = ( + 0.2345 * self._range + self._min) + self._set_actuator_controls(physics, 0.2345, scaled_actuator) + np.testing.assert_allclose( + physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15) + + for _ in range(100): + normalized_ctrl = self._random_state.uniform() + physics.bind(self._joints).qpos = ( + self._random_state.uniform() * self._range + self._min) + self._set_actuator_controls(physics, normalized_ctrl, + scaled_actuator, standard_actuator) + self._assert_same_qfrc_actuator( + physics, self._scaled_actuator_joint, self._standard_actuator_joint) + + def test_velocity_actuator(self): + scaled_actuator = scaled_actuators.add_velocity_actuator( + target=self._scaled_actuator_joint, kv=self._gain, + qvelrange=(self._min, self._max), + ctrlrange=(self._scaled_min, self._scaled_max)) + standard_actuator = self._mjcf_model.actuator.add( + 'velocity', joint=self._standard_actuator_joint, kv=self._gain, + ctrllimited=True, ctrlrange=(self._min, self._max)) + physics = mjcf.Physics.from_mjcf_model(self._mjcf_model) + + # Zero torque. + physics.bind(self._scaled_actuator_joint).qvel = ( + 0.5432 * self._range + self._min) + self._set_actuator_controls(physics, 0.5432, scaled_actuator) + np.testing.assert_allclose( + physics.bind(self._scaled_actuator_joint).qfrc_actuator, 0, atol=1e-15) + + for _ in range(100): + normalized_ctrl = self._random_state.uniform() + physics.bind(self._joints).qvel = ( + self._random_state.uniform() * self._range + self._min) + self._set_actuator_controls(physics, normalized_ctrl, + scaled_actuator, standard_actuator) + self._assert_same_qfrc_actuator( + physics, self._scaled_actuator_joint, self._standard_actuator_joint) + + def test_invalid_kwargs(self): + invalid_kwargs = dict(joint=self._scaled_actuator_joint, ctrllimited=False) + with self.assertRaisesWithLiteralMatch( + TypeError, + scaled_actuators._GOT_INVALID_KWARGS.format(sorted(invalid_kwargs))): + scaled_actuators.add_position_actuator( + target=self._scaled_actuator_joint, + qposrange=(self._min, self._max), + **invalid_kwargs) + + def test_invalid_target(self): + invalid_target = self._mjcf_model.worldbody + with self.assertRaisesWithLiteralMatch( + TypeError, + scaled_actuators._GOT_INVALID_TARGET.format(invalid_target)): + scaled_actuators.add_position_actuator( + target=invalid_target, qposrange=(self._min, self._max)) + + +if __name__ == '__main__': + absltest.main() diff --git a/DMC/src/env/dm_control/dm_control/locomotion/walls.png b/DMC/src/env/dm_control/dm_control/locomotion/walls.png new file mode 100644 index 0000000..fc15386 Binary files /dev/null and b/DMC/src/env/dm_control/dm_control/locomotion/walls.png differ diff --git a/DMC/src/env/dm_control/dm_control/mjcf/README.md b/DMC/src/env/dm_control/dm_control/mjcf/README.md new file mode 100644 index 0000000..8418b31 --- /dev/null +++ b/DMC/src/env/dm_control/dm_control/mjcf/README.md @@ -0,0 +1,498 @@ +# PyMJCF + +IMPORTANT: If you find yourself stuck while using PyMJCF, check out the various +IMPORTANT boxes on this page and the [Common gotchas](#common-gotchas) section +at the bottom to see if any of them is relevant. + +This library provides a Python object model for MuJoCo's XML-based +[MJCF](http://www.mujoco.org/book/modeling.html) physics modeling language. The +goal of the library is to allow users to easily interact with and modify MJCF +models in Python, similarly to what the JavaScript DOM does for HTML. + +A key feature of this library is the ability to easily compose multiple separate +MJCF models into a larger one. Disambiguation of duplicated names from different +models, or multiple instances of the same model, is handled automatically. + +The following snippet provides a quick example of this library's typical use +case. Here, the `UpperBody` class can simply instantiate two copies of `Arm`, +thus reducing code duplication. The names of bodies, joints, or geoms of each +`Arm` are automatically prefixed by their parent's names, and so no name +collision occurs. + +```python +from dm_control import mjcf + +class Arm(object): + + def __init__(self, name): + self.mjcf_model = mjcf.RootElement(model=name) + + self.upper_arm = self.mjcf_model.worldbody.add('body', name='upper_arm') + self.shoulder = self.upper_arm.add('joint', name='shoulder', type='ball') + self.upper_arm.add('geom', name='upper_arm', type='capsule', + pos=[0, 0, -0.15], size=[0.045, 0.15]) + + self.forearm = self.upper_arm.add('body', name='forearm', pos=[0, 0, -0.3]) + self.elbow = self.forearm.add('joint', name='elbow', + type='hinge', axis=[0, 1, 0]) + self.forearm.add('geom', name='forearm', type='capsule', + pos=[0, 0, -0.15], size=[0.045, 0.15]) + +class UpperBody(object): + + def __init__(self): + self.mjcf_model = mjcf.RootElement() + self.mjcf_model.worldbody.add( + 'geom', name='torso', type='box', size=[0.15, 0.045, 0.25]) + left_shoulder_site = self.mjcf_model.worldbody.add( + 'site', size=[1e-6]*3, pos=[-0.15, 0, 0.25]) + right_shoulder_site = self.mjcf_model.worldbody.add( + 'site', size=[1e-6]*3, pos=[0.15, 0, 0.25]) + + self.left_arm = Arm(name='left_arm') + left_shoulder_site.attach(self.left_arm.mjcf_model) + + self.right_arm = Arm(name='right_arm') + right_shoulder_site.attach(self.right_arm.mjcf_model) + +body = UpperBody() +physics = mjcf.Physics.from_mjcf_model(body.mjcf_model) +``` + +## Basic operations + +### Creating an MJCF model + +In PyMJCF, the basic building block of a model is an `mjcf.Element`. This +corresponds to an element in the generated XML. However, user code _cannot_ +instantiate a generic `mjcf.Element` object directly. + +A valid model always consists of a single root `` element. This is +represented as the special `mjcf.RootElement` type in PyMJCF, which _can_ be +instantiated in user code to create an empty model. + +```python +from dm_control import mjcf + +mjcf_model = mjcf.RootElement() +print(mjcf_model) # MJCF Element: +``` + +### Adding new elements + +Attributes of the new element can be passed as kwargs: + +```python +my_box = mjcf_model.worldbody.add('geom', name='my_box', + type='box', pos=[0, .1, 0]) +print(my_box) # MJCF Element: +``` + +### Parsing an existing XML document + +Alternatively, if an existing XML file already exists, PyMJCF can parse it to +create a Python object: + +```python +from dm_control import mjcf + +# Parse from path +mjcf_model = mjcf.from_path(filename) + +# Parse from file +with open(filename) as f: + mjcf_model = mjcf.from_file(f) + +# Parse from string +with open(filename) as f: + xml_string = f.read() +mjcf_model = mjcf.from_xml_string(xml_string) + +print(type(mjcf_model)) # +``` + +### Traversing through a model + +Consider the following MJCF model: + +```xml + + + + + + + + + + + + + + + + + +``` + +The child elements and XML attributes of an `Element` object are exposed as +Python attributes. These attributes all have the same names as their XML +counterparts, with one exception: the `class` XML attribute is named `dclass` in +order to avoid a clash with the Python `class` keyword: + +```python +my_geom = mjcf_model.worldbody.body['foo'].body['bar'].geom['my_geom'] +print(isinstance(mjcf_model, mjcf.Element)) # True +print(my_geom.name) # 'my_geom' +print(my_geom.pos) # np.array([0., 1., 2.], dtype=np.float) +print(my_geom.class) # SyntaxError +print(my_geom.dclass) # 'brick' +``` + +Note that attribute values in the object model are **not** affected by defaults: + +```python +print(mjcf_model.default.default['brick'].geom.rgba) # [1, 0, 0, 1] +print(my_geom.rgba) # None +``` + +### Finding elements without traversing + +We can also find elements directly without having to traverse through the object +hierarchy: + +```python +found_geom = mjcf_model.find('geom', 'my_geom') +print(found_geom == my_geom) # True +``` + +Find all elements of a given type: + +```python +# Note that is also considered a joint +joints = mjcf_model.find_all('joint') +print(len(joints)) # 2 +print(joints[0] == mjcf_model.worldbody.body['foo'].freejoint) # True +print(joints[1] == mjcf_model.worldbody.body['foo'].body['bar'].joint[0]) # True +``` + +Note that the order of elements returned by `find_all` is the same as the order +in which they are declared in the model. + +### Modifying XML attributes + +Attributes can be modified, added, or removed: + +```python +my_geom.pos = [1, 2, 3] +print(my_geom.pos) # np.array([1., 2., 3.], dtype=np.float) +my_geom.quat = [0, 1, 0, 0] +print(my_geom.quat) # np.array([0., 1., 0., 0.], dtype=np.float) +del my_geom.quat +print(my_geom.quat) # None +``` + +Schema violations result in errors: + +```python +print(my_geom.poss) # raise AttributeError (no child or attribute called poss) +my_geom.pos = 'invalid' # raise ValueError (assigning string to array) +my_geom.pos = [1, 2, 3, 4, 5, 6] # raise ValueError (array length is too long) + +# raise ValueError (mass is a required attribute of ) +del mjcf_model.find('body', 'foo').inertial.mass +``` + +### Uniqueness of identifiers + +PyMJCF enforces the uniqueness of "identifier" attributes within a model. +Identifiers consist of the `class` attribute of a ``, and all `name` +attributes. Their uniqueness is only enforced within a particular namespace. For +example, a `` is allowed to have the same name as a ``, whereas +`` and `` actuators cannot have the same name. + +```python +mjcf_model.worldbody.add('geom', name='my_geom') +foo = mjcf_model.worldbody.find('body', 'foo') +foo.add('my_geom') # Error, duplicated geom name +foo.add('foo') # OK, a geom can have the same name as a body +mjcf_model.find('geom', 'foo').name = 'my_geom' # Error, duplicated geom name +``` + +### Reference attributes + +Some attributes are references to other elements. For example, the `joint` +attribute of an actuator refers to a `` element in the model. + +An `mjcf.Element` can be directly assigned to these reference attributes: + +```python +my_hinge = mjcf_model.find('joint', 'my_hinge') +my_actuator = mjcf_model.actuator.add('velocity', joint=my_hinge) +``` + +This is the recommended way to assign reference attributes, since it guarantees +that the reference is not invalidated if the referenced element is renamed. +Alternatively, a string can also be assigned to reference attributes. In this +case, PyMJCF does **not** attempt to verify that the named element actually +exists in the model. + +IMPORTANT: If the element being referenced is in a different model to the +reference attribute (e.g. in an attached model), the reference **must** be +created by directly assigning an `mjcf.Element` object to the attribute rather +than a string. Strings assigned to reference attributes cannot contain '/', +since they are automatically scoped by PyMJCF upon attachment. + +## Attaching models + +In this section we will refer to an `mjcf.RootElement` simply as a "model". +Models can be _attached_ to other models in order to create compositional +scenes. + +```python +arena = mjcf.RootElement() +arena.worldbody.add('geom', name='ground', type='plane', size=[10, 10, 1]) + +robot = mjcf.from_xml_file('robot.xml') +arena.attach(robot) +``` + +We refer to `arena` as the _parent model_, and `robot` as the _child model_ (or +the _attached model_). + +### Attachment frames + +When a model is attached to a site, an empty body is created in the parent +model. This empty body is called an _attachment frame_. + +The attachment frame is created as a child of the body that contains the +attachment site, and it has the same position and orientation as the site. When +the XML is generated, the attachment frame's contents shadow the contents of the +attached model's ``. The attachment frame's name in the generated XML +is the child's `fully/qualified/prefix/`. The trailing slash ensures that the +attachment frame's name never collides with a user-defined body. + +More concretely, if we have the following parent and child models: + +```xml + + + + + + + + + + + + + + +``` + +Then the final generated XML will be: + +```xml + + + + + + + + + + + + +``` + +IMPORTANT: The attachment frame is created _transparently_ to the user. In +particular, it is NOT treated as a regular `body` by PyMJCF. Its name in the +generated XML should be considered implementation detail and should NOT be +relied on. + +Having said that, it is sometimes necessary to access the attachment frame, for +example to add a joint between the parent and the child model. The easiest way +to do this is to hold a reference to the object returned by a call to `attach`: + +```python +attachment_frame = parent_model.attach('child') +attachment_frame.add('freejoint') +``` + +Alternatively, if a model has already been attached, the `find` function can be +used with the `attachment_frame` namespace in order to retrieve the attachment +frame. The `get_attachment_frame` convenience function in `mjcf.traversal_utils` +can find the child model's attachment frame without needing access to the parent +model. + +```python +frame_1 = parent_model.find('attachment_frame', 'child') + +# Convenience function: get the attachment frame directly from a child model +frame_2 = mjcf.traversal_utils.get_attachment_frame(child_model) +print(frame_1 == frame_2) # True +``` + +IMPORTANT: To encourage good modeling practices, the only allowed direct +children of an attachment frame are `` and ``. Other types of +elements should instead add be added to the `` of the attached model. + +### Element ownership + +IMPORTANT: Elements of child models do **not** appear when traversing through +the parent model. + +### Default classes + +PyMJCF ensures that default classes of a parent model _never_ affect any of its +child models. This minimises the possibility that two models become subtly +"incompatible", as a model always behaves in the same way regardless of what it +is attached to. + +The way that PyMJCF achieves this in practice is to move everything in a model's +global `` context into a default class named `/`. In other words, a +PyMJCF-generated model never has anything in the global default context. +Instead, the generated model always looks like: + +```xml + + + + + + + + + +``` + +IMPORTANT: This transformation is _transparent_ to the user. Within Python, the +above geom rgba setting is accessed as if it were a global default, i.e. +`mjcf_model.default.geom.rgba`. Generally speaking, users should never have to +worry about PyMJCF's internal handling of defaults. + +When a model is attached, its `/` default class turns into +`fully/qualified/prefix/`. The trailing slash ensures that this transformation +never conflicts with a user-named default class. More specifically, if we have +the following parent and child models: + +```xml + + + + + + + + + + + + + + + + + +``` + +Then the final generated XML will be: + +```xml + + + + + + + + + + + + + + + + + +``` + +### Global options + +A model cannot be attached to another model if _any_ of the global options are +different. Global options consist of attributes of ``, `