From fda47cbdb7af513c9474be25b477a37e1dfd54b2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 09:47:08 -0800 Subject: [PATCH 01/16] Update [ghstack-poisoned] --- sota-implementations/iql/discrete_iql.py | 169 +++++++++++++---------- sota-implementations/iql/iql_offline.py | 84 ++++++----- sota-implementations/iql/iql_online.py | 133 ++++++++---------- sota-implementations/iql/utils.py | 4 +- 4 files changed, 205 insertions(+), 185 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 79cf2114d40..203b916c4a2 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -13,16 +13,20 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict import TensorDict +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -87,8 +91,19 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create model model = make_discrete_iql_model(cfg, train_env, eval_env, device) + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0]) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) @@ -97,6 +112,34 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value + + def update(sampled_tensordict): + optimizer.zero_grad(set_to_none=True) + # compute losses + actor_loss, _ = loss_module.actor_loss(sampled_tensordict) + value_loss, _ = loss_module.value_loss(sampled_tensordict) + q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + return TensorDict( + metadata.update( + {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss} + ) + ).detach() + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) # Main loop collected_frames = 0 @@ -112,84 +155,52 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + + collector_iter = iter(collector) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) + # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.reshape(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("buffer - extend"): + tensordict = tensordict.reshape(-1) + + # add to replay buffer + replay_buffer.extend(tensordict) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - actor_loss, _ = loss_module.actor_loss(sampled_tensordict) - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - value_loss, _ = loss_module.value_loss(sampled_tensordict) - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - q_loss, metadata = loss_module.qvalue_loss(sampled_tensordict) - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - sampled_tensordict.set( - loss_module.tensor_keys.priority, - metadata.pop("td_error").detach().max(0).values, - ) - replay_buffer.update_priority(sampled_tensordict) - - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + # sample from replay buffer + with timeit("buffer - sample"): + sampled_tensordict = replay_buffer.sample().to(device) + + with timeit("training - update"): + metadata = update(sampled_tensordict) + # update priority + if prb: + sampled_tensordict.set( + loss_module.tensor_keys.priority, + metadata.pop("td_error").detach().max(0).values, + ) + replay_buffer.update_priority(sampled_tensordict) + episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] - # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -197,18 +208,28 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + + # Logging + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = metadata["q_loss"] + metrics_to_log["train/actor_loss"] = metadata["actor_loss"] + metrics_to_log["train/value_loss"] = metadata["value_loss"] + metrics_to_log.update(timeit.todict(prefix="time")) if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() + timeit.erase() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") if __name__ == "__main__": diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 09cf9954b86..1853bc79a5f 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -11,16 +11,19 @@ """ from __future__ import annotations -import time +import warnings import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from tensordict.nn import CudaGraphModule + +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -85,54 +88,62 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - - gradient_steps = cfg.optim.gradient_steps - evaluation_interval = cfg.logger.eval_iter - eval_steps = cfg.logger.eval_steps - - # Training loop - start_time = time.time() - for i in range(gradient_steps): - pbar.update(1) - # sample data - data = replay_buffer.sample() - - if data.device != device: - data = data.to(device, non_blocking=True) - + def update(data): + optimizer.zero_grad(set_to_none=True) # compute losses loss_info = loss_module(data) actor_loss = loss_info["loss_actor"] value_loss = loss_info["loss_value"] q_loss = loss_info["loss_qvalue"] - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() + (actor_loss + value_loss + q_loss).backward() + optimizer.step() # update qnet_target params target_net_updater.step() + return loss_info.detach() + + compile_mode = None + if cfg.compile.compile: + compile_mode = cfg.compile.compile_mode + if compile_mode in ("", None): + if cfg.compile.cudagraphs: + compile_mode = "default" + else: + compile_mode = "reduce-overhead" + + if cfg.compile.compile: + update = torch.compile(update, mode=compile_mode) + if cfg.compile.cudagraphs: + warnings.warn( + "CudaGraphModule is experimental and may lead to silently wrong results. Use with caution.", + category=UserWarning, + ) + update = CudaGraphModule(update, warmup=50) + + pbar = tqdm.tqdm(range(cfg.optim.gradient_steps)) + + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + + # Training loop + for i in pbar: + # sample data + with timeit("sample"): + data = replay_buffer.sample() + data = data.to(device) - # log metrics - to_log = { - "loss_actor": actor_loss.item(), - "loss_qvalue": q_loss.item(), - "loss_value": value_loss.item(), - } + with timeit("update"): + loss_info = update(data) # evaluation + to_log = loss_info.to_dict() if i % evaluation_interval == 0: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("eval"): eval_td = eval_env.rollout( max_steps=eval_steps, policy=model[0], auto_cast_to_device=True ) @@ -147,7 +158,6 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training time: {time.time() - start_time}") if __name__ == "__main__": diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 8497d24f106..de32678e6cd 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -13,16 +13,15 @@ """ from __future__ import annotations -import time - import hydra import numpy as np import torch import tqdm -from torchrl._utils import logger as torchrl_logger +from torchrl._utils import timeit from torchrl.envs import set_gym_backend from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.objectives import group_optimizers from torchrl.record.loggers import generate_exp_name, get_logger from utils import ( @@ -97,10 +96,11 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( cfg.optim, loss_module ) + optimizer = group_optimizers(optimizer_actor, optimizer_critic, optimizer_value) + del optimizer_actor, optimizer_critic, optimizer_value # Main loop collected_frames = 0 - pbar = tqdm.tqdm(total=cfg.collector.total_frames) init_random_frames = cfg.collector.init_random_frames num_updates = int( @@ -112,82 +112,61 @@ def main(cfg: "DictConfig"): # noqa: F821 eval_iter = cfg.logger.eval_iter frames_per_batch = cfg.collector.frames_per_batch eval_rollout_steps = cfg.collector.max_frames_per_traj - sampling_start = start_time = time.time() - for tensordict in collector: - sampling_time = time.time() - sampling_start - pbar.update(tensordict.numel()) + collector_iter = iter(collector) + pbar = tqdm.tqdm(range(collector.total_frames)) + for _ in range(len(collector)): + with timeit("collection"): + tensordict = next(collector_iter) + current_frames = tensordict.numel() + pbar.update(current_frames) # update weights of the inference policy collector.update_policy_weights_() - tensordict = tensordict.view(-1) - current_frames = tensordict.numel() - # add to replay buffer - replay_buffer.extend(tensordict.cpu()) + with timeit("rb - extend"): + # add to replay buffer + tensordict = tensordict.rehsape(-1) + replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames # optimization steps - training_start = time.time() - if collected_frames >= init_random_frames: - for _ in range(num_updates): - # sample from replay buffer - sampled_tensordict = replay_buffer.sample().clone() - if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to( - device, non_blocking=True - ) - else: - sampled_tensordict = sampled_tensordict - # compute losses - loss_info = loss_module(sampled_tensordict) - actor_loss = loss_info["loss_actor"] - value_loss = loss_info["loss_value"] - q_loss = loss_info["loss_qvalue"] - - optimizer_actor.zero_grad() - actor_loss.backward() - optimizer_actor.step() - - optimizer_value.zero_grad() - value_loss.backward() - optimizer_value.step() - - optimizer_critic.zero_grad() - q_loss.backward() - optimizer_critic.step() - - # update qnet_target params - target_net_updater.step() - - # update priority - if prb: - replay_buffer.update_priority(sampled_tensordict) - training_time = time.time() - training_start + with timeit("training"): + if collected_frames >= init_random_frames: + for _ in range(num_updates): + with timeit("rb - sampling"): + # sample from replay buffer + sampled_tensordict = replay_buffer.sample().to(device) + + def update(sampled_tensordict): + optimizer.zero_grad() + # compute losses + loss_info = loss_module(sampled_tensordict) + actor_loss = loss_info["loss_actor"] + value_loss = loss_info["loss_value"] + q_loss = loss_info["loss_qvalue"] + + (actor_loss + value_loss + q_loss).backward() + optimizer.step() + + # update qnet_target params + target_net_updater.step() + return loss_info.detach() + + with timeit("update"): + loss_info = update(sampled_tensordict) + # update priority + if prb: + replay_buffer.update_priority(sampled_tensordict) episode_rewards = tensordict["next", "episode_reward"][ tensordict["next", "done"] ] # Logging metrics_to_log = {} - if len(episode_rewards) > 0: - episode_length = tensordict["next", "step_count"][ - tensordict["next", "done"] - ] - metrics_to_log["train/reward"] = episode_rewards.mean().item() - metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( - episode_length - ) - if collected_frames >= init_random_frames: - metrics_to_log["train/q_loss"] = q_loss.detach() - metrics_to_log["train/actor_loss"] = actor_loss.detach() - metrics_to_log["train/value_loss"] = value_loss.detach() - metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() - metrics_to_log["train/sampling_time"] = sampling_time - metrics_to_log["train/training_time"] = training_time - # Evaluation if abs(collected_frames % eval_iter) < frames_per_batch: - with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad(): - eval_start = time.time() + with set_exploration_type( + ExplorationType.DETERMINISTIC + ), torch.no_grad(), timeit("evaluating"): eval_rollout = eval_env.rollout( eval_rollout_steps, model[0], @@ -195,25 +174,33 @@ def main(cfg: "DictConfig"): # noqa: F821 break_when_any_done=True, ) eval_env.apply(dump_video) - eval_time = time.time() - eval_start eval_reward = eval_rollout["next", "reward"].sum(-2).mean().item() metrics_to_log["eval/reward"] = eval_reward - metrics_to_log["eval/time"] = eval_time + if len(episode_rewards) > 0: + episode_length = tensordict["next", "step_count"][ + tensordict["next", "done"] + ] + metrics_to_log["train/reward"] = episode_rewards.mean().item() + metrics_to_log["train/episode_length"] = episode_length.sum().item() / len( + episode_length + ) + if collected_frames >= init_random_frames: + metrics_to_log["train/q_loss"] = loss_info["loss_qvalue"].detach() + metrics_to_log["train/actor_loss"] = loss_info["loss_actor"].detach() + metrics_to_log["train/value_loss"] = loss_info["loss_value"].detach() + metrics_to_log["train/entropy"] = loss_info.get("entropy").detach() + metrics_to_log.update(timeit.todict(prefix="time")) + if logger is not None: log_metrics(logger, metrics_to_log, collected_frames) - sampling_start = time.time() collector.shutdown() - end_time = time.time() - execution_time = end_time - start_time if not eval_env.is_closed: eval_env.close() if not train_env.is_closed: train_env.close() - torchrl_logger.info(f"Training took {execution_time:.2f} seconds to finish") - if __name__ == "__main__": main() diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 261cb912de0..b74eb8e8f79 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -118,7 +118,7 @@ def make_environment(cfg, train_num_envs=1, eval_num_envs=1, logger=None): # --------------------------- -def make_collector(cfg, train_env, actor_model_explore): +def make_collector(cfg, train_env, actor_model_explore, compile_mode): """Make collector.""" device = cfg.collector.device if device in ("", None): @@ -134,6 +134,8 @@ def make_collector(cfg, train_env, actor_model_explore): max_frames_per_traj=cfg.collector.max_frames_per_traj, total_frames=cfg.collector.total_frames, device=device, + compile_policy={"mode": compile_mode} if compile_mode else False, + cuda=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector From c97088f7ceff5880d098397c91ce1168f54d0d92 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:35:48 -0800 Subject: [PATCH 02/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index b74eb8e8f79..42ca848318a 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -135,7 +135,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode): total_frames=cfg.collector.total_frames, device=device, compile_policy={"mode": compile_mode} if compile_mode else False, - cuda=cfg.compile.cudagraphs, + cudagraph_policy=cfg.compile.cudagraphs, ) collector.set_seed(cfg.env.seed) return collector From ab3616c6dda8e96eb383bedcd549a616953c5e17 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 14 Dec 2024 20:41:18 -0800 Subject: [PATCH 03/16] Update [ghstack-poisoned] --- sota-implementations/iql/discrete_iql.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 504e46e19df..93072f24f0a 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -129,11 +129,10 @@ def update(sampled_tensordict): # update qnet_target params target_net_updater.step() - return TensorDict( - metadata.update( - {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss} - ) - ).detach() + metadata.update( + {"actor_loss": actor_loss, "value_loss": value_loss, "q_loss": q_loss} + ) + return TensorDict(metadata).detach() if cfg.compile.compile: update = torch.compile(update, mode=compile_mode) From 65bb4f3ac93a7ca042df60820bb963bb29f2fbbb Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:22:53 -0800 Subject: [PATCH 04/16] Update [ghstack-poisoned] --- sota-implementations/iql/discrete_iql.py | 1 + sota-implementations/iql/iql_offline.py | 1 + sota-implementations/iql/iql_online.py | 1 + 3 files changed, 3 insertions(+) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 93072f24f0a..17153a59913 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -184,6 +184,7 @@ def update(sampled_tensordict): sampled_tensordict = replay_buffer.sample().to(device) with timeit("training - update"): + torch.compiler.cudagraph_mark_step_begin() metadata = update(sampled_tensordict) # update priority if prb: diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 94334929664..4adb0fce7e4 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -139,6 +139,7 @@ def update(data): data = data.to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_info = update(data) # evaluation diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index eb063fcd139..499c2164b52 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -176,6 +176,7 @@ def update(sampled_tensordict): # sample from replay buffer sampled_tensordict = replay_buffer.sample().to(device) with timeit("update"): + torch.compiler.cudagraph_mark_step_begin() loss_info = update(sampled_tensordict) # update priority if prb: From 5ffd212a67d6f095eea36528d49a684debc79c33 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:24:37 -0800 Subject: [PATCH 05/16] Update [ghstack-poisoned] --- sota-implementations/iql/iql_online.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 499c2164b52..6e9c8a0d8ea 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -164,7 +164,7 @@ def update(sampled_tensordict): with timeit("rb - extend"): # add to replay buffer - tensordict = tensordict.rehsape(-1) + tensordict = tensordict.reshape(-1) replay_buffer.extend(tensordict.cpu()) collected_frames += current_frames From bac5b006afb60f9b3c3662c52f09c6a7368b9283 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:28:10 -0800 Subject: [PATCH 06/16] Update [ghstack-poisoned] --- sota-implementations/iql/iql_online.py | 4 +++- sota-implementations/iql/utils.py | 15 +++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index 6e9c8a0d8ea..f3537ef0d47 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -103,7 +103,9 @@ def main(cfg: "DictConfig"): # noqa: F821 compile_mode = "reduce-overhead" # Create collector - collector = make_collector(cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode) + collector = make_collector( + cfg, train_env, actor_model_explore=model[0], compile_mode=compile_mode + ) # Create loss loss_module, target_net_updater = make_loss(cfg.loss, model) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 42ca848318a..bc643eb6149 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -10,6 +10,7 @@ import torch.optim from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Categorical from torchrl.collectors import SyncDataCollector from torchrl.data import ( @@ -36,7 +37,6 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( MLP, - OneHotCategorical, ProbabilisticActor, SafeModule, TanhNormal, @@ -44,7 +44,6 @@ ) from torchrl.objectives import DiscreteIQLLoss, HardUpdate, IQLLoss, SoftUpdate from torchrl.record import VideoRecorder - from torchrl.trainers.helpers.models import ACTIVATIONS @@ -58,7 +57,11 @@ def env_maker(cfg, device="cpu", from_pixels=False): if lib in ("gym", "gymnasium"): with set_gym_backend(lib): return GymEnv( - cfg.env.name, device=device, from_pixels=from_pixels, pixels_only=False + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + categorical_action_encoding=True, ) elif lib == "dm_control": env = DMControlEnv( @@ -221,8 +224,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): spec=action_spec, distribution_class=TanhNormal, distribution_kwargs={ - "low": action_spec.space.low, - "high": action_spec.space.high, + "low": action_spec.space.low.to(device), + "high": action_spec.space.high.to(device), "tanh_loc": False, }, default_interaction_type=ExplorationType.RANDOM, @@ -318,7 +321,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): module=actor_module, in_keys=["logits"], out_keys=["action"], - distribution_class=OneHotCategorical, + distribution_class=Categorical, distribution_kwargs={}, default_interaction_type=InteractionType.RANDOM, return_log_prob=False, From 8a623ac8d1a1955afc222d1647141358df6f6250 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 10:34:04 -0800 Subject: [PATCH 07/16] Update [ghstack-poisoned] --- sota-implementations/iql/iql_offline.py | 2 +- sota-implementations/iql/iql_online.py | 2 +- sota-implementations/iql/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sota-implementations/iql/iql_offline.py b/sota-implementations/iql/iql_offline.py index 4adb0fce7e4..1a270ee8ccc 100644 --- a/sota-implementations/iql/iql_offline.py +++ b/sota-implementations/iql/iql_offline.py @@ -85,7 +85,7 @@ def main(cfg: "DictConfig"): # noqa: F821 model = make_iql_model(cfg, train_env, eval_env, device) # Create loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + loss_module, target_net_updater = make_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( diff --git a/sota-implementations/iql/iql_online.py b/sota-implementations/iql/iql_online.py index f3537ef0d47..4f6c765d1e8 100644 --- a/sota-implementations/iql/iql_online.py +++ b/sota-implementations/iql/iql_online.py @@ -108,7 +108,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create loss - loss_module, target_net_updater = make_loss(cfg.loss, model) + loss_module, target_net_updater = make_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index bc643eb6149..3c561f2b7dc 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -374,7 +374,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): # --------- -def make_loss(loss_cfg, model): +def make_loss(loss_cfg, model, device): loss_module = IQLLoss( model[0], model[1], @@ -383,7 +383,7 @@ def make_loss(loss_cfg, model): temperature=loss_cfg.temperature, expectile=loss_cfg.expectile, ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = SoftUpdate(loss_module, tau=loss_cfg.tau) return loss_module, target_net_updater From 5f4cdfe3f8cea5b9299f980c89867a32135f6e14 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:39:30 -0800 Subject: [PATCH 08/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 3c561f2b7dc..9da7e433abb 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -184,7 +184,8 @@ def make_offline_replay_buffer(rb_cfg): dataset_id=rb_cfg.dataset, split_trajs=False, batch_size=rb_cfg.batch_size, - sampler=SamplerWithoutReplacement(drop_last=False), + # We use drop_last to avoid recompiles (and dynamic shapes) + sampler=SamplerWithoutReplacement(drop_last=True), prefetch=4, direct_download=True, ) From 430fa6b6fff49c5e44f08ffe0a3eecb794825781 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:47:29 -0800 Subject: [PATCH 09/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 47 +++++++++++++------------------ 1 file changed, 20 insertions(+), 27 deletions(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 9da7e433abb..a0d5d881254 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -298,19 +298,16 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): """Make discrete IQL agent.""" # Define Actor Network in_keys = ["observation"] - action_spec = train_env.action_spec - if train_env.batch_size: - action_spec = action_spec[(0,) * len(train_env.batch_size)] + action_spec = train_env.action_spec_unbatched # Define Actor Network in_keys = ["observation"] - actor_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": ACTIVATIONS[cfg.model.activation], - } - - actor_net = MLP(**actor_net_kwargs) + actor_net = MLP( + num_cells=cfg.model.hidden_sizes, + out_features=action_spec.space.n, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, + ) actor_module = SafeModule( module=actor_net, @@ -318,7 +315,7 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): out_keys=["logits"], ) actor = ProbabilisticActor( - spec=Composite(action=eval_env.action_spec), + spec=Composite(action=eval_env.action_spec_unbatched).to(device), module=actor_module, in_keys=["logits"], out_keys=["action"], @@ -329,15 +326,12 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): ) # Define Critic Network - qvalue_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": action_spec.shape[-1], - "activation_class": ACTIVATIONS[cfg.model.activation], - } qvalue_net = MLP( - **qvalue_net_kwargs, + num_cells=cfg.model.hidden_sizes, + out_features=action_spec.space.n, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, ) - qvalue = TensorDictModule( in_keys=["observation"], out_keys=["state_action_value"], @@ -345,26 +339,25 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): ) # Define Value Network - value_net_kwargs = { - "num_cells": cfg.model.hidden_sizes, - "out_features": 1, - "activation_class": ACTIVATIONS[cfg.model.activation], - } - value_net = MLP(**value_net_kwargs) + value_net = MLP( + num_cells=cfg.model.hidden_sizes, + out_features=1, + activation_class=ACTIVATIONS[cfg.model.activation], + device=device, + ) value_net = TensorDictModule( in_keys=["observation"], out_keys=["state_value"], module=value_net, ) - model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) + model = torch.nn.ModuleList([actor, qvalue, value_net]) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td eval_env.close() return model From 06dfd214d638932238a57c46e225ab6f5ac7ee50 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:49:30 -0800 Subject: [PATCH 10/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index a0d5d881254..68acb9bfed4 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -250,12 +250,10 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"): model = torch.nn.ModuleList([actor, qvalue, value_net]).to(device) # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = eval_env.fake_tensordict() td = td.to(device) for net in model: net(td) - del td - eval_env.close() return model @@ -358,7 +356,6 @@ def make_discrete_iql_model(cfg, train_env, eval_env, device): td = td.to(device) for net in model: net(td) - eval_env.close() return model From c6551fa27b3a53676d6aa25cf8b53789e3e7f35c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:55:40 -0800 Subject: [PATCH 11/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 1 + torchrl/data/utils.py | 2 +- torchrl/objectives/iql.py | 14 ++++++++++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 68acb9bfed4..168416d80da 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -388,6 +388,7 @@ def make_discrete_loss(loss_cfg, model): loss_function=loss_cfg.loss_function, temperature=loss_cfg.temperature, expectile=loss_cfg.expectile, + action_space="categorical", ) loss_module.make_value_estimator(gamma=loss_cfg.gamma) target_net_updater = HardUpdate( diff --git a/torchrl/data/utils.py b/torchrl/data/utils.py index db2c8afca10..d43cbd7810d 100644 --- a/torchrl/data/utils.py +++ b/torchrl/data/utils.py @@ -307,7 +307,7 @@ def _process_action_space_spec(action_space, spec): return action_space, spec -def _find_action_space(action_space): +def _find_action_space(action_space) -> str: if isinstance(action_space, TensorSpec): if isinstance(action_space, Composite): if "action" in action_space.keys(): diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 71d1a22e17b..48667728071 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) - state_action_value = td_q.get(self.tensor_keys.state_action_value) + state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": if action.shape != state_action_value.shape: @@ -791,9 +791,11 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: chosen_state_action_value = torch.gather( state_action_value, -1, index=action ).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) if log_prob.shape != min_Q.shape: raise RuntimeError( @@ -834,9 +836,11 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: chosen_state_action_value = torch.gather( state_action_value, -1, index=action ).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") min_Q, _ = torch.min(chosen_state_action_value, dim=0) # state value td_copy = tensordict.select(*self.value_network.in_keys, strict=False) @@ -867,9 +871,11 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) - else: + elif self.action_space == "one_hot": action = action.to(torch.float) pred_val = (state_action_value * action).sum(-1) + else: + raise RuntimeError(f"Unknown action space {self.action_space}.") td_error = (pred_val - target_value.expand_as(pred_val)).pow(2) loss_qval = distance_loss( From 38ae5801ec3c1033c13f8a2e79533e4311101f16 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:57:21 -0800 Subject: [PATCH 12/16] Update [ghstack-poisoned] --- torchrl/objectives/iql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 48667728071..9b83995b1cd 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -782,7 +782,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: # Min Q value td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) td_q = self._vmap_qvalue_networkN0(td_q, self.target_qvalue_network_params) - state_action_value = td_q.get(self.tensor_keys.chosen_state_action_value) + state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": if action.shape != state_action_value.shape: From 0a292b196a3eaa8107c801c141642219cb6e32f8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 11:59:12 -0800 Subject: [PATCH 13/16] Update [ghstack-poisoned] --- torchrl/objectives/iql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 9b83995b1cd..26cd8e2e89a 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -785,7 +785,7 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) chosen_state_action_value = torch.gather( @@ -830,7 +830,7 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) chosen_state_action_value = torch.gather( @@ -867,7 +867,7 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.shape != state_action_value.shape: + if action.ndim < state_action_value.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) From 7769d9675d7ba46ec12788fb4174329db7f6d2e9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 12:03:57 -0800 Subject: [PATCH 14/16] Update [ghstack-poisoned] --- torchrl/objectives/iql.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 26cd8e2e89a..039d5fc1c34 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -785,12 +785,15 @@ def actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) @@ -830,12 +833,17 @@ def value_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < ( + state_action_value.ndim - (td_q.ndim - tensordict.ndim) + ): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - chosen_state_action_value = torch.gather( - state_action_value, -1, index=action - ).squeeze(-1) + chosen_state_action_value = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) chosen_state_action_value = (state_action_value * action).sum(-1) @@ -867,10 +875,15 @@ def qvalue_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]: state_action_value = td_q.get(self.tensor_keys.state_action_value) action = tensordict.get(self.tensor_keys.action) if self.action_space == "categorical": - if action.ndim < state_action_value.ndim: + if action.ndim < (state_action_value.ndim - (td_q.ndim - tensordict.ndim)): # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) - pred_val = torch.gather(state_action_value, -1, index=action).squeeze(-1) + pred_val = torch.vmap( + lambda state_action_value, action: torch.gather( + state_action_value, -1, index=action + ).squeeze(-1), + (0, None), + )(state_action_value, action) elif self.action_space == "one_hot": action = action.to(torch.float) pred_val = (state_action_value * action).sum(-1) From c2888a0dc61f367aa2b9c5ead8c567bccb360550 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 12:06:35 -0800 Subject: [PATCH 15/16] Update [ghstack-poisoned] --- sota-implementations/iql/discrete_iql.py | 2 +- sota-implementations/iql/discrete_iql.yaml | 2 +- sota-implementations/iql/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sota-implementations/iql/discrete_iql.py b/sota-implementations/iql/discrete_iql.py index 17153a59913..e51bd25a8a8 100644 --- a/sota-implementations/iql/discrete_iql.py +++ b/sota-implementations/iql/discrete_iql.py @@ -109,7 +109,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # Create loss - loss_module, target_net_updater = make_discrete_loss(cfg.loss, model) + loss_module, target_net_updater = make_discrete_loss(cfg.loss, model, device=device) # Create optimizer optimizer_actor, optimizer_critic, optimizer_value = make_iql_optimizer( diff --git a/sota-implementations/iql/discrete_iql.yaml b/sota-implementations/iql/discrete_iql.yaml index 81b5a88e59f..3f53ab9a68a 100644 --- a/sota-implementations/iql/discrete_iql.yaml +++ b/sota-implementations/iql/discrete_iql.yaml @@ -62,5 +62,5 @@ loss: compile: compile: False - compile_mode: + compile_mode: default cudagraphs: False diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 168416d80da..04cc2b250ab 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -380,7 +380,7 @@ def make_loss(loss_cfg, model, device): return loss_module, target_net_updater -def make_discrete_loss(loss_cfg, model): +def make_discrete_loss(loss_cfg, model, device): loss_module = DiscreteIQLLoss( model[0], model[1], @@ -390,7 +390,7 @@ def make_discrete_loss(loss_cfg, model): expectile=loss_cfg.expectile, action_space="categorical", ) - loss_module.make_value_estimator(gamma=loss_cfg.gamma) + loss_module.make_value_estimator(gamma=loss_cfg.gammam, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=loss_cfg.hard_update_interval ) From b73eea2940581d42e4babf70b60c06e631f130f8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 15 Dec 2024 13:07:40 -0800 Subject: [PATCH 16/16] Update [ghstack-poisoned] --- sota-implementations/iql/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/iql/utils.py b/sota-implementations/iql/utils.py index 04cc2b250ab..519d4350536 100644 --- a/sota-implementations/iql/utils.py +++ b/sota-implementations/iql/utils.py @@ -390,7 +390,7 @@ def make_discrete_loss(loss_cfg, model, device): expectile=loss_cfg.expectile, action_space="categorical", ) - loss_module.make_value_estimator(gamma=loss_cfg.gammam, device=device) + loss_module.make_value_estimator(gamma=loss_cfg.gamma, device=device) target_net_updater = HardUpdate( loss_module, value_network_update_interval=loss_cfg.hard_update_interval )