Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 13, 2024
2 parents d89066b + 9ee1ae7 commit a6d2cf6
Show file tree
Hide file tree
Showing 59 changed files with 3,564 additions and 453 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ conda deactivate && conda activate ./env
python -c "import mlagents_envs"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity]

coverage combine
coverage xml -i
4 changes: 2 additions & 2 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@
logger.backend=
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=200 \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=True \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
Expand Down
18 changes: 14 additions & 4 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,23 @@ jobs:
with:
repository: pytorch/rl
upload-artifact: docs
runner: "linux.g5.4xlarge.nvidia.gpu"
docker-image: "nvidia/cudagl:11.4.0-base"
timeout: 120
script: |
set -e
set -v
apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
# apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils
yum makecache
# yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut
# Install Mesa and OpenGL Libraries:
yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL
# Install DRI Drivers:
yum install -y mesa-dri-drivers
# Install Xvfb for Headless Environments:
yum install -y xorg-x11-server-Xvfb
# xhost +local:docker
# Xvfb :1 -screen 0 1024x768x24 &
# export DISPLAY=:1
root_dir="$(pwd)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
Expand All @@ -51,7 +61,7 @@ jobs:
conda activate "${env_dir}"
# 2. upgrade pip, ninja and packaging
apt-get install python3-pip unzip -y -f
# apt-get install python3-pip unzip -y -f
python3 -m pip install --upgrade pip
python3 -m pip install setuptools ninja packaging cmake -U
Expand Down
3 changes: 3 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,11 @@ TorchRL offers a series of custom built-in environments.
:toctree: generated/
:template: rl_template.rst

ChessEnv
PendulumEnv
TicTacToeEnv
LLMHashingEnv


Multi-agent environments
------------------------
Expand Down
8 changes: 4 additions & 4 deletions docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.

- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``Recorder`` hook, the reward
logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.

Expand Down Expand Up @@ -174,9 +174,9 @@ Trainer and hooks
BatchSubSampler
ClearCudaCache
CountFramesLog
LogReward
LogScalar
OptimizerHook
Recorder
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Expand Down
99 changes: 99 additions & 0 deletions examples/replay-buffers/catframes-in-buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
CatFrames,
Compose,
DMControlEnv,
StepCounter,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
)

# Number of frames to stack together
frame_stack = 4
# Dimension along which the stack should occur
stack_dim = -4
# Max size of the buffer
max_size = 100_000
# Batch size of the replay buffer
training_batch_size = 32

seed = 123


def main():
catframes = CatFrames(
N=frame_stack,
dim=stack_dim,
in_keys=["pixels_trsf"],
out_keys=["pixels_trsf"],
)
env = TransformedEnv(
DMControlEnv(
env_name="cartpole",
task_name="balance",
device="cpu",
from_pixels=True,
pixels_only=True,
),
Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels"],
out_keys=["pixels_trsf"],
shape_tolerant=True,
),
UnsqueezeTransform(
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
),
catframes,
StepCounter(),
),
)
env.set_seed(seed)

transform, sampler = catframes.make_rb_transform_and_sampler(
batch_size=training_batch_size,
traj_key=("collector", "traj_ids"),
strict_length=True,
)

rb_transforms = Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
shape_tolerant=True,
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
UnsqueezeTransform(
dim=stack_dim,
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
), # 1 C W' H'
transform,
)

rb = ReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
sampler=sampler,
batch_size=training_batch_size,
transform=rb_transforms,
)

data = env.rollout(1000, break_when_any_done=False)
rb.extend(data)

training_batch = rb.sample()
print(training_batch)


if __name__ == "__main__":
main()
89 changes: 89 additions & 0 deletions examples/replay-buffers/filter-imcomplete-trajs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""Efficient Trajectory Sampling with CompletedTrajRepertoire
This example demonstrates how to design a custom transform that filters trajectories during sampling,
ensuring that only completed trajectories are present in sampled batches. This can be particularly useful
when dealing with environments where some trajectories might be corrupted or never reach a done state,
which could skew the learning process or lead to biased models. For instance, in robotics or autonomous
driving, a trajectory might be interrupted due to external factors such as hardware failures or human
intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories,
we can improve the quality of the training data and increase the robustness of our models.
"""

import torch
from tensordict import TensorDictBase
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import GymEnv, TrajCounter, Transform


class CompletedTrajectoryRepertoire(Transform):
"""
A transform that keeps track of completed trajectories and filters them out during sampling.
"""

def __init__(self):
super().__init__()
self.completed_trajectories = set()
self.repertoire_tensor = torch.zeros((), dtype=torch.int64)

def _update_repertoire(self, tensordict: TensorDictBase) -> None:
"""Updates the repertoire of completed trajectories."""
done = tensordict["next", "terminated"].squeeze(-1)
traj = tensordict["next", "traj_count"][done].view(-1)
if traj.numel():
self.completed_trajectories = self.completed_trajectories.union(
traj.tolist()
)
self.repertoire_tensor = torch.tensor(
list(self.completed_trajectories), dtype=torch.int64
)

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Updates the repertoire of completed trajectories during insertion."""
self._update_repertoire(tensordict)
return tensordict

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Filters out incomplete trajectories during sampling."""
traj = tensordict["next", "traj_count"]
traj = traj.unsqueeze(-1)
has_traj = (traj == self.repertoire_tensor).any(-1)
has_traj = has_traj.view(tensordict.shape)
return tensordict[has_traj]


def main():
# Create a CartPole environment with trajectory counting
env = GymEnv("CartPole-v1").append_transform(TrajCounter())

# Create a replay buffer with the completed trajectory repertoire transform
buffer = ReplayBuffer(
storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire()
)

# Roll out the environment for 1000 steps
while True:
rollout = env.rollout(1000, break_when_any_done=False)
if not rollout["next", "done"][-1].item():
break

# Extend the replay buffer with the rollout
buffer.extend(rollout)

# Get the last trajectory count
last_traj_count = rollout[-1]["next", "traj_count"].item()
print(f"Incomplete trajectory: {last_traj_count}")

# Sample from the replay buffer 10 times
for _ in range(10):
sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique()
print(f"Sampled trajectories: {sample_traj_counts}")
assert last_traj_count not in sample_traj_counts


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions sota-implementations/cql/cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def update(sampled_tensordict):
if collected_frames >= init_random_frames:
log_loss_td = TensorDict(batch_size=[num_updates], device=device)
for j in range(num_updates):
pbar.set_description(f"optim iter {j}")
with timeit("rb - sample"):
# sample from replay buffer
sampled_tensordict = replay_buffer.sample().to(device)
Expand Down
29 changes: 16 additions & 13 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

import torch.nn
import torch.optim
from tensordict import TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.tensorclass import NonTensorData

from torchrl.collectors import SyncDataCollector
from torchrl.data import (
Expand Down Expand Up @@ -219,17 +217,22 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
# Wrapping the kwargs in a TensorDictParams such that these items are
# send to device when necessary
distribution_kwargs=TensorDictParams(
TensorDict(
{
"low": torch.as_tensor(action_spec.space.low, device=device),
"high": torch.as_tensor(action_spec.space.high, device=device),
"tanh_loc": NonTensorData(False),
}
),
no_convert=True,
),
# send to device when necessary - not compatible with compile yet
# distribution_kwargs=TensorDictParams(
# TensorDict(
# {
# "low": torch.as_tensor(action_spec.space.low, device=device),
# "high": torch.as_tensor(action_spec.space.high, device=device),
# "tanh_loc": NonTensorData(False),
# }
# ),
# no_convert=True,
# ),
distribution_kwargs={
"low": action_spec.space.low.to(device),
"high": action_spec.space.high.to(device),
"tanh_loc": False,
},
default_interaction_type=ExplorationType.RANDOM,
)

Expand Down
17 changes: 7 additions & 10 deletions sota-implementations/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from torchrl._utils import timeit
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.objectives import group_optimizers

from torchrl.record.loggers import generate_exp_name, get_logger
from utils import (
Expand Down Expand Up @@ -112,28 +113,26 @@ def main(cfg: "DictConfig"): # noqa: F821
optimizer_critic,
optimizer_alpha,
) = make_crossQ_optimizer(cfg, loss_module)
# optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
# del optimizer_actor, optimizer_critic, optimizer_alpha
optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_alpha)
del optimizer_actor, optimizer_critic, optimizer_alpha

def update_qloss(sampled_tensordict):
optimizer_critic.zero_grad(set_to_none=True)
optimizer.zero_grad(set_to_none=True)
td_loss = {}
q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict)
sampled_tensordict.set(loss_module.tensor_keys.priority, value_meta["td_error"])
q_loss = q_loss.mean()

# Update critic
q_loss.backward()
optimizer_critic.step()
optimizer.step()
td_loss["loss_qvalue"] = q_loss
td_loss["loss_actor"] = float("nan")
td_loss["loss_alpha"] = float("nan")
return TensorDict(td_loss, device=device).detach()

def update_all(sampled_tensordict: TensorDict):
optimizer_critic.zero_grad(set_to_none=True)
optimizer_actor.zero_grad(set_to_none=True)
optimizer_alpha.zero_grad(set_to_none=True)
optimizer.zero_grad(set_to_none=True)

td_loss = {}
q_loss, value_meta = loss_module.qvalue_loss(sampled_tensordict)
Expand All @@ -148,9 +147,7 @@ def update_all(sampled_tensordict: TensorDict):

# Updates
(q_loss + actor_loss + actor_loss).backward()
optimizer_critic.step()
optimizer_actor.step()
optimizer_alpha.step()
optimizer.step()

# Update critic
td_loss["loss_qvalue"] = q_loss
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def make_ddpg_agent(cfg, train_env, eval_env, device):
"""Make DDPG agent."""
# Define Actor Network
in_keys = ["observation"]
action_spec = train_env.single_action_spec
action_spec = train_env.action_spec_unbatched
actor_net_kwargs = {
"num_cells": cfg.network.hidden_sizes,
"out_features": action_spec.shape[-1],
Expand Down
Loading

0 comments on commit a6d2cf6

Please sign in to comment.