Skip to content

Commit

Permalink
Allow Conjunctions to take no Transforms
Browse files Browse the repository at this point in the history
* Remove constraint of Conjunction to have non-empty list of transforms.

* Add tests to base classes of Transforms.
  • Loading branch information
PierreQuinton authored Jun 6, 2024
1 parent d692469 commit f9853e1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/torchjd/transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,12 @@ def output_keys(self) -> set[Tensor]:

class Conjunction(Transform[_A, _B]):
def __init__(self, transforms: Sequence[Transform[_A, _B]]):
if len(transforms) == 0:
raise ValueError("Parameter `transforms` cannot be empty.")

self.transforms = transforms

self._required_keys = transforms[0].required_keys

for transform in transforms[1:]:
self._required_keys = set(
key for transform in transforms for key in transform.required_keys
)
for transform in transforms:
if transform.required_keys != self.required_keys:
raise ValueError("All transforms should require the same set of keys.")

Expand Down
127 changes: 127 additions & 0 deletions tests/unit/transform/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import typing

import pytest
import torch
from torch import Tensor

from torchjd.transform._utils import _B, _C
from torchjd.transform.base import Conjunction, Transform
from torchjd.transform.tensor_dict import TensorDict


class FakeTransform(Transform[_B, _C]):
"""
Fake ``Transform`` to test `required_keys` and `output_keys` when composing and conjuncting.
"""

def __init__(self, required_keys: set[Tensor], output_keys: set[Tensor]):
self._required_keys = required_keys
self._output_keys = output_keys

def _compute(self, input: _B) -> _C:
# ignore the input, create a dictionary with the right keys as an output.
# cast the type for the purpose of type-checking.
output_dict = {key: torch.empty(0) for key in self._output_keys}
return typing.cast(_C, output_dict)

@property
def required_keys(self) -> set[Tensor]:
return self._required_keys

@property
def output_keys(self) -> set[Tensor]:
return self._output_keys


def test_apply_keys():
"""
Tests that a ``Transform`` checks that the provided dictionary to the `__apply__` function
contains keys that correspond exactly to `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])
transform = FakeTransform({t1}, {t1, t2})

transform(TensorDict({t1: t2}))

with pytest.raises(ValueError):
transform(TensorDict({t2: t1}))

with pytest.raises(ValueError):
transform(TensorDict({}))

with pytest.raises(ValueError):
transform(TensorDict({t1: t2, t2: t1}))


def test_compose_keys_match():
"""
Tests that the composition of ``Transform``s checks that the inner transform's `output_keys`
match with the outer transform's `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])
transform1 = FakeTransform({t1}, {t1, t2})
transform2 = FakeTransform({t2}, {t1})

transform1 << transform2

with pytest.raises(ValueError):
transform2 << transform1


def test_conjunct_required_keys():
"""
Tests that the conjunction of ``Transform``s checks that the provided transforms all have the
same `required_keys`.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])

transform1 = FakeTransform({t1}, set())
transform2 = FakeTransform({t1}, set())
transform3 = FakeTransform({t2}, set())

transform1 | transform2

with pytest.raises(ValueError):
transform2 | transform3

with pytest.raises(ValueError):
transform1 | transform2 | transform3


def test_conjunct_wrong_output_keys():
"""
Tests that the conjunction of ``Transform``s checks that the transforms `output_keys` are
disjoint.
"""

t1 = torch.randn([2])
t2 = torch.randn([3])

transform1 = FakeTransform(set(), {t1, t2})
transform2 = FakeTransform(set(), {t1})
transform3 = FakeTransform(set(), {t2})

transform2 | transform3

with pytest.raises(ValueError):
transform1 | transform3

with pytest.raises(ValueError):
transform1 | transform2 | transform3


def test_conjunction_empty_transforms():
"""
Tests that it is possible to take the conjunction of no transform. This should return an empty
dictionary.
"""

conjunction = Conjunction([])

assert len(conjunction(TensorDict({}))) == 0

0 comments on commit f9853e1

Please sign in to comment.