Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent 0dc4622 commit e5a358b
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 16 deletions.
9 changes: 5 additions & 4 deletions sota-implementations/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def update(sampled_tensordict):
storing_device=device,
max_frames_per_traj=-1,
init_random_frames=init_random_frames,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
compile_policy={"mode": compile_mode, "fullgraph": True}
if compile_mode is not None
else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

Expand Down Expand Up @@ -212,9 +214,8 @@ def update(sampled_tensordict):
# Get and log q-values, loss, epsilon, sampling time and training time
log_info.update(
{
"train/q_values": (data["action_value"] * data["action"]).sum().item()
/ frames_per_batch,
"train/q_loss": q_losses.mean().item(),
"train/q_values": data["chosen_action_value"].sum() / frames_per_batch,
"train/q_loss": q_losses.mean(),
"train/epsilon": greedy_module.eps,
}
)
Expand Down
9 changes: 6 additions & 3 deletions sota-implementations/dqn/dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,19 @@ def main(cfg: "DictConfig"): # noqa: F821
device = torch.device(device)

# Make the components
model = make_dqn_model(cfg.env.env_name)
model = make_dqn_model(cfg.env.env_name, device=device)

greedy_module = EGreedyModule(
annealing_num_steps=cfg.collector.annealing_frames,
eps_init=cfg.collector.eps_start,
eps_end=cfg.collector.eps_end,
spec=model.spec,
device=device,
)
model_explore = TensorDictSequential(
model,
greedy_module,
).to(device)
)

# Create the replay buffer
replay_buffer = TensorDictReplayBuffer(
Expand Down Expand Up @@ -135,7 +136,9 @@ def update(sampled_tensordict):
storing_device="cpu",
max_frames_per_traj=-1,
init_random_frames=cfg.collector.init_random_frames,
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
compile_policy={"mode": compile_mode, "fullgraph": True}
if compile_mode is not None
else False,
cudagraph_policy=cfg.compile.cudagraphs,
)

Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/dqn/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def make_dqn_modules_pixels(proof_environment, device):
)
qvalue_module = QValueActor(
module=torch.nn.Sequential(cnn, mlp),
spec=Composite(action=action_spec),
spec=Composite(action=action_spec).to(device),
in_keys=["pixels"],
)
return qvalue_module
Expand Down
11 changes: 6 additions & 5 deletions sota-implementations/dqn/utils_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
# --------------------------------------------------------------------


def make_dqn_modules(proof_environment):
def make_dqn_modules(proof_environment, device):

# Define input shape
input_shape = proof_environment.observation_spec["observation"].shape
Expand All @@ -45,19 +45,20 @@ def make_dqn_modules(proof_environment):
activation_class=torch.nn.ReLU,
out_features=num_outputs,
num_cells=[120, 84],
device=device,
)

qvalue_module = QValueActor(
module=mlp,
spec=Composite(action=action_spec),
spec=Composite(action=action_spec).to(device),
in_keys=["observation"],
)
return qvalue_module


def make_dqn_model(env_name):
proof_environment = make_env(env_name, device="cpu")
qvalue_module = make_dqn_modules(proof_environment)
def make_dqn_model(env_name, device):
proof_environment = make_env(env_name, device=device)
qvalue_module = make_dqn_modules(proof_environment, device=device)
del proof_environment
return qvalue_module

Expand Down
12 changes: 10 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ class CategoricalBox(Box):
n: int
register = invertible_dict()

def __post_init__(self):
# n could be a numpy array or a tensor, making compile go a bit crazy
# We want to make sure we're working with a regular integer
self.__dict__["n"] = int(self.n)

def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> CategoricalBox:
return deepcopy(self)

Expand Down Expand Up @@ -3313,7 +3318,10 @@ def __init__(
)
self.update_mask(mask)
self._provisional_n = None
self._undefined_n = self.space.n < 0

@torch.compiler.assume_constant_result
def _undefined_n(self):
return self.space.n == -1

def enumerate(self) -> torch.Tensor:
dtype = self.dtype
Expand Down Expand Up @@ -3380,7 +3388,7 @@ def set_provisional_n(self, n: int):
self._provisional_n = n

def rand(self, shape: torch.Size = None) -> torch.Tensor:
if self._undefined_n:
if self._undefined_n():
if self._provisional_n is None:
raise RuntimeError(
"Cannot generate random categorical samples for undefined cardinality (n=-1). "
Expand Down
3 changes: 2 additions & 1 deletion torchrl/modules/tensordict_module/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def step(self, frames: int = 1) -> None:
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if exploration_type() == ExplorationType.RANDOM or exploration_type() is None:
expl = exploration_type()
if expl in (ExplorationType.RANDOM, None):
if isinstance(self.action_key, tuple) and len(self.action_key) > 1:
action_tensordict = tensordict.get(self.action_key[:-1])
action_key = self.action_key[-1]
Expand Down

0 comments on commit e5a358b

Please sign in to comment.