Skip to content

Commit

Permalink
[Lint] Add TorchFix linter (#1580)
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 authored Oct 1, 2023
1 parent a02679b commit db1a7d4
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
additional_dependencies:
- flake8-bugbear==22.10.27
- flake8-comprehensions==3.10.1

- torchfix==0.0.2

- repo: https://github.com/PyCQA/pydocstyle
rev: 6.1.1
Expand Down
11 changes: 10 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,18 @@ per-file-ignores =
test/smoke_test_deps.py: F401
test_*.py: F841, E731, E266
test/opengl_rendering.py: F401
test/test_modules.py: F841, E731, E266, TOR101
test/test_tensordictmodules.py: F841, E731, E266, TOR101
torchrl/objectives/cql.py: TOR101
torchrl/objectives/deprecated.py: TOR101
torchrl/objectives/iql.py: TOR101
torchrl/objectives/redq.py: TOR101
torchrl/objectives/sac.py: TOR101
torchrl/objectives/td3.py: TOR101
torchrl/objectives/value/advantages.py: TOR101

exclude = venv
extend-select = B901, C401, C408, C409
extend-select = B901, C401, C408, C409, TOR0, TOR1, TOR2

[pydocstyle]
;select = D417 # Missing argument descriptions in the docstring
Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6723,8 +6723,8 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa
with pytest.raises(AssertionError):
torch.testing.assert_close(cur_embedding[:, 1:], last_embedding[:, :-1])

explicit_reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
-torch.norm(last_embedding - goal_embedding, dim=-1)
explicit_reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - (
-torch.linalg.norm(last_embedding - goal_embedding, dim=-1)
)
torch.testing.assert_close(explicit_reward, td["next", "reward"].squeeze())

Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/vip.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def _step(
cur_embedding = next_tensordict.get(self.out_keys[0])
if last_embedding is not None:
goal_embedding = tensordict["goal_embedding"]
reward = -torch.norm(cur_embedding - goal_embedding, dim=-1) - (
-torch.norm(last_embedding - goal_embedding, dim=-1)
reward = -torch.linalg.norm(cur_embedding - goal_embedding, dim=-1) - (
-torch.linalg.norm(last_embedding - goal_embedding, dim=-1)
)
next_tensordict.set("reward", reward)
return next_tensordict
Expand Down

0 comments on commit db1a7d4

Please sign in to comment.