From d1107e1ba10312001fd89eafaaf75d3b795026a4 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Fri, 25 Apr 2025 17:06:11 -0500 Subject: [PATCH] Added support for waves_per_eu function attribute. --- jax/_src/pallas/triton/core.py | 3 +++ jax/_src/pallas/triton/pallas_call_registration.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py index 097f8497e8f7..3f64e90a76ab 100644 --- a/jax/_src/pallas/triton/core.py +++ b/jax/_src/pallas/triton/core.py @@ -29,10 +29,13 @@ class TritonCompilerParams(pallas_core.CompilerParams): 32 threads. num_stages: The number of stages the compiler should use for software pipelining loops. + waves_per_eu: Manages Vector General Purpose Registers (VGPR) usage to achieve + desired occupancy levels. serialized_metadata: Additional compiler metadata. This field is unstable and may be removed in the future. """ PLATFORM: ClassVar[str] = "triton" num_warps: int | None = None num_stages: int | None = None + waves_per_eu: int | None = None serialized_metadata: bytes | None = None diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index 59b1b86f33fc..865a2c6460f1 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -67,9 +67,12 @@ def pallas_call_lowering( if lowering_platform == "rocm": num_stages = triton_params.get("num_stages", 1) num_stages = 1 if num_stages is None else num_stages + waves_per_eu = triton_params.get("waves_per_eu", 1) + waves_per_eu = 1 if waves_per_eu is None else waves_per_eu else: num_stages = triton_params.get("num_stages", 3) num_stages = 3 if num_stages is None else num_stages + waves_per_eu = 1 if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -98,6 +101,7 @@ def pallas_call_lowering( ir=ir.StringAttr.get(buf.getvalue()), num_stages=mlir.i32_attr(num_stages), num_warps=mlir.i32_attr(num_warps), + waves_per_eu=mlir.i32_attr(waves_per_eu), grid_x=mlir.i32_attr(grid_x), grid_y=mlir.i32_attr(grid_y), grid_z=mlir.i32_attr(grid_z),