-
Notifications
You must be signed in to change notification settings - Fork 0
dev(svd): add support for svd #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
gausshj
commented
Oct 9, 2025
- Add support for svd
- Add test for svd
- Add support for svd - Add test for svd Signed-off-by: Gausshj <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们先review一下业务逻辑,先改这数值svd部分吧。你修改后,我们再说代码风格的事情。
数值svd后,恢复grassmann tensor这部分看起来没毛病。
grassmann_tensor/tensor.py
Outdated
|
|
||
| tensor = tensor.reshape((left_dim, right_dim)) | ||
|
|
||
| U, S, Vh = torch.linalg.svd(tensor.tensor, full_matrices=full_matrices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里tensor是一个2分块的矩阵,你需要分别进行svd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不然的话,原来的分块矩阵进行svd后就不是分块的了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已在e8297cb中提交了修改。
grassmann_tensor/tensor.py
Outdated
| U, S, Vh = torch.linalg.svd(tensor.tensor, full_matrices=full_matrices) | ||
|
|
||
| k = min(tensor.tensor.shape[0], tensor.tensor.shape[-1]) | ||
| k_index = tensor.tensor.shape.index(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
分别进行svd后,需要允许有cut dimension的操作,这个在tn中很常见。大概就是删掉最小几个singular value,只保留最大的若干个,这个个数使用参数传进来,默认不进行cut,这里两个分块需要分别cut。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已在e8297cb中提交了修改。
- Perform SVD separately on even/odd parity blocks instead of entire
rank-2 tensor
- Update the test case with new implementation
Previously, the SVD was applied directly to the full Grassmann tensor,
which ignored the parity block structure and produced incorrect
decompositions. Now the tensor is split into even/odd blocks before
performing SVD, then recombined via block_diag to ensure correct
parity preservation.
Signed-off-by: Gausshj <[email protected]>
- Correct cutoff logic and support None as no-trunction mode - Add parameterized test cases for svd - Fix missing coverage on exception and boundary paths Signed-off-by: Gauss <[email protected]>
- Resolve coverage issues of svd function - Correct type annotations in test cases - Reformat the code of svd Signed-off-by: Gauss <[email protected]>
8bb8f29 to
e8297cb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从你的代码看起来你是理解svd应该怎么做了,不过我们要做的是带broadcast的fermi tensor,所以要有一些不同。
| self, | ||
| free_names_u: tuple[int, ...], | ||
| *, | ||
| full_matrices: bool = False, # When full_matrices=True, the gradient with respect to U and Vh will be ignored |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full matrics可以直接是False,tensor下面,不会用到full matrics的svd。
|
|
||
| tensor = tensor.reshape((left_dim, right_dim)) | ||
|
|
||
| tensor.update_mask() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个update mask是多余的
| keep_even = torch.ones_like(S_even, dtype=torch.bool, device=S_even.device) | ||
| keep_odd = torch.ones_like(S_odd, dtype=torch.bool, device=S_odd.device) | ||
| else: | ||
| S_cat = torch.cat([S_even, S_odd]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你的想法是对的,最普适的svd确实是应该这么做,不过我们需要broadcast的svd,所以如果直接设置一个cutoff会使得每个batch出现不同的S even/S odd选择方案。所以我们应该在S even里选择最大的若干个,S odd里选择最大的若干个。他们之间互不影响。这个函数的cutoff类型可以是None | int | tuple[int, int]