Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 17, 2024
2 parents baeb8c9 + 663062a commit 0f1c06a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .github/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ dependencies:
- transformers
- ninja
- timm
- gymnasium[atari,accept-rom-license]
- mo-gymnasium[mujoco]
5 changes: 0 additions & 5 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,6 @@ conda env update --file "${this_dir}/environment.yml" --prune
conda deactivate
conda activate "${env_dir}"

echo "installing gymnasium"
pip3 install "gymnasium[atari,accept-rom-license]"
pip3 install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
pip3 install "mujoco" -U

# sanity check: remove?
python3 -c """
import dm_control
Expand Down
11 changes: 11 additions & 0 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,17 @@ def sample(
) -> torch.Tensor:
...

@property
def deterministic_sample(self):
return self.mode

@property
def mode(self) -> torch.Tensor:
if hasattr(self, "logits"):
return (self.logits == self.logits.max(-1, True)[0]).to(torch.long)
else:
return (self.probs == self.probs.max(-1, True)[0]).to(torch.long)

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))

Expand Down

0 comments on commit 0f1c06a

Please sign in to comment.