This repository tries to implements the ideas presented in the blog post "Accelerating 2D Dynamic Block Quantized Float8 GEMMs in Triton". Designed specifically for NVIDIA H100 GPUs, it leverages advanced features like float8 computation, Triton's high-performance GPU programming capabilities, and the Tensor Memory Accelerator (TMA). These elements enable state-of-the-art GEMM kernels by optimizing memory transfer efficiency, reducing latency, and maximizing computational throughput.
The goal of this project is to accelerate GEMMs (General Matrix Multiplications) by using dynamic block quantization with float8. By reducing memory bandwidth and maintaining computational accuracy, this work aims to push the boundaries of GPU performance.
To install the dependencies, run:
pip install -r requirements.txt
The repository is organized as follows:
bench.py
: Contains benchmarking utilities to compare Triton and PyTorch implementations.gemm.py
: Implements Triton kernels for GEMMs with persistent memory and dynamic block quantization.main.py
: Entry point for the quantization and GEMM computation pipeline.quant.py
: Handles block-level quantization of matrices into float8 format.requirements.txt
: Lists Python dependencies.
The quant.py
module quantizes input matrices by:
- Dividing matrices into subblocks and determining the maximum absolute value for each block during the first pass.
- Using the maximum value found in the first pass as a scaling factor to quantize the data into the float8 representation during the second pass.
- Storing these scaling factors for later dequantization to recover the original matrix.
The gemm.py
module leverages Triton to:
- Load matrix blocks efficiently using Tensor Memory Accelerator (TMA) descriptors.
- Perform tiled, persistent matrix multiplication to optimize SM utilization.
- TODO: Employ cooperative kernel design with warp specialization, enabling parallelism by dedicating warp groups to specific tasks and reducing data dependencies.
- Store results back to memory in float8 format, utilizing optimized scaling and dequantization techniques.
Benchmarks in bench.py
compare Triton-based implementations with PyTorch, plotting speedups and memory efficiency for various matrix sizes.
To run the pipeline for 4096 x 4096 matrices:
python main.py
Preliminary results on H100 GPUs demonstrate:
- Significant speedups due to reduced memory usage with float8 precision.
- Optimized memory access patterns via Triton's TMA descriptors.
Future updates will include detailed benchmarks and accuracy analyses.
-
Implement warp-specialization, recently merged into triton.
-
Enhance benchmarking and accuracy testing.
This repository draws inspiration from the blog post "Accelerating 2D Dynamic Block Quantized Float8 GEMMs in Triton" and Triton's official documentation and examples.
If you use this work, please cite the original blog post and this repository.