Skip to content

Commit

Permalink
Add docstring to all tests in test_base.py
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreQuinton committed Jun 6, 2024
1 parent 7b7177f commit 8dd4cbd
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion tests/unit/transform/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class FakeTransform(Transform[_B, _C]):
"""
Fake Transform to test `required_keys` and `output_keys` when composing and conjuncting.
Fake ``Transform`` to test `required_keys` and `output_keys` when composing and conjuncting.
"""

def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]):
Expand All @@ -34,6 +34,11 @@ def output_keys(self) -> set[Tensor]:


def test_apply_keys():
"""
Tests that a ``Transform`` checks that the provided dictionary to the `__apply__` function
contains keys that correspond exactly to `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])
transform = FakeTransform({t1}, {t1, t2})
Expand All @@ -51,6 +56,11 @@ def test_apply_keys():


def test_compose_keys_match():
"""
Tests that the composition of ``Transform``s checks that the inner transform's `output_keys`
match with the outer transform's `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])
transform1 = FakeTransform({t1}, {t1, t2})
Expand All @@ -63,6 +73,11 @@ def test_compose_keys_match():


def test_conjunct_required_keys():
"""
Tests that the conjunction of ``Transform``s checks that the provided transforms all havve the
same `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])

Expand All @@ -80,6 +95,11 @@ def test_conjunct_required_keys():


def test_conjunct_wrong_output_keys():
"""
Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are
disjoint.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])

Expand All @@ -97,6 +117,11 @@ def test_conjunct_wrong_output_keys():


def test_conjunction_empty_transforms():
"""
Tests that it is possible to take the conjunction of no transform, this should return an empty
dictionary.
"""

conjunction = Conjunction([])

assert len(conjunction(TensorDict({}))) == 0

0 comments on commit 8dd4cbd

Please sign in to comment.