Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 15, 2024
1 parent ad97990 commit a4abc7b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
5 changes: 4 additions & 1 deletion sota-implementations/gail/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from torchrl.record.loggers import generate_exp_name, get_logger


torch.set_float32_matmul_precision("high")


@hydra.main(config_path="", config_name="config")
def main(cfg: "DictConfig"): # noqa: F821
set_gym_backend(cfg.env.backend).set()
Expand Down Expand Up @@ -260,7 +263,7 @@ def update(data, expert_data, num_network_updates=num_network_updates):
expert_data = expert_data.to(device)

metadata = update(data, expert_data)
d_loss = metadata["d_loss"]
d_loss = metadata["dloss"]
alpha = metadata["alpha"]

# Get training rewards and episode lengths
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3319,9 +3319,9 @@ def __init__(
self.update_mask(mask)
self._provisional_n = None

@torch.compiler.assume_constant_result
@property
def _undefined_n(self):
return self.space.n == -1
return self.space.n < 0

def enumerate(self) -> torch.Tensor:
dtype = self.dtype
Expand Down

0 comments on commit a4abc7b

Please sign in to comment.