You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Mostly notes to myself as this is very precise and short term. Concerns the cmma version of matmul.
Each step should come with a config so it doesn't become mandatory and therefore we can merge faster.
Steps must be done in this order:
Compute loop
Invert k and n for loops: at the moment the outer loop is on n and the inner on k, which forbids the next step:
Keep lhs fragment in register instead of reloading the same data on each iteration
Loading to SMEM
At the moment each warp is filling its own tile, while there is no real reason to split responsibilities this way. We should instead change the algorithm so a warp loads as many as it can in one coalesced write (with vectorization considered), then offsets itself by this value times the number of warps, of course following the respective layouts of the GMEM and SMEM.
Allow for different layouts of the SMEM. As of now tiles from lhs are row major and tiles of rhs are col major, but the contrary is probably more suitable for double buffering.
Instead of assuming which warp loads what based on their cube position, write the loading as a function of an id and a number of warp for better flexibility.
Config
Relax the config constraint b_m = b_n
Fix num_compute_warps = b_m / 16, num_buffers = b_k / 16, num_accumulators = b_n / 16. Have one warp per row to maxime reutilization of lhs fragment.
Double buffering with warp specialization
Specialize some warps into compute warps. The number should equal the number of tensor cores (typically 4 or 8), and should be num_compute_warps. Allow some other warps to serve as loading warps. Adjust sync_units accordingly.
Define specialization strategy, to determine which warps should do what (for instance: 0..7 compute, the rest loads)
Use double buffering by alternating using first and second half of SMEM
It's probably better to have the compute warps load the first tiles to save a sync
The text was updated successfully, but these errors were encountered:
Mostly notes to myself as this is very precise and short term. Concerns the cmma version of matmul.
Each step should come with a config so it doesn't become mandatory and therefore we can merge faster.
Steps must be done in this order:
Compute loop
Loading to SMEM
Config
Double buffering with warp specialization
The text was updated successfully, but these errors were encountered: