Skip to content

Efficient Block-Diagonal Matrix Operations for Wigner D Computation #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Kai-Qi
Copy link

@Kai-Qi Kai-Qi commented May 21, 2025

Optimize Wigner D Computation Using Block-Diagonal Batch Operations

Summary

This PR improves the efficiency of Wigner D matrix computation by eliminating per-l for-loops and replacing them with batch block-diagonal matrix operations. All rotation matrices and related terms are assembled and multiplied in a single step. This approach reduces redundant computation and significantly speeds up the process, especially for large batches or high l_max.

Key Changes

  • Batch Generation: All small rotation matrices for different l values are generated at once.
  • Block-Diagonal Assembly: These matrices are combined into large block-diagonal matrices.
  • Single-Step Multiplication: The Wigner D calculation is performed with a single batch matrix multiplication.
  • Grouped self.irreps_in and self.irreps_out by quantum number l.
  • Used torch.cat and torch.bmm to batch multiple irreps of the same l together for rotation.
  • Ensured numerical correctness by verifying that output reshaping and scattering match the original structure.

@Kai-Qi Kai-Qi changed the title Optimize SO(2) convolution in SO2_Linear layer via Quantum Number Grouping for Efficient Batched Computation Efficient Block-Diagonal Matrix Operations for Wigner D Computation Jun 13, 2025
@QG-phy QG-phy requested review from floatingCatty and Copilot June 20, 2025 02:17
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors Wigner D matrix computation to use batched block-diagonal operations instead of per-l loops, significantly improving performance for large l_max or batch sizes.

  • Generate all small rotation matrices in one go and assemble into block-diagonal tensors
  • Perform a single batched matrix multiplication sequence for the full Wigner D
  • Group irreps by quantum number and rotate features in bulk rather than per-l
Comments suppressed due to low confidence (3)

dptb/nn/tensor_product.py:17

  • The docstring for build_z_rot_multi mentions an l_max parameter which is not present; update the parameter list and descriptions to match the actual signature.
    l_max: int

dptb/nn/tensor_product.py:55

  • Add unit tests comparing batch_wigner_D against the original wigner_D for small l_max and random angles to ensure the batched implementation matches legacy outputs.
def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):

dptb/nn/tensor_product.py:257

  • This loop is rotating data in out, which is still zero; it should be reshaping and rotating x_ (the input features) rather than out to preserve the computed values.
                x_slice = out[:, slice_in].reshape(n, mul, -1)

# Assign values to the diagonal
M_total[:, global_row, global_col_diag] = cos_val[idx_l, :, idx_row].transpose(0,1)
# Assign values to non-overlapping anti-diagonals
overlap_mask = (global_row == global_col_anti)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it require a mask here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'overlap_mask' is used to avoid assigning sine values to matrix entries that are already filled with cosine values on the diagonal. It ensures that sine terms are only written to true off-diagonal (anti-diagonal) positions to prevent overwriting or conflict.


# Load static data
sizes = idx_data["sizes"][:l_max+1]
offsets = idx_data["offsets"][:l_max+1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the offset defined?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following is the code for generating the 'z_rot_indices_lmax12.pt' file:

import torch

l_max_all = 12
l_list = torch.arange(0, l_max_all + 1)
sizes = 2 * l_list + 1
offsets = torch.cat([torch.tensor([0]), torch.cumsum(sizes, 0)[:-1]])
L = len(l_list)
Mmax = sizes.max().item()

mask = torch.zeros(L, Mmax, dtype=torch.bool)
freq = torch.zeros(L, Mmax)
inds = torch.zeros(L, Mmax, dtype=torch.long)
reversed_inds = torch.zeros(L, Mmax, dtype=torch.long)

for i, l in enumerate(l_list):
sz = sizes[i]
mask[i, :sz] = True
freq[i, :sz] = torch.arange(l, -l-1, -1)
inds[i, :sz] = torch.arange(0, sz)
reversed_inds[i, :sz] = torch.arange(2 * l, -1, -1)

torch.save({
"mask": mask,
"freq": freq,
"inds": inds,
"reversed_inds": reversed_inds,
"l_list": l_list,
"sizes": sizes,
"offsets": offsets,
"Mmax": Mmax,
"l_max_all": l_max_all,
}, "/root/DeePTB/dptb/nn/z_rot_indices_lmax12.pt")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants