Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dataclasses
import functools
import typing
import math
import torch


Expand Down Expand Up @@ -488,6 +489,73 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
_tensor=tensor,
)

def svd(
self,
free_names_u: tuple[int, ...],
*,
full_matrices: bool = False, # When full_matrices=True, the gradient with respect to U and Vh will be ignored
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full matrics可以直接是False,tensor下面,不会用到full matrics的svd。

) -> tuple[GrassmannTensor, GrassmannTensor, GrassmannTensor]:
"""
This function is used to computes the singular value decomposition of a grassmann tensor.
The SVD are implemented by follow steps:
1. Split the legs into left and right;
2. Merge the tensor with two groups.
3. Compute the singular value decomposition.
4. Split the legs into original left and right.
The returned tensors U and V are not unique, nor are they continuous with respect to self.
Due to this lack of uniqueness, different hardware and software may compute different singular vectors.
Gradients computed using U or Vh will only be finite when A does not have repeated singular values.
Furthermore, if the distance between any two singular values is close to zero, the gradient
will be numerically unstable, as it depends on the singular values
"""
left_legs = tuple(int(i) for i in free_names_u)
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
"Left/right must cover all tensor legs."
)

order = left_legs + right_legs
tensor = self.permute(order)

left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])

tensor = tensor.reshape((left_dim, right_dim))

U, S, Vh = torch.linalg.svd(tensor.tensor, full_matrices=full_matrices)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里tensor是一个2分块的矩阵,你需要分别进行svd

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不然的话,原来的分块矩阵进行svd后就不是分块的了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已在e8297cb中提交了修改。


k = min(tensor.tensor.shape[0], tensor.tensor.shape[-1])
k_index = tensor.tensor.shape.index(k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

分别进行svd后,需要允许有cut dimension的操作,这个在tn中很常见。大概就是删掉最小几个singular value,只保留最大的若干个,这个个数使用参数传进来,默认不进行cut,这里两个分块需要分别cut。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,已在e8297cb中提交了修改。


U = GrassmannTensor(
_arrow=(True, True), _edges=(tensor.edges[0], tensor.edges[k_index]), _tensor=U
)
S = GrassmannTensor(
_arrow=(
False,
True,
),
_edges=(tensor.edges[k_index], tensor.edges[k_index]),
_tensor=torch.diag(S),
)
Vh = GrassmannTensor(
_arrow=(False, True), _edges=(tensor.edges[k_index], tensor.edges[-1]), _tensor=Vh
)
# Split
left_arrow = [self.arrow[i] for i in left_legs]
left_edges = [self.edges[i] for i in left_legs]

right_arrow = [self.arrow[i] for i in right_legs]
right_edges = [self.edges[i] for i in right_legs]

U = U.reshape(tuple(left_edges + [tensor.edges[k_index]]))
U._arrow = tuple(left_arrow + [True])

Vh = Vh.reshape(tuple([tensor.edges[k_index]] + right_edges))
Vh._arrow = tuple([False] + right_arrow)

return U, S, Vh

def __post_init__(self) -> None:
assert len(self._arrow) == self._tensor.dim(), (
f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."
Expand Down
32 changes: 32 additions & 0 deletions tests/svd_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import math
from grassmann_tensor import GrassmannTensor


def test_svd() -> None:
gt = GrassmannTensor(
(True, True, True, True),
((8, 8), (4, 4), (2, 2), (1, 1)),
torch.randn([16, 8, 4, 2], dtype=torch.float64),
)
U, S, Vh = gt.svd((0, 3))

# reshape U
# left_arrow = U.arrow[:-1]
left_dim = math.prod(U.tensor.shape[:-1])
left_edge = list(U.edges[:-1])
U = U.reshape((left_dim, -1))

# reshape Vh
# right_arrow = Vh.arrow[1:]
right_dim = math.prod(Vh.tensor.shape[1:])
right_edge = list(Vh.edges[1:])
Vh = Vh.reshape((-1, right_dim))

US = GrassmannTensor.matmul(U, S)
USV = GrassmannTensor.matmul(US, Vh)

USV = USV.reshape(tuple(left_edge + right_edge))
USV = USV.permute((0, 2, 3, 1))

assert torch.allclose(gt.tensor, USV.tensor)