Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into release/0.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 31, 2024
2 parents 8c7f303 + b5e90c4 commit 56e0641
Show file tree
Hide file tree
Showing 70 changed files with 729 additions and 513 deletions.
2 changes: 2 additions & 0 deletions .github/unittest/linux_examples/scripts/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/cql/discrete_cql_online.py \
collector.total_frames=48 \
Expand Down Expand Up @@ -256,6 +257,7 @@ python .github/unittest/helpers/coverage_run_parallel.py examples/dqn/dqn_atari.
buffer.batch_size=10 \
device=cuda:0 \
loss.num_updates=1 \
logger.backend= \
buffer.buffer_size=120
python .github/unittest/helpers/coverage_run_parallel.py examples/redq/redq.py \
num_workers=2 \
Expand Down
17 changes: 10 additions & 7 deletions examples/bandits/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import torch
import tqdm
from torch import nn

from tensordict.nn import TensorDictSequential
from torch import nn
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import DistributionalQValueActor, EGreedyWrapper, MLP, QValueActor
from torchrl.modules import DistributionalQValueActor, EGreedyModule, MLP, QValueActor
from torchrl.objectives import DistributionalDQNLoss, DQNLoss

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -85,12 +86,14 @@
actor(env.reset())
loss = DQNLoss(actor, loss_function="smooth_l1", action_space=env.action_spec)
loss.make_value_estimator(gamma=0.0)
policy = EGreedyWrapper(
policy = TensorDictSequential(
actor,
eps_init=eps_greedy,
eps_end=0.0,
annealing_num_steps=n_steps,
spec=env.action_spec,
EGreedyModule(
eps_init=eps_greedy,
eps_end=0.0,
annealing_num_steps=n_steps,
spec=env.action_spec,
),
)
optim = torch.optim.Adam(loss.parameters(), lr, weight_decay=wd)

Expand Down
4 changes: 2 additions & 2 deletions examples/cql/discrete_cql_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
env:
name: CartPole-v1
task: ""
backend: gym
backend: gymnasium
n_samples_stats: 1000
max_episode_steps: 200
seed: 0
Expand Down Expand Up @@ -36,7 +36,7 @@ replay_buffer:
prb: 0
buffer_prefetch: 64
size: 1_000_000
scratch_dir: ${env.exp_name}_${env.seed}
scratch_dir: null

# Optimization
optim:
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/discrete_cql_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main(cfg: "DictConfig"): # noqa: F821
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

Expand Down
2 changes: 1 addition & 1 deletion examples/cql/offline_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ env:
library: gym
n_samples_stats: 1000
seed: 0
backend: gym # D4RL uses gym so we make sure gymnasium is hidden
backend: gymnasium

# logger
logger:
Expand Down
2 changes: 1 addition & 1 deletion examples/cql/online_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ env:
seed: 0
train_num_envs: 1
eval_num_envs: 1
backend: gym
backend: gymnasium

# Collector
collector:
Expand Down
7 changes: 3 additions & 4 deletions examples/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir=None,
scratch_dir=None,
device="cpu",
prefetch=3,
):
Expand All @@ -133,7 +133,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand All @@ -144,7 +144,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand Down Expand Up @@ -320,7 +320,6 @@ def make_discrete_loss(loss_cfg, model):
model,
loss_function=loss_cfg.loss_function,
delay_value=True,
gamma=loss_cfg.gamma,
)
loss_module.make_value_estimator(gamma=loss_cfg.gamma)
target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau)
Expand Down
2 changes: 1 addition & 1 deletion examples/ddpg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ collector:
replay_buffer:
size: 1000000
prb: 0 # use prioritized experience replay
scratch_dir: ${logger.exp_name}_${env.seed}
scratch_dir: null

# optimization
optim:
Expand Down
2 changes: 1 addition & 1 deletion examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

Expand Down
6 changes: 3 additions & 3 deletions examples/ddpg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir=None,
scratch_dir=None,
device="cpu",
prefetch=3,
):
Expand All @@ -131,7 +131,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand All @@ -142,7 +142,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/dt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ replay_buffer:
stacked_frames: 20
buffer_prefetch: 64
capacity: 1_000_000
buffer_scratch_dir:
scratch_dir:
device: cpu
prefetch: 3

Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/odt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ replay_buffer:
stacked_frames: 20
buffer_prefetch: 64
capacity: 1_000_000
buffer_scratch_dir:
scratch_dir:
device: cuda:0
prefetch: 3

Expand Down
2 changes: 1 addition & 1 deletion examples/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def make_online_replay_buffer(offline_buffer, rb_cfg, reward_scaling=0.001):
)
storage = LazyMemmapStorage(
max_size=rb_cfg.capacity,
scratch_dir=rb_cfg.buffer_scratch_dir,
scratch_dir=rb_cfg.scratch_dir,
device=rb_cfg.device,
)

Expand Down
2 changes: 1 addition & 1 deletion examples/discrete_sac/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ collector:
replay_buffer:
prb: 0 # use prioritized experience replay
size: 1000000
scratch_dir: ${logger.exp_name}_${env.seed}
scratch_dir: null

# optim
optim:
Expand Down
2 changes: 1 addition & 1 deletion examples/discrete_sac/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def main(cfg: "DictConfig"): # noqa: F821
batch_size=cfg.optim.batch_size,
prb=cfg.replay_buffer.prb,
buffer_size=cfg.replay_buffer.size,
buffer_scratch_dir=cfg.replay_buffer.scratch_dir,
scratch_dir=cfg.replay_buffer.scratch_dir,
device="cpu",
)

Expand Down
6 changes: 3 additions & 3 deletions examples/discrete_sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,14 @@ def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir=None,
scratch_dir=None,
device="cpu",
prefetch=3,
):
with (
tempfile.TemporaryDirectory()
if buffer_scratch_dir is None
else nullcontext(buffer_scratch_dir)
if scratch_dir is None
else nullcontext(scratch_dir)
) as scratch_dir:
if prb:
replay_buffer = TensorDictPrioritizedReplayBuffer(
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
if __name__ == "__main__":

# 1. Define Hyperparameters
device = "cpu" # if not torch.has_cuda else "cuda:0"
device = "cpu" # if not torch.cuda.device_count() else "cuda:0"
num_cells = 256
max_grad_norm = 1.0
frame_skip = 1
Expand Down
6 changes: 3 additions & 3 deletions examples/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def make_replay_buffer(
batch_size,
prb=False,
buffer_size=1000000,
buffer_scratch_dir=None,
scratch_dir=None,
device="cpu",
prefetch=3,
):
Expand All @@ -137,7 +137,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand All @@ -148,7 +148,7 @@ def make_replay_buffer(
prefetch=prefetch,
storage=LazyMemmapStorage(
buffer_size,
scratch_dir=buffer_scratch_dir,
scratch_dir=scratch_dir,
device=device,
),
batch_size=batch_size,
Expand Down
22 changes: 12 additions & 10 deletions examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import hydra
import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer
Expand All @@ -17,7 +17,7 @@
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.envs.libs.vmas import VmasEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import EGreedyWrapper, QValueModule, SafeSequential
from torchrl.modules import EGreedyModule, QValueModule, SafeSequential
from torchrl.modules.models.multiagent import MultiAgentMLP
from torchrl.objectives import DQNLoss, SoftUpdate, ValueEstimators
from utils.logging import init_logging, log_evaluation, log_training
Expand All @@ -31,7 +31,7 @@ def rendering_callback(env, td):
@hydra.main(version_base="1.1", config_path=".", config_name="iql")
def train(cfg: "DictConfig"): # noqa: F821
# Device
cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0"
cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
cfg.env.device = cfg.train.device

# Seeding
Expand Down Expand Up @@ -96,13 +96,15 @@ def train(cfg: "DictConfig"): # noqa: F821
)
qnet = SafeSequential(module, value_module)

qnet_explore = EGreedyWrapper(
qnet_explore = TensorDictSequential(
qnet,
eps_init=0.3,
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
EGreedyModule(
eps_init=0.3,
eps_end=0,
annealing_num_steps=int(cfg.collector.total_frames * (1 / 2)),
action_key=env.action_key,
spec=env.unbatched_action_spec,
),
)

collector = SyncDataCollector(
Expand Down Expand Up @@ -174,7 +176,7 @@ def train(cfg: "DictConfig"): # noqa: F821
optim.zero_grad()
target_net_updater.step()

qnet_explore.step(frames=current_frames) # Update exploration annealing
qnet_explore[1].step(frames=current_frames) # Update exploration annealing
collector.update_policy_weights_()

training_time = time.time() - training_start
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def rendering_callback(env, td):
@hydra.main(version_base="1.1", config_path=".", config_name="maddpg_iddpg")
def train(cfg: "DictConfig"): # noqa: F821
# Device
cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0"
cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
cfg.env.device = cfg.train.device

# Seeding
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def rendering_callback(env, td):
@hydra.main(version_base="1.1", config_path=".", config_name="mappo_ippo")
def train(cfg: "DictConfig"): # noqa: F821
# Device
cfg.train.device = "cpu" if not torch.has_cuda else "cuda:0"
cfg.train.device = "cpu" if not torch.cuda.device_count() else "cuda:0"
cfg.env.device = cfg.train.device

# Seeding
Expand Down
Loading

0 comments on commit 56e0641

Please sign in to comment.