diff --git a/warp/builtins.py b/warp/builtins.py index 21b62429..dc028acf 100644 --- a/warp/builtins.py +++ b/warp/builtins.py @@ -5681,9 +5681,7 @@ def make_function(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout): raise RuntimeError("time_matmul(A, B, C) requires all inputs to be real or complex") element_type = a_type - lto_symbol = ( - f"dot_{M}_{N}_{K}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}" - ) + lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}" # early out if LTO for this combination already exists for this module if lto_symbol in builder.ltoirs: