diff --git a/tests/unit/transform/test_base.py b/tests/unit/transform/test_base.py index b0ae37d5..c3c14739 100644 --- a/tests/unit/transform/test_base.py +++ b/tests/unit/transform/test_base.py @@ -11,7 +11,7 @@ class FakeTransform(Transform[_B, _C]): """ - Fake Transform to test `required_keys` and `output_keys` when composing and conjuncting. + Fake ``Transform`` to test `required_keys` and `output_keys` when composing and conjuncting. """ def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]): @@ -34,6 +34,11 @@ def output_keys(self) -> set[Tensor]: def test_apply_keys(): + """ + Tests that a ``Transform`` checks that the provided dictionary to the `__apply__` function + contains keys that correspond exactly to `required_keys`. + """ + t1 = torch.randn([2]) t2 = torch.randn([3]) transform = FakeTransform({t1}, {t1, t2}) @@ -51,6 +56,11 @@ def test_apply_keys(): def test_compose_keys_match(): + """ + Tests that the composition of ``Transform``s checks that the inner transform's `output_keys` + match with the outer transform's `required_keys`. + """ + t1 = torch.randn([2]) t2 = torch.randn([3]) transform1 = FakeTransform({t1}, {t1, t2}) @@ -63,6 +73,11 @@ def test_compose_keys_match(): def test_conjunct_required_keys(): + """ + Tests that the conjunction of ``Transform``s checks that the provided transforms all havve the + same `required_keys`. + """ + t1 = torch.randn([2]) t2 = torch.randn([3]) @@ -80,6 +95,11 @@ def test_conjunct_required_keys(): def test_conjunct_wrong_output_keys(): + """ + Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are + disjoint. + """ + t1 = torch.randn([2]) t2 = torch.randn([3]) @@ -97,6 +117,11 @@ def test_conjunct_wrong_output_keys(): def test_conjunction_empty_transforms(): + """ + Tests that it is possible to take the conjunction of no transform, this should return an empty + dictionary. + """ + conjunction = Conjunction([]) assert len(conjunction(TensorDict({}))) == 0