Triton是一种用于编写高效自定义深度学习原语的语言和编译器。它的目标是提供一个开源环境,让用户能够以比使用 CUDA 更高的生产效率编写快速代码,同时还能比其他现有的领域特定语言(DSL)更具灵活性。
在 Triton 中,grid 用于定义 GPU 内核的执行网格,即决定内核在 GPU 上如何并行执行。
构建 grid 的方式:
- 固定
grid(直接赋值为元组)
grid = (16, 16) # 2D grid,16x16 blocks
kernel[grid](...)
- 动态
grid(lambda 函数)
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
kernel[grid](...)
- 更复杂的动态
grid(自定义函数)
def grid(META):
return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1)
kernel[grid](...)
git clone https://github.com/triton-lang/triton.git
cd triton
pip install ninja cmake wheel; # build-time dependencies
pip install -e .
Triton前端与接口部分使用 Python 实现,而核心部分使用 C++ 实现,这是由于其核心任务涉及矩阵运算等密集型计算,以及对底层硬件指令的精准控制。因此,安装Triton涉及对其核心部分的 C++ 代码进行编译。(note: 与后续提到的Triton kernel的编译是不同的概念。)
Triton的 C++ 核心实现目录为lib/,包含:
Triton IR的数据结构和操作- 编译 Pass管理(优化、调度、IR-Lowering)
- 将
Triton IR转换为LLVM IR的代码 - 调用
LLVM生成PTX的逻辑
最终产物为一个共享库python/triton/_C/libtriton.so
libtriton.so是Triton的 C++ 编译器核心,其通过Pybind11暴露为 Python 可调用模块triton._C.libtriton。它通常不会由用户手动调用,而是由Triton的 Python 包的内部模块自动使用。
- 在 Python 中完成,不直接调用
libtriton.so
# python/triton/compiler.py
libtriton.compile_ttir_to_llir(...)
# python/triton/compiler.py
libtriton.compile_llir_to_ptx(...)
# python/triton/compiler.py
libtriton.link_ptx(...)
# python/triton/runtime/launcher.py
libtriton.get_function(...)
libtriton.launch(...)
