Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul Optimizations Short-Term Roadmap #128

Open
8 of 11 tasks
louisfd opened this issue Sep 18, 2024 · 0 comments
Open
8 of 11 tasks

Matmul Optimizations Short-Term Roadmap #128

louisfd opened this issue Sep 18, 2024 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@louisfd
Copy link
Member

louisfd commented Sep 18, 2024

Mostly notes to myself as this is very precise and short term. Concerns the cmma version of matmul.
Each step should come with a config so it doesn't become mandatory and therefore we can merge faster.
Steps must be done in this order:

Compute loop

  • Invert k and n for loops: at the moment the outer loop is on n and the inner on k, which forbids the next step:
  • Keep lhs fragment in register instead of reloading the same data on each iteration

Loading to SMEM

  • At the moment each warp is filling its own tile, while there is no real reason to split responsibilities this way. We should instead change the algorithm so a warp loads as many as it can in one coalesced write (with vectorization considered), then offsets itself by this value times the number of warps, of course following the respective layouts of the GMEM and SMEM.
  • Allow for different layouts of the SMEM. As of now tiles from lhs are row major and tiles of rhs are col major, but the contrary is probably more suitable for double buffering.
  • Instead of assuming which warp loads what based on their cube position, write the loading as a function of an id and a number of warp for better flexibility.

Config

  • Relax the config constraint b_m = b_n
  • Fix num_compute_warps = b_m / 16, num_buffers = b_k / 16, num_accumulators = b_n / 16. Have one warp per row to maxime reutilization of lhs fragment.

Double buffering with warp specialization

  • Specialize some warps into compute warps. The number should equal the number of tensor cores (typically 4 or 8), and should be num_compute_warps. Allow some other warps to serve as loading warps. Adjust sync_units accordingly.
  • Define specialization strategy, to determine which warps should do what (for instance: 0..7 compute, the rest loads)
  • Use double buffering by alternating using first and second half of SMEM
  • It's probably better to have the compute warps load the first tiles to save a sync
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant