diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index fa22531..e4de632 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -9,6 +9,7 @@ import dataclasses import functools import typing +import math import torch @@ -295,9 +296,28 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens tensor = self.tensor.reshape(()) return GrassmannTensor(_arrow=(), _edges=(), _tensor=tensor) + if new_shape == (1,) and int(self.tensor.numel()) == 1: + eo = self._calculate_even_odd() + new_shape = (eo,) + cursor_plan: int = 0 cursor_self: int = 0 while cursor_plan != len(new_shape) or cursor_self != self.tensor.dim(): + if cursor_self == self.tensor.dim() and cursor_plan != len(new_shape): + new_shape_check = new_shape[cursor_plan] + if (isinstance(new_shape_check, int) and new_shape_check == 1) or ( + new_shape_check == (1, 0) + ): + arrow.append(False) + edges.append((1, 0)) + shape.append(1) + cursor_plan += 1 + continue + raise AssertionError( + "New shape exceeds after exhausting self dimensions: " + f"edges={self.edges}, new_shape={new_shape}" + ) + if cursor_plan != len(new_shape) and new_shape[cursor_plan] == -1: # Does not change arrow.append(self.arrow[cursor_self]) @@ -306,7 +326,11 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens cursor_self += 1 cursor_plan += 1 continue - elif cursor_plan != len(new_shape) and new_shape[cursor_plan] == (1, 0): + elif ( + cursor_plan != len(new_shape) + and new_shape[cursor_plan] == (1, 0) + and cursor_plan < len(new_shape) - 1 + ): # A trivial plan edge arrow.append(False) edges.append((1, 0)) @@ -532,6 +556,146 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: _tensor=tensor, ) + def svd( + self, + free_names_u: tuple[int, ...], + *, + cutoff: int | None | tuple[int, int] = None, + ) -> tuple[GrassmannTensor, GrassmannTensor, GrassmannTensor]: + """ + This function is used to computes the singular value decomposition of a grassmann tensor. + The SVD are implemented by follow steps: + 1. Split the legs into left and right; + 2. Merge the tensor with two groups. + 3. Split the block tensor into two parts. + 4. Compute the singular value decomposition. + 5. Use cutoff to keep the largest cutoff singular values (globally across even/odd blocks). + 6. Contract U, S and Vh. + 7. Split the legs into original left and right. + The returned tensors U and V are not unique, nor are they continuous with respect to self. + Due to this lack of uniqueness, different hardware and software may compute different singular vectors. + Gradients computed using U or Vh will only be finite when A does not have repeated singular values. + Furthermore, if the distance between any two singular values is close to zero, the gradient + will be numerically unstable, as it depends on the singular values + """ + left_legs = tuple(int(i) for i in free_names_u) + right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs) + assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), ( + "Left/right must cover all tensor legs." + ) + + if isinstance(cutoff, tuple): + assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple." + + order = left_legs + right_legs + tensor = self.permute(order) + + left_dim = math.prod(tensor.tensor.shape[: len(left_legs)]) + right_dim = math.prod(tensor.tensor.shape[len(left_legs) :]) + + tensor = tensor.reshape((left_dim, right_dim)) + + (even_left, odd_left) = tensor.edges[0] + (even_right, odd_right) = tensor.edges[1] + even_tensor = tensor.tensor[:even_left, :even_right] + odd_tensor = tensor.tensor[even_left:, even_right:] + + if even_tensor.numel() > 0: + U_even, S_even, Vh_even = torch.linalg.svd(even_tensor, full_matrices=False) + else: + U_even = even_tensor.new_zeros((even_left, 0)) + S_even = even_tensor.new_zeros((0,)) + Vh_even = even_tensor.new_zeros((0, even_right)) + + if odd_tensor.numel() > 0: + U_odd, S_odd, Vh_odd = torch.linalg.svd(odd_tensor, full_matrices=False) + else: + U_odd = odd_tensor.new_zeros((odd_left, 0)) + S_odd = odd_tensor.new_zeros((0,)) + Vh_odd = odd_tensor.new_zeros((0, odd_right)) + + n_even, n_odd = S_even.shape[0], S_odd.shape[0] + + if cutoff is None: + k_even, k_odd = n_even, n_odd + elif isinstance(cutoff, int): + if n_even == 0 and n_odd == 0: + raise RuntimeError("Both parity block are empty. Can not form SVD.") + assert cutoff > 0, f"Cutoff must be greater than 0, but got {cutoff}" + k_even = min(cutoff, n_even) + k_odd = min(cutoff, n_odd) + elif isinstance(cutoff, tuple): + assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple." + if n_even == 0 and n_odd == 0: + raise RuntimeError("Both parity block are empty. Can not form SVD.") + k_even = max(0, min(int(cutoff[0]), n_even)) + k_odd = max(0, min(int(cutoff[1]), n_odd)) + else: + raise ValueError( + f"Cutoff must be an integer or a tuple of two integers, but got {cutoff}" + ) + + assert (k_even > 0 or n_even == 0) and (k_odd > 0 or n_odd == 0), ( + "Per-block cutoff must be compatible with available singulars" + ) + + keep_even = torch.zeros(n_even, dtype=torch.bool, device=S_even.device) + keep_odd = torch.zeros(n_odd, dtype=torch.bool, device=S_odd.device) + if k_even > 0: + keep_even[:k_even] = True + if k_odd > 0: + keep_odd[:k_odd] = True + + U_even_trunc = U_even[:, keep_even] + S_even_trunc = S_even[keep_even] + Vh_even_trunc = Vh_even[keep_even, :] + + U_odd_trunc = U_odd[:, keep_odd] + S_odd_trunc = S_odd[keep_odd] + Vh_odd_trunc = Vh_odd[keep_odd, :] + + U_tensor = torch.block_diag(U_even_trunc, U_odd_trunc) # type: ignore[no-untyped-call] + S_tensor = torch.cat([S_even_trunc, S_odd_trunc], dim=0) + Vh_tensor = torch.block_diag(Vh_even_trunc, Vh_odd_trunc) # type: ignore[no-untyped-call] + + U_edges = ( + (U_even_trunc.shape[0], U_odd_trunc.shape[0]), + (U_even_trunc.shape[1], U_odd_trunc.shape[1]), + ) + S_edges = ( + (U_even_trunc.shape[1], U_odd_trunc.shape[1]), + (Vh_even_trunc.shape[0], Vh_odd_trunc.shape[0]), + ) + Vh_edges = ( + (Vh_even_trunc.shape[0], Vh_odd_trunc.shape[0]), + (Vh_even_trunc.shape[1], Vh_odd_trunc.shape[1]), + ) + + U = GrassmannTensor(_arrow=(True, True), _edges=U_edges, _tensor=U_tensor) + S = GrassmannTensor( + _arrow=( + False, + True, + ), + _edges=S_edges, + _tensor=torch.diag(S_tensor), + ) + Vh = GrassmannTensor(_arrow=(False, True), _edges=Vh_edges, _tensor=Vh_tensor) + # Split + left_arrow = [self.arrow[i] for i in left_legs] + left_edges = [self.edges[i] for i in left_legs] + + right_arrow = [self.arrow[i] for i in right_legs] + right_edges = [self.edges[i] for i in right_legs] + + U = U.reshape((*left_edges, U_edges[1])) + U._arrow = tuple(left_arrow + [True]) + + Vh = Vh.reshape((Vh_edges[0], *right_edges)) + Vh._arrow = tuple([False] + right_arrow) + + return U, S, Vh + def __post_init__(self) -> None: assert len(self._arrow) == self._tensor.dim(), ( f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." diff --git a/tests/reshape_test.py b/tests/reshape_test.py index a1dac6f..2d2d452 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -197,3 +197,37 @@ def test_reshape_with_none_edge_assertion() -> None: _ = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, -1)) with pytest.raises(AssertionError, match="Ambiguous integer dim"): _ = GrassmannTensor((), (), torch.tensor(2333)).reshape((2, 2)) + + +@pytest.mark.parametrize( + "arrow, edges, tensor", + [ + ((True, True), ((0, 1), (0, 1)), torch.tensor([[2333]])), + ((True, True, True), ((0, 1), (1, 0), (0, 1)), torch.tensor([[[2333]]])), + ], +) +@pytest.mark.parametrize( + "shape", + [ + (1,), + (1, 1), + (1, 1, 1), + (1, 1, 1, 1), + ], +) +def test_reshape_with_one_dimension( + arrow: tuple[bool, ...], + edges: tuple[tuple[int, int], ...], + tensor: torch.Tensor, + shape: tuple[int, ...], +) -> None: + a = GrassmannTensor(arrow, edges, tensor).reshape(shape) + assert ( + len(a.arrow) == len(shape) and len(a.edges) == len(shape) and a.tensor.dim() == len(shape) + ) + + +def test_reshape_trailing_nontrivial_dim_raises() -> None: + a = GrassmannTensor((True,), ((2, 2),), torch.randn([4])) + with pytest.raises(AssertionError, match="New shape exceeds after exhausting self dimensions"): + _ = a.reshape((-1, (2, 2))) diff --git a/tests/svd_test.py b/tests/svd_test.py new file mode 100644 index 0000000..4660385 --- /dev/null +++ b/tests/svd_test.py @@ -0,0 +1,253 @@ +import torch +import pytest +from _pytest.mark.structures import ParameterSet +import math +import itertools +from typing import TypeAlias, Iterable, Any + +from grassmann_tensor import GrassmannTensor + +Arrow: TypeAlias = tuple[bool, ...] +Edges: TypeAlias = tuple[tuple[int, int], ...] +Tensor: TypeAlias = torch.Tensor +Cutoff: TypeAlias = int | tuple[int, int] | None +Tau: TypeAlias = float +FreeNamesU: TypeAlias = tuple[int, ...] + +SVDCases = Iterable[ParameterSet] + + +def get_total_singular(edges: Edges, free_names_u: FreeNamesU) -> tuple[int, int]: + even_singular = min(GrassmannTensor.calculate_even_odd(tuple(edges[i] for i in free_names_u))) + set_all = set(range(len(edges))) + rest_idx = sorted(set_all - set(free_names_u)) + odd_singular = min(GrassmannTensor.calculate_even_odd(tuple(edges[i] for i in rest_idx))) + return even_singular, odd_singular + + +def tau_for_cutoff(c: int, total: int, alpha: float = 0.8, slack: float = 1.05) -> float: + cut = 0 + if isinstance(c, int): + cut = c + lo, hi = 1e-8, 1e-1 + x = (total - cut) / max(1, total - 1) + return (lo + (hi - lo) * (x**alpha)) * slack + + +def choose_free_names(n_edges: int, limit: int = 8) -> list[FreeNamesU]: + combos = [ + tuple(c) for r in range(1, n_edges) for c in itertools.combinations(range(n_edges), r) + ] + return combos[:limit] + + +BASE_GT_CASES: list[tuple[Arrow, Edges, Tensor]] = [ + ((True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)), + ((True, True, True), ((2, 2), (4, 4), (8, 8)), torch.randn(4, 8, 16, dtype=torch.float64)), + ( + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ), +] + + +def svd_cases() -> SVDCases: + params = [] + for arrow, edges, tensor in BASE_GT_CASES: + for fnu in choose_free_names(len(edges)): + even_singular, odd_singular = get_total_singular(edges, fnu) + max_singular = max(even_singular, odd_singular) + total = even_singular + odd_singular + cutoff_list = [ + None, + max_singular, + max_singular - 1, + (even_singular, odd_singular), + ] + for cutoff in cutoff_list: + if cutoff is None: + kept = total + elif isinstance(cutoff, int): + k = cutoff + kept = min(k, even_singular) + min(k, odd_singular) + else: + ke = min(int(cutoff[0]), even_singular) + ko = min(int(cutoff[1]), odd_singular) + kept = ke + ko + tau = tau_for_cutoff(kept, total) + params.append( + pytest.param( + arrow, + edges, + tensor, + cutoff, + tau, + fnu, + id=f"edges={tuple(edges)}|fnu={fnu}|cut={cutoff}|tau={tau:.2e}", + ) + ) + return params + + +@pytest.mark.parametrize( + "arrow, edges, tensor, cutoff, tau, free_names_u", + svd_cases(), +) +@pytest.mark.repeat(20) +def test_svd( + arrow: Arrow, + edges: Edges, + tensor: Tensor, + cutoff: Cutoff, + tau: Tau, + free_names_u: FreeNamesU, +) -> None: + gt = GrassmannTensor(arrow, edges, tensor) + U, S, Vh = gt.svd(free_names_u, cutoff=cutoff) + + # reshape U + left_dim = math.prod(U.tensor.shape[:-1]) + left_edge = list(U.edges[:-1]) + U = U.reshape((left_dim, -1)) + + # reshape Vh + right_dim = math.prod(Vh.tensor.shape[1:]) + right_edge = list(Vh.edges[1:]) + Vh = Vh.reshape((-1, right_dim)) + + US = GrassmannTensor.matmul(U, S) + USV = GrassmannTensor.matmul(US, Vh) + + set_all = set(range(len(edges))) + set_u = set(free_names_u) + set_v = sorted(set_all - set_u) + perm_order = list(free_names_u) + list(set_v) + inv_perm = [perm_order.index(i) for i in range(len(edges))] + + USV = USV.reshape(tuple(left_edge + right_edge)) + USV = USV.permute(tuple(inv_perm)) + + masked = gt.update_mask().tensor + den = masked.norm() + eps = torch.finfo(masked.dtype).eps + rel_err = (masked - USV.tensor).norm() / max(den, eps) + assert rel_err <= tau + + +@pytest.mark.parametrize( + "arrow, edges, tensor, cutoff , tau, free_names_u", + svd_cases(), +) +@pytest.mark.parametrize( + "incompatible_cutoff", + [ + -1, + 0, + ( + 1, + 2, + 3, + ), + "string", + {"key", "value"}, + [1, 2, 3], + {1, 2}, + object(), + ], +) +def test_svd_with_incompatible_cutoff( + arrow: Arrow, + edges: Edges, + tensor: Tensor, + cutoff: Cutoff, + tau: Tau, + free_names_u: FreeNamesU, + incompatible_cutoff: Any, +) -> None: + gt = GrassmannTensor(arrow, edges, tensor) + if isinstance(incompatible_cutoff, int): + with pytest.raises(AssertionError, match="Cutoff must be greater than 0"): + _, _, _ = gt.svd(free_names_u, cutoff=incompatible_cutoff) + elif isinstance(incompatible_cutoff, tuple): + with pytest.raises( + AssertionError, match="The length of cutoff must be 2 if cutoff is a tuple" + ): + _, _, _ = gt.svd(free_names_u, cutoff=incompatible_cutoff) + else: + with pytest.raises( + ValueError, match="Cutoff must be an integer or a tuple of two integers" + ): + _, _, _ = gt.svd(free_names_u, cutoff=incompatible_cutoff) + + +@pytest.mark.parametrize("a,b", [(3, 5), (1, 1), (8, 2)]) +def test_svd_both_blocks_empty_raises_with_int_cutoff(a: int, b: int) -> None: + # edges: left=(even_left=0, odd_left=a), right=(even_right=b, odd_right=0) + # tensor shape must be (a, b) + arrow = (True, True) + edges = ((0, a), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.float64) + + gt = GrassmannTensor(arrow, edges, tensor) + + free_names_u = (0,) + with pytest.raises(RuntimeError, match="Both parity block are empty. Can not form SVD."): + _ = gt.svd(free_names_u, cutoff=1) + + +@pytest.mark.parametrize("a,b", [(3, 5), (2, 4), (7, 3)]) +def test_svd_both_blocks_empty_raises_with_tuple_cutoff(a: int, b: int) -> None: + arrow = (True, True) + edges = ((0, a), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.float64) + + gt = GrassmannTensor(arrow, edges, tensor) + + free_names_u = (0,) + with pytest.raises(RuntimeError, match="Both parity block are empty. Can not form SVD."): + _ = gt.svd(free_names_u, cutoff=(1, 1)) + + +@pytest.mark.parametrize( + "a,b,c,k", + [ + (3, 5, 7, 2), + (4, 1, 2, 3), + ], +) +def test_svd_int_cutoff_even_block_empty_select_from_odd_only( + a: int, b: int, c: int, k: int +) -> None: + arrow = (True, True) + edges = ((0, a), (b, c)) + tensor = torch.randn(a, b + c, dtype=torch.float64) + + gt = GrassmannTensor(arrow, edges, tensor) + U, S, Vh = gt.svd((0,), cutoff=k) + + expected_k = min(k, min(a, c)) + assert U.edges[-1] == (0, expected_k) + assert Vh.edges[0] == (0, expected_k) + assert S.edges == ((0, expected_k), (0, expected_k)) + + +@pytest.mark.parametrize( + "a,b,k", + [ + (5, 4, 2), + (7, 3, 5), + ], +) +def test_svd_int_cutoff_odd_block_empty_select_from_even_only(a: int, b: int, k: int) -> None: + arrow = (True, True) + edges = ((a, 0), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.float64) + + gt = GrassmannTensor(arrow, edges, tensor) + U, S, Vh = gt.svd((0,), cutoff=k) + + expected_k = min(k, min(a, b)) + assert U.edges[-1] == (expected_k, 0) + assert Vh.edges[0] == (expected_k, 0) + assert S.edges == ((expected_k, 0), (expected_k, 0))