Skip to content

[TransOp fusion]: Fuse tt.trans with tt.load to expoit 2D block read operations #4450

Open
@etiotto

Description

@etiotto

Consider a loop containing a tt.dot operation that consume the result of a transposition operation:

    %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
      %17 = tt.advance %9, [%11, %arg5] : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
      %18 = tt.load %17 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
      %19 = tt.advance %arg6, [%12, %arg5] : <tensor<256x32xbf16, #linear>>
      %20 = tt.load %19 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<256x32xbf16, #linear>>
      %21 = tt.trans %20 {order = array<i32: 1, 0>} : tensor<256x32xbf16, #linear> -> tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
      %22 = tt.dot %18, %21, %arg4, inputPrecision = tf32 : tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xbf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
      %23 = arith.addi %arg5, %c32_i32 : i32
      scf.yield %22, %23, %19 : tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>
    }

Here the load for operand "B" of the tt.dot operation lacks the dot layout and therefore it is not lowered to efficient 2D block read HW primitives. In order to improve performance the Triton BE can implement a new transformation aimed at fusing a tt.load operation with a tt.transpose into a new load operation (for a transposed tensor) with block_io layout "row_major".

Metadata

Metadata

Assignees

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions