-
Notifications
You must be signed in to change notification settings - Fork 0
dev(svd): add support for svd #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
a760dbf
30003b4
bfa26e2
754dce6
e8297cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| import dataclasses | ||
| import functools | ||
| import typing | ||
| import math | ||
| import torch | ||
|
|
||
|
|
||
|
|
@@ -488,6 +489,73 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor: | |
| _tensor=tensor, | ||
| ) | ||
|
|
||
| def svd( | ||
| self, | ||
| free_names_u: tuple[int, ...], | ||
| *, | ||
| full_matrices: bool = False, # When full_matrices=True, the gradient with respect to U and Vh will be ignored | ||
| ) -> tuple[GrassmannTensor, GrassmannTensor, GrassmannTensor]: | ||
| """ | ||
| This function is used to computes the singular value decomposition of a grassmann tensor. | ||
| The SVD are implemented by follow steps: | ||
| 1. Split the legs into left and right; | ||
| 2. Merge the tensor with two groups. | ||
| 3. Compute the singular value decomposition. | ||
| 4. Split the legs into original left and right. | ||
| The returned tensors U and V are not unique, nor are they continuous with respect to self. | ||
| Due to this lack of uniqueness, different hardware and software may compute different singular vectors. | ||
| Gradients computed using U or Vh will only be finite when A does not have repeated singular values. | ||
| Furthermore, if the distance between any two singular values is close to zero, the gradient | ||
| will be numerically unstable, as it depends on the singular values | ||
| """ | ||
| left_legs = tuple(int(i) for i in free_names_u) | ||
| right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs) | ||
| assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), ( | ||
| "Left/right must cover all tensor legs." | ||
| ) | ||
|
|
||
| order = left_legs + right_legs | ||
| tensor = self.permute(order) | ||
|
|
||
| left_dim = math.prod(tensor.tensor.shape[: len(left_legs)]) | ||
| right_dim = math.prod(tensor.tensor.shape[len(left_legs) :]) | ||
|
|
||
| tensor = tensor.reshape((left_dim, right_dim)) | ||
|
|
||
| U, S, Vh = torch.linalg.svd(tensor.tensor, full_matrices=full_matrices) | ||
|
||
|
|
||
| k = min(tensor.tensor.shape[0], tensor.tensor.shape[-1]) | ||
| k_index = tensor.tensor.shape.index(k) | ||
|
||
|
|
||
| U = GrassmannTensor( | ||
| _arrow=(True, True), _edges=(tensor.edges[0], tensor.edges[k_index]), _tensor=U | ||
| ) | ||
| S = GrassmannTensor( | ||
| _arrow=( | ||
| False, | ||
| True, | ||
| ), | ||
| _edges=(tensor.edges[k_index], tensor.edges[k_index]), | ||
| _tensor=torch.diag(S), | ||
| ) | ||
| Vh = GrassmannTensor( | ||
| _arrow=(False, True), _edges=(tensor.edges[k_index], tensor.edges[-1]), _tensor=Vh | ||
| ) | ||
| # Split | ||
| left_arrow = [self.arrow[i] for i in left_legs] | ||
| left_edges = [self.edges[i] for i in left_legs] | ||
|
|
||
| right_arrow = [self.arrow[i] for i in right_legs] | ||
| right_edges = [self.edges[i] for i in right_legs] | ||
|
|
||
| U = U.reshape(tuple(left_edges + [tensor.edges[k_index]])) | ||
| U._arrow = tuple(left_arrow + [True]) | ||
|
|
||
| Vh = Vh.reshape(tuple([tensor.edges[k_index]] + right_edges)) | ||
| Vh._arrow = tuple([False] + right_arrow) | ||
|
|
||
| return U, S, Vh | ||
|
|
||
| def __post_init__(self) -> None: | ||
| assert len(self._arrow) == self._tensor.dim(), ( | ||
| f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})." | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| import torch | ||
| import math | ||
| from grassmann_tensor import GrassmannTensor | ||
|
|
||
|
|
||
| def test_svd() -> None: | ||
| gt = GrassmannTensor( | ||
| (True, True, True, True), | ||
| ((8, 8), (4, 4), (2, 2), (1, 1)), | ||
| torch.randn([16, 8, 4, 2], dtype=torch.float64), | ||
| ) | ||
| U, S, Vh = gt.svd((0, 3)) | ||
|
|
||
| # reshape U | ||
| # left_arrow = U.arrow[:-1] | ||
| left_dim = math.prod(U.tensor.shape[:-1]) | ||
| left_edge = list(U.edges[:-1]) | ||
| U = U.reshape((left_dim, -1)) | ||
|
|
||
| # reshape Vh | ||
| # right_arrow = Vh.arrow[1:] | ||
| right_dim = math.prod(Vh.tensor.shape[1:]) | ||
| right_edge = list(Vh.edges[1:]) | ||
| Vh = Vh.reshape((-1, right_dim)) | ||
|
|
||
| US = GrassmannTensor.matmul(U, S) | ||
| USV = GrassmannTensor.matmul(US, Vh) | ||
|
|
||
| USV = USV.reshape(tuple(left_edge + right_edge)) | ||
| USV = USV.permute((0, 2, 3, 1)) | ||
|
|
||
| assert torch.allclose(gt.tensor, USV.tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full matrics可以直接是False,tensor下面,不会用到full matrics的svd。