From 73373ab98afa5d26a7f0d6e258cfd8c9a4c086a3 Mon Sep 17 00:00:00 2001 From: The jax_triton Authors Date: Mon, 18 Sep 2023 06:03:45 -0700 Subject: [PATCH] Testing triton integration 2023-09-14 PiperOrigin-RevId: 566278630 --- jax_triton/triton_lib.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 220847b5..5ee1c4af 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -163,8 +163,20 @@ def compile_ttir_to_ptx_inplace( print(ttir) try: ttir = tc.optimize_ttir(ttir, compute_capability) - ttgir = tc.ttir_to_ttgir(ttir, num_warps) - ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability) + ttgir = tc.ttir_to_ttgir( + ttir, num_warps, num_ctas=1, arch=compute_capability + ) + ttgir = tc.optimize_ttgir( + ttgir, + num_stages, + num_warps, + num_ctas=1, + arch=compute_capability, + cluster_info=_triton.ClusterInfo(), + enable_warp_specialization=False, + enable_persistent=False, + optimize_epilogue=False, + ) except RuntimeError as e: ttir.dump() raise ValueError("TTIR->TTGIR pass failed!") from e @@ -172,7 +184,9 @@ def compile_ttir_to_ptx_inplace( print(ttgir) extern_libs = {} try: - llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability) + llir = tc.ttgir_to_llir( + ttgir, extern_libs, compute_capability, _triton.TMAInfos() + ) except RuntimeError as e: ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e