Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas/Triton segfault on H100 #17356

Open
jaro-sevcik opened this issue Sep 19, 2024 · 7 comments · Fixed by triton-lang/triton#4803
Open

Pallas/Triton segfault on H100 #17356

jaro-sevcik opened this issue Sep 19, 2024 · 7 comments · Fixed by triton-lang/triton#4803

Comments

@jaro-sevcik
Copy link
Contributor

After the commit cb304cf, JAX crashes in Triton on H100 with the following repro:

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def mha_forward_kernel(
    q_ref,
    k_ref,
    v_ref,
    o_ref,  # Output
):
    q = pl.load(q_ref, (pl.dslice(0, 64), pl.dslice(None)))
    k = pl.load(k_ref, (pl.dslice(None), pl.dslice(0, 64)))
    qk = pl.dot(q, k).astype(jnp.bfloat16)
    v = pl.load(v_ref, (pl.dslice(0, 64), pl.dslice(0, 128)))
    o = pl.dot(qk, v).astype(jnp.bfloat16)
    pl.store(o_ref, (pl.dslice(0, 64), pl.dslice(None)), o)

q = jnp.zeros((64, 128), dtype=jnp.bfloat16)
k = jnp.zeros((128, 64), dtype=jnp.bfloat16)
v = jnp.zeros((64, 128), dtype=jnp.bfloat16)
pl.pallas_call(
        mha_forward_kernel,
        grid=(1,),
        in_specs=[
            pl.BlockSpec(lambda _: (0, 0), (64, 128)),
            pl.BlockSpec(lambda _: (0, 0), (128, 64)),
            pl.BlockSpec(lambda _: (0, 0), (64, 128)),
        ],
        out_specs=pl.BlockSpec(lambda _: (0, 0), (64, 128)),
        compiler_params=dict(triton=dict(num_warps=8)),
        out_shape=jax.ShapeDtypeStruct(shape=(64, 128), dtype=q.dtype),
        name="mha_forward",
    )(q, k, v)

The stack trace:

#0  mlir::detail::OperandStorage::OperandStorage
#1  mlir::Operation::create
#2  mlir::Operation::create
#3  mlir::Operation::create
#4  mlir::OpBuilder::create
#5  mlir::LLVM::InsertElementOp mlir::OpBuilder::create<mlir::LLVM::InsertElementOp, mlir::Type&, mlir::Value&, mlir::Value&, mlir::Value>
#6  loadReg
#7  convertDot
#8  convertWGMMA
#9  (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite
#10 mlir::ConvertOpToLLVMPattern<mlir::triton::nvidia_gpu::WarpGroupDotOp>::matchAndRewrite
#11 mlir::ConversionPattern::matchAndRewrite
#12 void llvm::function_ref<void ()>::callback_fn<mlir::PatternApplicator::matchAndRewrite
#13 mlir::PatternApplicator::matchAndRewrite
#14 (anonymous namespace)::OperationLegalizer::legalize
#15 mlir::OperationConverter::convert
#16 mlir::OperationConverter::convertOperations
#17 mlir::applyPartialConversion
#18 (anonymous namespace)::ConvertTritonGPUToLLVM::runOnOperation
#19 mlir::detail::OpToOpPassAdaptor::run
#20 mlir::detail::OpToOpPassAdaptor::runPipeline
#21 mlir::PassManager::run
#22 xla::gpu::CompileTritonToLLVM
#23 std::_Function_handler<absl::lts_20230802::StatusOr<xla::gpu::KernelReuseCache::Entry> (), xla::gpu::IrEmitterUnnested::EmitTritonCustomCall
#24 xla::gpu::KernelReuseCache::GetWithStatus
#25 xla::gpu::IrEmitterUnnested::EmitTritonCustomCall
#26 xla::gpu::IrEmitterUnnested::EmitHloInstruction
...

Here is the JAX version:

>>> import jax; jax.print_environment_info()
jax:    0.4.34.dev20240918+988ed2bd7
jaxlib: 0.4.34.dev20240919
numpy:  1.26.4
python: 3.10.12 (main, Sep 11 2024, 15:47:36) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CudaDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='5.15.0-84-generic', version='#93-Ubuntu SMP Tue Sep 5 17:16:10 UTC 2023', machine='x86_64')


$ nvidia-smi
Thu Sep 19 08:36:26 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:45:00.0 Off |                    0 |
| N/A   25C    P0             65W /  700W |     534MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
@jaro-sevcik
Copy link
Contributor Author

For completeness, here is the crashing HLO:

HloModule jit_wrapped, entry_computation_layout={(bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0})->bf16[64,128]{1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.5 {
  Arg_0.1 = bf16[64,128]{1,0} parameter(0), metadata={op_name="args[0]"}
  Arg_1.2 = bf16[128,64]{1,0} parameter(1), metadata={op_name="args[1]"}
  Arg_2.3 = bf16[64,128]{1,0} parameter(2), metadata={op_name="args[2]"}
  ROOT custom-call.4 = bf16[64,128]{1,0} custom-call(Arg_0.1, Arg_1.2, Arg_2.3), custom_call_target="__gpu$xla.gpu.triton", operand_layout_constraints={bf16[64,128]{1,0}, bf16[128,64]{1,0}, bf16[64,128]{1,0}}, api_version=API_VERSION_TYPED_FFI, metadata={op_name="jit(wrapped)/jit(main)/pallas_call" source_file="//workspace/triton_crash_repro_wip.py" source_line=21}, backend_config={debug = false, grid_x = 1 : i32, grid_y = 1 : i32, grid_z = 1 : i32, ir = "ML\EFR\0DMLIR20.0.0git\00\017\07\01\05\09!\01\03\0F\03\17\13\17\1B\1F#'+/37;\05\09?CGK\03\A5g+\01e\07\0F\0F\0F\0F\0F\0F\13\13\0B\0B\0F\13\13\0F\0B\0F\0F\0B\0F\0B\0F\0B\0F\1B\0B\0F\0B\0B\0F\0F\13\0B\13\0F\0F\13\0F\13\0F\0F\0F\13\0F\13\0F\0B\0F\0F\13\05\03Y\01)\0F\1F\1F\1F\07\17\17\1F\1B\1B\07\1F\1F\1F\1F\0B\1B\1B\1F\1F\03\039\02\AA\04\1F\11\01\01\1D]_\1D\1F;\1D\1FE\1D\1FQ\11\01\05\11\01\02\04\11\01\02\02\05'\05)\1DAC#\01\01\01\03\0335\11\1F\00\05+\1D)+\1D)/\05-\13\09\01\05/\15K\17\051\15W\17\01\09\1B\1B\1B\1B\053\11\01\81\0D\1D\055\15=\17\1D\15?\17\13\17\01\057\17\13+\01\15G\17\1D\15I\17\13\19\01\1D\15M\17\13\1B\01\1D-+\15S\17\1D\15U\17\13\1D\01\1D\15Y\17\13\1F\01\1D-/\059\15a\17\1D\15c\17\13!\01#arith.overflow<none>\00\01\02\02\1B\05\02\02\02\04\01\1B\05\02\04\02\02\01\1B\05\02\02\02\04)\0B\1B\03\02\02\01\1B\03\02\04\01\1B\05\02\02\02\04\15\1B\05\02\02\05\01\1B\05\05\02\04\01\07\1B\05\02\04\02\02)\1B\05\02\02\02\02\09\1B\05\02\02\02\04\09\05\09))))\01\01\09\1B\05\02\04\05\01\1B\05\05\02\02\01\1B\05\02\04\02\02\15\1B\05\02\02\02\02\15!tt.ptr<bf16>\00\04\16\0F\05\01P\01\01\07\04\F2\0E\03\01\05\11P\01\03\07\04\C6\0E\03\06\02\FE\03\09QQQQ\00\13B\01\05\03\01\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05\0B\0F\19B\01\0B\03\01\1BF\01\09\03\01\05\0D\13\19B\01\05\03\01\19B\01\05\03\01\19B\01\0B\03\01\1BF\01\09\03\01\05\17\1B\19B\01\07\03\01\1BF\01\09\03\01\05\19\1F\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05#'\19B\01\0B\03\01\1BF\01\09\03\01\05%+\19B\01\05\03\01\19B\01\05\03\01\19B\01\07\03\01\1BF\01\09\03\01\05/3\19B\01\0B\03\01\1BF\01\09\03\01\0517\19B\07\05\03\01\03\06\07\03\03\03;\05B\07\0D\03\0B\07F\07\0F\03\11\03?\09\06\07\03\03\03A\03\06\07\03\03\03\11\1DF\07\09\03\03\05CE\19B\07\0B\03\01\03\06\07\03\03\03I\1BF\07\09\03\03\05GK\1DF\07\09\03\03\05=M\05B\07\11\03\0D\07F\07\05\03\13\03Q\09\06\07\03\03\03S\03\06\07\03\03\03\15\1DF\07\09\03\03\05UW\19B\07\0F\03\01\03\06\07\03\03\03[\1BF\07\09\03\03\05Y]\1DF\07\09\03\03\05O_\03\06\07\03\07\03\01\0B\06\07\03\07\05ca\0DF\07\13\03\0F\03e\19B\09\05\03\01\03\06\09\03\05\03i\05B\09\11\03\0D\07F\09\0F\03!\03m\09\06\09\03\05\03o\03\06\09\03\05\03\1D\1DF\09\09\03\05\05qs\19B\09\07\03\01\03\06\09\03\05\03w\1BF\09\09\03\05\05uy\1DF\09\09\03\05\05k{\05B\09\0D\03\0B\07F\09\05\03#\03\7F\09\06\09\03\05\03\81\03\06\09\03\05\03!\1DF\09\09\03\05\05\83\85\19B\09\0F\03\01\03\06\09\03\05\03\89\1BF\09\09\03\05\05\87\8B\1DF\09\09\03\05\05}\8D\03\06\09\03\17\03\03\0B\06\09\03\17\05\91\8F\0DF\09\13\03%\03\93\19B!\15\03\09\03\06!\03\19\03\97\0FF!\17\03\19\07g\95\99\1FFO\19\03'\03\9B\19B\0B\05\03\01\03\06\0B\03\03\03\9F\05B\0B\0D\03\0B\07F\0B\0F\03\11\03\A3\09\06\0B\03\03\03\A5\03\06\0B\03\03\03)\1DF\0B\09\03\03\05\A7\A9\19B\0B\0B\03\01\03\06\0B\03\03\03\AD\1BF\0B\09\03\03\05\AB\AF\1DF\0B\09\03\03\05\A1\B1\05B\0B\11\03\0D\07F\0B\05\03\13\03\B5\09\06\0B\03\03\03\B7\03\06\0B\03\03\03-\1DF\0B\09\03\03\05\B9\BB\19B\0B\0F\03\01\03\06\0B\03\03\03\BF\1BF\0B\09\03\03\05\BD\C1\1DF\0B\09\03\03\05\B3\C3\03\06\0B\03\07\03\05\0B\06\0B\03\07\05\C7\C5\0DF\0B\13\03\0F\03\C9\19B#\15\03\09\03\06#\03\1B\03\CD\0FF#\17\03\1B\07\9D\CB\CF\1FF[\19\03\0F\03\D1\19B\05\05\03\01\03\06\05\03\03\03\D5\05B\05\0D\03\0B\07F\05\0F\03\11\03\D9\09\06\05\03\03\03\DB\03\06\05\03\03\035\1DF\05\09\03\03\05\DD\DF\19B\05\0B\03\01\03\06\05\03\03\03\E3\1BF\05\09\03\03\05\E1\E5\1DF\05\09\03\03\05\D7\E7\05B\05\11\03\0D\07F\05\05\03\13\03\EB\09\06\05\03\03\03\ED\03\06\05\03\03\039\1DF\05\09\03\03\05\EF\F1\19B\05\0F\03\01\03\06\05\03\03\03\F5\1BF\05\09\03\03\05\F3\F7\1DF\05\09\03\03\05\E9\F9\03\06\05\03\07\03\07\0B\06\05\03\07\05\FD\FB\0DF\05\13\03\0F\03\FF\15D\05\1B\05\FF\D3\17\00\01\06\03\01\05\01\00*\05;\1B\13\0F!-\1B\19\1B'M\0F\0B\0B\13\0F\0D\1F\0B\09\0B\0F\15\19\17\0D\0F\0D\07\11builtin\00tt\00arith\00module\00splat\00make_range\00expand_dims\00broadcast\00addptr\00load\00dot\00func\00get_program_id\00store\00return\00constant\00muli\00addi\00truncf\00//workspace/triton_crash_repro_wip.py\00mha_forward_kernel\00/masked_load\00mha_forward\00/dot_general\00/convert_element_type\00tt.divisibility\00public\00<module>\00/masked_swap\00\08_\1D\05K\01\0Bc7\01%s\03\03\03\11\03\CB\03\0F\05\11\03\03\0D\05\0F\03\113\1B\1B;\01\07\01\03\03'\05\07\07\05\01\01\073\1B\1B", name = "mha_forward", num_stages = 3 : i32, num_warps = 8 : i32}
}

@cheshire
Copy link
Member

Seems like a Triton crash, but could you provide segfault with asan?

@jaro-sevcik
Copy link
Contributor Author

I have not run with ASAN, but debug version fails on OOB access:

... libc ...
#4  __assert_fail () from /lib/x86_64-linux-gnu/libc.so.6
#5  llvm::SmallVectorTemplateCommon<mlir::Value, void>::operator[] (this=0x7ffffffe2230, idx=16) at external/llvm-project/llvm/include/llvm/ADT/SmallVector.h:295
#6  loadReg (rewriter=..., loc=..., elements=..., startIndex=16, numElements=8, insertBefore=0x555557ae3910) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:292
#7  convertDot (typeConverter=0x7ffffffe4a38, rewriter=..., loc=..., op=0x5555579280d0, a=..., b=..., c=..., d=..., useCOperand=..., loadedA=..., loadedB=..., loadedC=..., allowTF32=true,
    needsPartialAccumulator=false, maxNumImpreciseAcc=0, sync=true, thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:447
#8  convertWGMMA (op=..., adaptor=..., typeConverter=0x7ffffffe4a38, rewriter=..., thread=...) at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/WGMMA.cpp:511
#9  (anonymous namespace)::WarpGroupDotOpConversion::matchAndRewrite (this=0x5555579d6600, op=..., adaptor=..., rewriter=...)
    at external/triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp:95
#10 mlir::ConvertOpToLLVMPattern<mlir::triton::nvidia_gpu::WarpGroupDotOp>::matchAndRewrite (this=0x5555579d6600, op=0x5555579280d0, operands=..., rewriter=...)
    at external/llvm-project/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h:165
...

@vwbaker
Copy link
Member

vwbaker commented Sep 24, 2024

I think I narrowed it down to this PR: triton-lang/triton#4492.

After this PR, the new TTGIR for your test looks like this:

// -----// IR Dump Before ConvertTritonGPUToLLVM (convert-triton-gpu-to-llvm) ('builtin.module' operation: @mha_forward) //----- //
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 32, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}>
module @mha_forward attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.shared = 34816 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @mha_forward(%arg0: !tt.ptr<bf16> {tt.divisibility = 32 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 32 : i32}, %arg2: !tt.ptr<bf16> {tt.divisibility = 32 : i32}, %arg3: !tt.ptr<bf16> {tt.divisibility = 32 : i32}) {
    %cst = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #mma>
    %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma1>
    %cst_1 = arith.constant dense<64> : tensor<128x1xi32, #blocked>
    %cst_2 = arith.constant dense<128> : tensor<64x1xi32, #blocked1>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
    %1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1>
    %2 = arith.muli %1, %cst_2 : tensor<64x1xi32, #blocked1>
    %3 = tt.broadcast %2 : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1>
    %4 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
    %5 = tt.expand_dims %4 {axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1>
    %6 = tt.broadcast %5 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1>
    %7 = arith.addi %3, %6 : tensor<64x128xi32, #blocked1>
    %8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<64x128x!tt.ptr<bf16>, #blocked1>
    %9 = tt.addptr %8, %7 : tensor<64x128x!tt.ptr<bf16>, #blocked1>, tensor<64x128xi32, #blocked1>
    %10 = tt.load %9 : tensor<64x128x!tt.ptr<bf16>, #blocked1>
    %11 = triton_gpu.local_alloc %10 {allocation.offset = 0 : i32} : (tensor<64x128xbf16, #blocked1>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory>
    %12 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
    %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
    %14 = arith.muli %13, %cst_1 : tensor<128x1xi32, #blocked>
    %15 = tt.broadcast %14 : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked>
    %16 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>>
    %17 = tt.expand_dims %16 {axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
    %18 = tt.broadcast %17 : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked>
    %19 = arith.addi %15, %18 : tensor<128x64xi32, #blocked>
    %20 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<128x64x!tt.ptr<bf16>, #blocked>
    %21 = tt.addptr %20, %19 : tensor<128x64x!tt.ptr<bf16>, #blocked>, tensor<128x64xi32, #blocked>
    %22 = tt.load %21 : tensor<128x64x!tt.ptr<bf16>, #blocked>
    %23 = triton_gpu.local_alloc %22 {allocation.offset = 16384 : i32} : (tensor<128x64xbf16, #blocked>) -> !tt.memdesc<128x64xbf16, #shared, #triton_gpu.shared_memory>
    triton_nvidia_gpu.fence_async_shared {bCluster = false}
    %24 = triton_nvidia_gpu.warp_group_dot %11, %23, %cst_0 {inputPrecision = 0 : i32} : !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x64xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x64xf32, #mma1>
    %25 = arith.truncf %24 : tensor<64x64xf32, #mma1> to tensor<64x64xbf16, #mma1>
    %26 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<64x128x!tt.ptr<bf16>, #blocked1>
    %27 = tt.addptr %26, %7 : tensor<64x128x!tt.ptr<bf16>, #blocked1>, tensor<64x128xi32, #blocked1>
    %28 = tt.load %27 : tensor<64x128x!tt.ptr<bf16>, #blocked1>
    %29 = triton_gpu.local_alloc %28 {allocation.offset = 0 : i32} : (tensor<64x128xbf16, #blocked1>) -> !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory>
    %30 = triton_gpu.convert_layout %25 : tensor<64x64xbf16, #mma1> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
    triton_nvidia_gpu.fence_async_shared {bCluster = false}
    %31 = triton_nvidia_gpu.warp_group_dot %30, %29, %cst {inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma>
    %32 = triton_gpu.convert_layout %31 {allocation.offset = 0 : i32} : tensor<64x128xf32, #mma> -> tensor<64x128xf32, #blocked1>
    %33 = arith.truncf %32 : tensor<64x128xf32, #blocked1> to tensor<64x128xbf16, #blocked1>
    %34 = tt.splat %arg3 : !tt.ptr<bf16> -> tensor<64x128x!tt.ptr<bf16>, #blocked1>
    %35 = tt.addptr %34, %7 : tensor<64x128x!tt.ptr<bf16>, #blocked1>, tensor<64x128xi32, #blocked1>
    tt.store %35, %33 : tensor<64x128x!tt.ptr<bf16>, #blocked1>
    tt.return
  }
}

Running this through triton-opt -convert-triton-gpu-to-llvm repros the crash.

The interesting part I think is here:

    %30 = triton_gpu.convert_layout %25 : tensor<64x64xbf16, #mma1> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
    triton_nvidia_gpu.fence_async_shared {bCluster = false}
    %31 = triton_nvidia_gpu.warp_group_dot %30, %29, %cst {inputPrecision = 0 : i32} : tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * !tt.memdesc<64x128xbf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma>

The PR is trying to more efficiently chain dots so instead of doing an extra local_alloc and putting it back into shared memory it keeps it in registers for the next mma. However, hopper wgmma LHS in registers is not supported yet in triton as far as I know so this can't work. I'm not sure exactly what the correct solution is yet; maybe it needs a check to not do this in specific situations. But I do know that @ggengnv is working on exactly this (openxla/triton#18) so perhaps they can spot the issue quicker than I can?

@ggengnv
Copy link

ggengnv commented Sep 24, 2024

@vwbaker I think the name of my PR is a bit misleading :) - WGMMA with LHS in registers does already exist in Triton, specifically for chained MMAs. This means keeping MMA1's accumulator in registers, possibly do some casting (and maybe shuffling? I haven't looked at the code closely), and then using these registers as LHS for MMA2. My PR was for loading from shmem into registers for MMA.

The TTGIR level optimization for chained MMA's is done in OptimizeDotOperands: link.

The LLVM lowering for this is in TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM: link.

OTOH, the PR that you linked pertains to "MMA to MMA" layout conversion ("MMA layout" being accumulator layout). It might have modified the LLVM lowering for this following TTGIR:

    %25 = arith.truncf %24 : tensor<64x64xf32, #mma1> to tensor<64x64xbf16, #mma1>

which was later passed to this MMA -> dot conversion

    %30 = triton_gpu.convert_layout %25 : tensor<64x64xbf16, #mma1> -> tensor<64x64xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>

and then to WGMMA, where the crash happened.

I'm not familiar with that part of the codebase, but I think it's likely a bug was added in the PR in TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp. A few fixes have been added to that file since then though (history); might be worth trying bumping Triton's version? (A later commit also added a knob for disabling the new MMA -> MMA logic)

@jaro-sevcik
Copy link
Contributor Author

This looks like the same root cause as triton-lang/triton#4502. The repro there is very similar and the same workaround (changing from num_warps=8 to num_warps=4) works for the Pallas repro, too.

vwbaker added a commit to openxla/triton that referenced this issue Sep 25, 2024
triton-lang#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356
Jokeren pushed a commit to triton-lang/triton that referenced this issue Sep 25, 2024
…4803)

#4492 started causing an
issue where chained MMAs on hopper would segfault with 8 warps. It seems
that previously this was checked, but the check got removed in this PR
and it's still unsupported.

Adding back this check means these MMAs will have to go back to shared
memory, but it's better than segfaulting until it's actually supported.

Resolves openxla/xla#17356

Co-authored-by: Tori <[email protected]>
@vwbaker
Copy link
Member

vwbaker commented Sep 26, 2024

This should be resolved by triton-lang/triton#4803 and will be merged into openxla/xla in next week's integrate

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants