Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton generates incorrect MLIR code when assigning values to elements of a tuple in the body of an if statement #5446

Open
Pluto-Zy opened this issue Dec 17, 2024 · 0 comments
Labels

Comments

@Pluto-Zy
Copy link

Describe the bug

I noticed that Triton introduced support for tuples in a recent commit (#5220). However, the Triton frontend does not handle the case where a user modifies a tuple inside an if statement. For example, consider the following code:

import torch
import triton
import triton.language as tl

@triton.jit
def func(cond, out):
    values = (100, 101)
    if cond > 0.0:
        values[1] = values[1] + 42
    else:
        values[1] = values[1] + 43

    tl.store(out, values[1])


output = torch.zeros((1,), dtype=torch.int32, device="cuda")
func[(1,)](1.3, output)

In this code, I modify values[1] in both the true and false branches of the if statement, and finally store the result in out. When I run this program with export MLIR_ENABLE_DUMP=1, Triton generates the following MLIR (with adjusted formatting and location information removed):

// -----// IR Dump Before Inliner (inline) ('builtin.module' operation) //----- //
"builtin.module"() ({
  "tt.func"() <{arg_attrs = [{}, {tt.divisibility = 16 : i32}], function_type = (f32, !tt.ptr<i32>) -> (), sym_name = "func", sym_visibility = "public"}> ({
  ^bb0(%arg0: f32, %arg1: !tt.ptr<i32>):
    %0 = "arith.constant"() <{value = 100 : i32}> : () -> i32
    %1 = "arith.constant"() <{value = 101 : i32}> : () -> i32
    %2 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
    %3 = "arith.cmpf"(%arg0, %2) <{fastmath = #arith.fastmath<none>, predicate = 2 : i64}> : (f32, f32) -> i1
    "scf.if"(%3) ({
      %15 = "arith.constant"() <{value = 42 : i32}> : () -> i32
      %16 = "arith.constant"() <{value = 42 : i32}> : () -> i32
      %17 = "arith.extsi"(%1) : (i32) -> i64
      %18 = "arith.extsi"(%16) : (i32) -> i64
      %19 = "arith.addi"(%17, %18) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %20 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
      %21 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
      %22 = "arith.cmpi"(%19, %20) <{predicate = 3 : i64}> : (i64, i64) -> i1
      %23 = "arith.cmpi"(%19, %21) <{predicate = 5 : i64}> : (i64, i64) -> i1
      %24 = "arith.andi"(%22, %23) : (i1, i1) -> i1
      %25 = "arith.addi"(%1, %16) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      "scf.yield"() : () -> ()
    }, {
      %4 = "arith.constant"() <{value = 43 : i32}> : () -> i32
      %5 = "arith.constant"() <{value = 43 : i32}> : () -> i32
      %6 = "arith.extsi"(%25) : (i32) -> i64
      %7 = "arith.extsi"(%5) : (i32) -> i64
      %8 = "arith.addi"(%6, %7) <{overflowFlags = #arith.overflow<none>}> : (i64, i64) -> i64
      %9 = "arith.constant"() <{value = 2147483647 : i64}> : () -> i64
      %10 = "arith.constant"() <{value = -2147483648 : i64}> : () -> i64
      %11 = "arith.cmpi"(%8, %9) <{predicate = 3 : i64}> : (i64, i64) -> i1
      %12 = "arith.cmpi"(%8, %10) <{predicate = 5 : i64}> : (i64, i64) -> i1
      %13 = "arith.andi"(%11, %12) : (i1, i1) -> i1
      %14 = "arith.addi"(%25, %5) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      "scf.yield"() : () -> ()
    }) : (i1) -> ()
    "tt.store"(%arg1, %14) <{boundaryCheck = array<i32>, cache = 1 : i32, evict = 1 : i32}> : (!tt.ptr<i32>, i32) -> ()
    "tt.return"() : () -> ()
  }) {noinline = false} : () -> ()
}) : () -> ()

The generated scf.if contains code that checks for overflow during the addition operation. Notably:

  1. The scf.if does not return any result. The correct behavior should be to return the modified value of values[1].
  2. %6 = arith.extsi(%25) incorrectly uses %25, which is defined in another region.
  3. tt.store incorrectly uses %14, which is defined within the if block.

I suspect the error arises from the local_defs statistics in code_generator.py, where only the names of definitions are counted, without accounting for the modification of tuple elements. Similarly, I believe a similar error could occur when modifying a property access.

If possible, I would be happy to contribute a fix for this issue to Triton. However, I am not sure how Triton would prefer to handle such cases — whether to simply prohibit users from modifying tuple elements, or if there is a better strategy for tracking values modified inside an if statement.

Environment details

Triton: commit 24b8d43
System: Ubuntu 22.04

@Pluto-Zy Pluto-Zy added the bug label Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant