-
Notifications
You must be signed in to change notification settings - Fork 21
Efficient Block-Diagonal Matrix Operations for Wigner D Computation #250
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors Wigner D matrix computation to use batched block-diagonal operations instead of per-l
loops, significantly improving performance for large l_max
or batch sizes.
- Generate all small rotation matrices in one go and assemble into block-diagonal tensors
- Perform a single batched matrix multiplication sequence for the full Wigner D
- Group irreps by quantum number and rotate features in bulk rather than per-
l
Comments suppressed due to low confidence (3)
dptb/nn/tensor_product.py:17
- The docstring for
build_z_rot_multi
mentions anl_max
parameter which is not present; update the parameter list and descriptions to match the actual signature.
l_max: int
dptb/nn/tensor_product.py:55
- Add unit tests comparing
batch_wigner_D
against the originalwigner_D
for smalll_max
and random angles to ensure the batched implementation matches legacy outputs.
def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):
dptb/nn/tensor_product.py:257
- This loop is rotating data in
out
, which is still zero; it should be reshaping and rotatingx_
(the input features) rather thanout
to preserve the computed values.
x_slice = out[:, slice_in].reshape(n, mul, -1)
# Assign values to the diagonal | ||
M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1) | ||
# Assign values to non-overlapping anti-diagonals | ||
overlap_mask = (global_row == global_col_anti) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it require a mask here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'overlap_mask' is used to avoid assigning sine values to matrix entries that are already filled with cosine values on the diagonal. It ensures that sine terms are only written to true off-diagonal (anti-diagonal) positions to prevent overwriting or conflict.
|
||
# Load static data | ||
sizes = idx_data["sizes"][:l_max+1] | ||
offsets = idx_data["offsets"][:l_max+1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is the offset defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following is the code for generating the 'z_rot_indices_lmax12.pt' file:
import torch
l_max_all = 12
l_list = torch.arange(0, l_max_all + 1)
sizes = 2 * l_list + 1
offsets = torch.cat([torch.tensor([0]), torch.cumsum(sizes, 0)[:-1]])
L = len(l_list)
Mmax = sizes.max().item()
mask = torch.zeros(L, Mmax, dtype=torch.bool)
freq = torch.zeros(L, Mmax)
inds = torch.zeros(L, Mmax, dtype=torch.long)
reversed_inds = torch.zeros(L, Mmax, dtype=torch.long)
for i, l in enumerate(l_list):
sz = sizes[i]
mask[i, :sz] = True
freq[i, :sz] = torch.arange(l, -l-1, -1)
inds[i, :sz] = torch.arange(0, sz)
reversed_inds[i, :sz] = torch.arange(2 * l, -1, -1)
torch.save({
"mask": mask,
"freq": freq,
"inds": inds,
"reversed_inds": reversed_inds,
"l_list": l_list,
"sizes": sizes,
"offsets": offsets,
"Mmax": Mmax,
"l_max_all": l_max_all,
}, "/root/DeePTB/dptb/nn/z_rot_indices_lmax12.pt")
Optimize Wigner D Computation Using Block-Diagonal Batch Operations
Summary
This PR improves the efficiency of Wigner D matrix computation by eliminating per-
l
for-loops and replacing them with batch block-diagonal matrix operations. All rotation matrices and related terms are assembled and multiplied in a single step. This approach reduces redundant computation and significantly speeds up the process, especially for large batches or highl_max
.Key Changes
l
values are generated at once.self.irreps_in
andself.irreps_out
by quantum numberl
.torch.cat
andtorch.bmm
to batch multiple irreps of the samel
together for rotation.