Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in cases like eltwise_mult_CSRxDense_oCSR #71

Open
pthomadakis opened this issue Oct 30, 2024 · 1 comment
Open

Bug in cases like eltwise_mult_CSRxDense_oCSR #71

pthomadakis opened this issue Oct 30, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@pthomadakis
Copy link
Collaborator

Running the test case in ops/eltwise_mult_CSRxDense_oCSR.ta unveils a bug in the IR that is hard to notice since it does not affect the results.
Specifically, lowering to loops (--emit-loops) produces the following IR:


module {
  func.func @main() {
    %idx-1 = index.constant -1
    %idx1 = index.constant 1
    %idx0 = index.constant 0
    %cst = arith.constant 2.700000e+00 : f64
    %cst_0 = arith.constant 0.000000e+00 : f64
    %c10 = arith.constant 10 : index
    %c9 = arith.constant 9 : index
    %c8 = arith.constant 8 : index
    %c7 = arith.constant 7 : index
    %c6 = arith.constant 6 : index
    %c5 = arith.constant 5 : index
    %c4 = arith.constant 4 : index
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c3 = arith.constant 3 : index
    %c2 = arith.constant 2 : index
    %c-1 = arith.constant -1 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<13xindex>
    %cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
    call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
    %0 = memref.load %alloc[%c0] : memref<13xindex>
    %1 = memref.load %alloc[%c1] : memref<13xindex>
    %2 = memref.load %alloc[%c2] : memref<13xindex>
    %3 = memref.load %alloc[%c3] : memref<13xindex>
    %4 = memref.load %alloc[%c4] : memref<13xindex>
    %5 = memref.load %alloc[%c5] : memref<13xindex>
    %6 = memref.load %alloc[%c6] : memref<13xindex>
    %7 = memref.load %alloc[%c7] : memref<13xindex>
    %8 = memref.load %alloc[%c8] : memref<13xindex>
    %9 = memref.load %alloc[%c9] : memref<13xindex>
    %10 = memref.load %alloc[%c10] : memref<13xindex>
    %alloc_1 = memref.alloc(%0) : memref
    scf.for %arg0 = %c0 to %0 step %c1 {
      memref.store %c0, %alloc_1[%arg0] : memref
    }
    %cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
    %alloc_3 = memref.alloc(%1) : memref
    scf.for %arg0 = %c0 to %1 step %c1 {
      memref.store %c0, %alloc_3[%arg0] : memref
    }
    %cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
    %alloc_5 = memref.alloc(%2) : memref
    scf.for %arg0 = %c0 to %2 step %c1 {
      memref.store %c0, %alloc_5[%arg0] : memref
    }
    %cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
    %alloc_7 = memref.alloc(%3) : memref
    scf.for %arg0 = %c0 to %3 step %c1 {
      memref.store %c0, %alloc_7[%arg0] : memref
    }
    %cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
    %alloc_9 = memref.alloc(%4) : memref
    scf.for %arg0 = %c0 to %4 step %c1 {
      memref.store %c0, %alloc_9[%arg0] : memref
    }
     %cast_10 = memref.cast %alloc_9 : memref to memref<*xindex> 
     %alloc_11 = memref.alloc(%5) : memref 
    scf.for %arg0 = %c0 to %5 step %c1 {
      memref.store %c0, %alloc_11[%arg0] : memref
    }
     %cast_12 = memref.cast %alloc_11 : memref to memref<*xindex> 
    %alloc_13 = memref.alloc(%6) : memref
    scf.for %arg0 = %c0 to %6 step %c1 {
      memref.store %c0, %alloc_13[%arg0] : memref
    }
    %cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
    %alloc_15 = memref.alloc(%7) : memref
    scf.for %arg0 = %c0 to %7 step %c1 {
      memref.store %c0, %alloc_15[%arg0] : memref
    }
    %cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
    %alloc_17 = memref.alloc(%8) : memref
    scf.for %arg0 = %c0 to %8 step %c1 {
      memref.store %cst_0, %alloc_17[%arg0] : memref
    }
    %cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
    call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
     %11 = bufferization.to_tensor %alloc_9 restrict writable : memref
     %12 = bufferization.to_tensor %alloc_11 restrict writable : memref 
    %13 = bufferization.to_tensor %alloc_17 restrict writable : memref
    %alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
    linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
    %14 = bufferization.to_tensor %alloc_19 restrict writable : memref
    %alloc_20 = memref.alloc() : memref<1xindex>
    memref.store %9, %alloc_20[%idx0] : memref<1xindex>
    %15 = bufferization.to_tensor %alloc_20 restrict writable : memref<1xindex>
    %alloc_21 = memref.alloc(%5) : memref
    scf.for %arg0 = %idx0 to %5 step %idx1 {
      memref.store %cst_0, %alloc_21[%arg0] : memref
    }
    %16 = bufferization.to_tensor %alloc_21 restrict writable : memref
    %17:3 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0, %arg2 = %12, %arg3 = %16) -> (index, tensor, tensor) {
      %19 = arith.addi %arg0, %c1 : index
      %extracted = tensor.extract %11[%arg0] : tensor
      %extracted_22 = tensor.extract %11[%19] : tensor
      %20:3 = scf.for %arg4 = %extracted to %extracted_22 step %c1 iter_args(%arg5 = %arg1, %arg6 = %arg2, %arg7 = %arg3) -> (index, tensor, tensor) {
        %extracted_23 = tensor.extract %12[%arg4] : tensor
        %extracted_24 = tensor.extract %13[%arg4] : tensor
        %extracted_25 = tensor.extract %14[%arg0, %extracted_23] : tensor
        %21 = arith.mulf %extracted_24, %extracted_25 : f64
         %inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor <-- We insert into %arg6 which is %12 = %alloc11
        %22 = index.add %arg5, %idx1
        %inserted_27 = tensor.insert %21 into %arg7[%arg5] : tensor
        scf.yield %22, %inserted_26, %inserted_27 : index, tensor, tensor
      }
      scf.yield %20#0, %20#1, %20#2 : index, tensor, tensor
    }
    %18 = bufferization.alloc_tensor() : tensor<1xindex>
    %inserted = tensor.insert %idx-1 into %18[%idx0] : tensor<1xindex>
    "ta.print"(%15) : (tensor<1xindex>) -> ()
    "ta.print"(%inserted) : (tensor<1xindex>) -> ()
     "ta.print"(%11) : (tensor) -> () 
     "ta.print"(%17#1) : (tensor) -> () 
    "ta.print"(%17#2) : (tensor) -> ()
    return
  }
  func.func private @read_input_2D_f64(i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32)
  func.func private @read_input_sizes_2D_f64(i32, index, index, index, index, memref<*xindex>, i32)
}

Here, we can see two issues:

  1. SSA %11, which refers to data related to one of the input sparse tensors is printed when I try to print the result/output tensor. This means that the input and output tensors share a reference to the same data structure which should not be the case.
  2. Even worse, SSA %17#1 is a tensor produced by tensor.insert into one of the underlying data of the same input tensor. this operation specifically: %inserted_26 = tensor.insert %extracted_23 into %arg6[%arg5] : tensor<?xindex>.

I have formatted the related pieces of IR in bold to make it easier to track what I'm referring to.
The problem does not show up since we never check the input tensors but, also, one-shot-bufferize saves us by creating copy a of the tensor that we try to insert to before inserting to it.

If I let bufferization happen, here it is what we get (--convert-ta-to-it --convert-to-loops)


  func.func @main() {
    %idx-1 = index.constant -1
    %idx1 = index.constant 1
    %idx0 = index.constant 0
    %cst = arith.constant 2.700000e+00 : f64
    %cst_0 = arith.constant 0.000000e+00 : f64
    %c10 = arith.constant 10 : index
    %c9 = arith.constant 9 : index
    %c8 = arith.constant 8 : index
    %c7 = arith.constant 7 : index
    %c6 = arith.constant 6 : index
    %c5 = arith.constant 5 : index
    %c4 = arith.constant 4 : index
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c3 = arith.constant 3 : index
    %c2 = arith.constant 2 : index
    %c-1 = arith.constant -1 : index
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %alloc = memref.alloc() : memref<13xindex>
    %cast = memref.cast %alloc : memref<13xindex> to memref<*xindex>
    call @read_input_sizes_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, i32) -> ()
    %0 = memref.load %alloc[%c0] : memref<13xindex>
    %1 = memref.load %alloc[%c1] : memref<13xindex>
    %2 = memref.load %alloc[%c2] : memref<13xindex>
    %3 = memref.load %alloc[%c3] : memref<13xindex>
    %4 = memref.load %alloc[%c4] : memref<13xindex>
    %5 = memref.load %alloc[%c5] : memref<13xindex>
    %6 = memref.load %alloc[%c6] : memref<13xindex>
    %7 = memref.load %alloc[%c7] : memref<13xindex>
    %8 = memref.load %alloc[%c8] : memref<13xindex>
    %9 = memref.load %alloc[%c9] : memref<13xindex>
    %10 = memref.load %alloc[%c10] : memref<13xindex>
    %alloc_1 = memref.alloc(%0) : memref
    scf.for %arg0 = %c0 to %0 step %c1 {
      memref.store %c0, %alloc_1[%arg0] : memref
    }
    %cast_2 = memref.cast %alloc_1 : memref to memref<*xindex>
    %alloc_3 = memref.alloc(%1) : memref
    scf.for %arg0 = %c0 to %1 step %c1 {
      memref.store %c0, %alloc_3[%arg0] : memref
    }
    %cast_4 = memref.cast %alloc_3 : memref to memref<*xindex>
    %alloc_5 = memref.alloc(%2) : memref
    scf.for %arg0 = %c0 to %2 step %c1 {
      memref.store %c0, %alloc_5[%arg0] : memref
    }
    %cast_6 = memref.cast %alloc_5 : memref to memref<*xindex>
    %alloc_7 = memref.alloc(%3) : memref
    scf.for %arg0 = %c0 to %3 step %c1 {
      memref.store %c0, %alloc_7[%arg0] : memref
    }
    %cast_8 = memref.cast %alloc_7 : memref to memref<*xindex>
    %alloc_9 = memref.alloc(%4) : memref
    scf.for %arg0 = %c0 to %4 step %c1 {
      memref.store %c0, %alloc_9[%arg0] : memref
    }
    %cast_10 = memref.cast %alloc_9 : memref to memref<*xindex>
    %alloc_11 = memref.alloc(%5) : memref
    scf.for %arg0 = %c0 to %5 step %c1 {
      memref.store %c0, %alloc_11[%arg0] : memref
    }
    %cast_12 = memref.cast %alloc_11 : memref to memref<*xindex>
    %alloc_13 = memref.alloc(%6) : memref
    scf.for %arg0 = %c0 to %6 step %c1 {
      memref.store %c0, %alloc_13[%arg0] : memref
    }
    %cast_14 = memref.cast %alloc_13 : memref to memref<*xindex>
    %alloc_15 = memref.alloc(%7) : memref
    scf.for %arg0 = %c0 to %7 step %c1 {
      memref.store %c0, %alloc_15[%arg0] : memref
    }
    %cast_16 = memref.cast %alloc_15 : memref to memref<*xindex>
    %alloc_17 = memref.alloc(%8) : memref
    scf.for %arg0 = %c0 to %8 step %c1 {
      memref.store %cst_0, %alloc_17[%arg0] : memref
    }
    %cast_18 = memref.cast %alloc_17 : memref to memref<*xf64>
    call @read_input_2D_f64(%c0_i32, %c0, %c-1, %c1, %c-1, %cast_2, %cast_4, %cast_6, %cast_8, %cast_10, %cast_12, %cast_14, %cast_16, %cast_18, %c1_i32) {filename = "SPARSE_FILE_NAME0"} : (i32, index, index, index, index, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xindex>, memref<*xf64>, i32) -> ()
    %alloc_19 = memref.alloc(%9, %10) {alignment = 32 : i64} : memref
    linalg.fill ins(%cst : f64) outs(%alloc_19 : memref)
    %alloc_20 = memref.alloc() : memref<1xindex>
    memref.store %9, %alloc_20[%idx0] : memref<1xindex>
    %alloc_21 = memref.alloc(%5) : memref
    scf.for %arg0 = %idx0 to %5 step %idx1 {
      memref.store %cst_0, %alloc_21[%arg0] : memref
    }
    %alloc_22 = memref.alloc(%5) {alignment = 64 : i64} : memref
     memref.copy %alloc_11, %alloc_22 : memref to memref <-- Inserted by one-shot-bufferize 
    %11 = scf.for %arg0 = %c0 to %9 step %c1 iter_args(%arg1 = %idx0) -> (index) {
      %12 = arith.addi %arg0, %c1 : index
      %13 = memref.load %alloc_9[%arg0] : memref
      %14 = memref.load %alloc_9[%12] : memref
      %15 = scf.for %arg2 = %13 to %14 step %c1 iter_args(%arg3 = %arg1) -> (index) {
        %16 = memref.load %alloc_11[%arg2] : memref
        %17 = memref.load %alloc_17[%arg2] : memref
        %18 = memref.load %alloc_19[%arg0, %16] : memref
        %19 = arith.mulf %17, %18 : f64
        memref.store %16, %alloc_22[%arg3] : memref
        %20 = index.add %arg3, %idx1
        memref.store %19, %alloc_21[%arg3] : memref
        scf.yield %20 : index
      }
      scf.yield %15 : index
    }
    %alloc_23 = memref.alloc() {alignment = 64 : i64} : memref<1xindex>
    memref.store %idx-1, %alloc_23[%idx0] : memref<1xindex>
    %cast_24 = memref.cast %alloc_20 : memref<1xindex> to memref<*xindex>
    call @comet_print_memref_i64(%cast_24) : (memref<*xindex>) -> ()
    %cast_25 = memref.cast %alloc_23 : memref<1xindex> to memref<*xindex>
    call @comet_print_memref_i64(%cast_25) : (memref<*xindex>) -> ()
    call @comet_print_memref_i64(%cast_10) : (memref<*xindex>) -> ()
    %cast_26 = memref.cast %alloc_22 : memref to memref<*xindex>
    call @comet_print_memref_i64(%cast_26) : (memref<*xindex>) -> ()
    %cast_27 = memref.cast %alloc_21 : memref to memref<*xf64>
    call @comet_print_memref_f64(%cast_27) : (memref<*xf64>) -> ()
    return
  }

Notice the copy operation that is inserted in bold. However, we cannot rely on this.

@pthomadakis pthomadakis added the bug Something isn't working label Oct 30, 2024
@AK2000
Copy link
Collaborator

AK2000 commented Nov 11, 2024

I don't think the issue described above is accurate. %11 and %12 are coordinate tensors. When inferring the domain of the output tensor, we recognize that it is the same as the input tensor, therefore we use the same coordinate tensors. Because we are ensured SSA semantics with tensors, this will always be safe. If at a later point we want to modify the coordinates, we are guaranteed to create a copy of the tensor (as happens to %12 here).

I think the (smaller) issue is that we are doing too much inserting. We recognize that we can use the same coordinate tensors, but then we insert the coordinates anyway (%inserted_26). Note that this is only a performance issue and will always produce the correct results. This happens as a result of the ta.TensorInsertOp lowering where we assume we must insert a coordinate with the float value. At some point in the lowering process, we should detect which coordinate tensors need a value inserted. This will also be crucial to support other sparse output formats.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants