Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][Atomics, Buffer Ops] Add support for buffer atomic RMW #5549

Merged
merged 25 commits into from
Jan 17, 2025

Conversation

SamGinzburg
Copy link
Contributor

@SamGinzburg SamGinzburg commented Jan 7, 2025

Overview

This PR enables the raw.ptr.buffer.atomic.* RMW ops in the AMD backend. They feature similar calling conventions and semantics to the other buffer ops in the AMD backend.

The new ops are gated behind the AMDGCN_ENABLE_BUFFER_ATOMICS environment variable which must be used in conjunction with AMDGCN_USE_BUFFER_OPS. They are also gated behind the GPU being CDNA3 (MI300-series GPUs) for now as the optimizations I added make assumptions regarding GFX942.

I originally started exploratory work on the PR to better understand the comment in LoadStoreOpToLLVM.cpp referring to buffer atomics as "more efficient". In short I found that on their own they aren't necessarily more efficient, but using them in conjunction with more careful control over how cache coherence ops/memory fences are emitted can improve performance by a significant fraction.

How

I've added a new buffer atomic RMW op in the AMDGPUOps dialect which has its own lowering in the backend. There are a number of checks in place to ensure that the lowering is done correctly between the ConvertToBufferOps pass and the LoadStoreOpToLLVM lowering.

The actual lowering is where most of the performance gains come from. At a high-level, when non-buffer atomic RMW ops are emitted, the memory fences lower to something along the lines of:

buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1
buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1

If my understanding of the GFX942 memory model is correct, then given several assumptions regarding CDNA3, this can actually be lowered to something that resembles:

buffer_wbl2 sc1
s_waitcnt lgkmcnt(0)
atomicRMWop()
s_waitcnt vmcnt(0) # AMDGCN specific cross-CU synchronization primitive
atomicRMWop()
s_waitcnt vmcnt(0) 
buffer_inv sc1

There are comments in the code which explain the thought process for why (I think) that this is okay.

It appears the AMD's CK library (AMD version of CUTLASS) uses similar synchronization mechanisms, although I am probably missing some of the context here for sure (https://github.com/ROCm/composable_kernel/blob/9e95d54cd2160dffc07c1197951a9ab1ca6c35f2/include/ck_tile/core/arch/amd_buffer_addressing.hpp#L619).

Results and Testing

In addition to the added lit test, I ran the existing atomic rmw tests in tree with buffer ops + buffer atomics enabled and they appear to pass.

Following this, I evaluated FP16 Split-K gemm with llama shapes in tritonbench using an MI300x. Some minor modifications to the kernel were made to emit buffer ops (e.g., tl.assume calls). For testing purposes, I disabled the non split-k configurations. I also checked the numerical accuracy with rtol=atol=1e-4 for all shapes here.

image

Each bucket in the figure above corresponds to the average TFlops of all shapes with the same shared M-dim.

At smaller batch sizes the performance is roughly equivalent. At BS=32, buffer atomics have ~50% greater TFlops. At BS=256 buffer atomics have ~3.75x the TFlops.

Note: the purpose of this test is to evaluate the performance of buffer atomics---split-k is not always optimal for these shapes/workload etc...

============================================================================================

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Copy link
Collaborator

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding support for it! I've a couple of comments. Also, could you turn AMDGCN_USE_BUFFER_OPS on for now so we can test it out? We will turn it back to off before landing.

third_party/amd/python/triton_amd.cc Outdated Show resolved Hide resolved
third_party/amd/backend/compiler.py Outdated Show resolved Hide resolved
third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp Outdated Show resolved Hide resolved
third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h Outdated Show resolved Hide resolved
third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h Outdated Show resolved Hide resolved
Copy link
Contributor

@giuseros giuseros left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an amazing PR! Thanks @SamGinzburg for not only extending buffer support but only coming up with a better lowering for atomic operations! I left few comments and agree what the comments left by @antiagainst !

TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
]>{
let summary = "Load from a scalar base pointer and a tensor offset";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this summary correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for catching this, I've updated it to be accurate to atomicrmw

Type bufferElementType = elementType;
if (elementType.isBF16())
// We don't want to cast to bf16 if we are emitting buffer atomics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I had few bugs with memory operations when I was not casting bf16 to i16. Are those bugs not there for atomics?

Copy link
Contributor Author

@SamGinzburg SamGinzburg Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are there, but present in different forms. Casting to i16 causes an error in LLVM (LLVM translation failed for operation) and passing bf16 through causes issues with instruction selection. There's the second issue that for loads/stores the type of the buffer is less important (just need a correctly sized op, can bitcast later---which is what I believe the code does today). For atomic rmw I think the type needs to be correct (e.g., fadd for fp16 vs bf16 is different).

image

The instruction does exist (or at least according to the docs it should).

I'm going to try and reach out to the AMD/LLVM team regarding this at some point, but since buffer ops are off by default and I had to disable the triton bf16 atomic fadd check to trigger this I don't think it should necessarily block the PR.

@@ -164,7 +200,7 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc,
// bit 0: GLC = 0 (atomics drop value, less coherency)
// bits 1-2: SLC, DLC = 0 (similarly)
// bit 3: swizzled (0 for raw)
Value cacheModifiers = int_val(32, 0);
Value cacheModifiers = int_val(32, cacheModifiersFlag);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not mandatory, but could we sync with @yiqian1 to get her PR merged first? I would like to have the cache modifiers story sorted properly instead of an ad hoc value only for the RMW case. But this is correct anyway, so if @yiqian1 's PR takes too long to be merged, I am happy for you to proceed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can do it either way, I think its up to whichever PR is ready to land first, I don't mind rebasing.

@SamGinzburg SamGinzburg force-pushed the PR-BufferAtomicRMW branch 2 times, most recently from 8f7ce03 to cb1a267 Compare January 10, 2025 19:36
@SamGinzburg
Copy link
Contributor Author

Nice! Thanks for adding support for it! I've a couple of comments. Also, could you turn AMDGCN_USE_BUFFER_OPS on for now so we can test it out? We will turn it back to off before landing.

Thanks! I've set the flag to be true for now!

@SamGinzburg SamGinzburg force-pushed the PR-BufferAtomicRMW branch 2 times, most recently from 458a1aa to 4769d58 Compare January 13, 2025 17:37
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also need to CHECK amdgpu.buffer_atomic_rmw is generated?

@antiagainst
Copy link
Collaborator

The patch LGTM now; can you resolve the conflicts so we can land @SamGinzburg?

@scxiao
Copy link
Contributor

scxiao commented Jan 15, 2025

Hi @SamGinzburg, just wondering whether this lowering optimization applicable to non-buffer-atomics (i.e., global_atomic)? Thanks

@SamGinzburg
Copy link
Contributor Author

SamGinzburg commented Jan 15, 2025

Hi @SamGinzburg, just wondering whether this lowering optimization applicable to non-buffer-atomics (i.e., global_atomic)? Thanks

Yes I think so, buffer ops just make it easier to control the lowering. I can put up a follow-up PR which does the same for those, but we will just be emitting inline assembly if that is okay (unless LLVM can add an optimization which automatically does this)

@scxiao
Copy link
Contributor

scxiao commented Jan 15, 2025

Hi @SamGinzburg, just wondering whether this lowering optimization applicable to non-buffer-atomics (i.e., global_atomic)? Thanks

Yes I think so, buffer ops just make it easier to control the lowering. I can put up a follow-up PR which does the same for those, but we will just be emitting inline assembly if that is okay (unless LLVM can add an optimization which automatically does this)

Thanks for the quick reply. Some other thoughts, I am wondering whether the tl.atomic_add() in the splitK gemm can use the sem input relaxed, see https://triton-lang.org/main/python-api/generated/triton.language.atomic_add.html#triton.language.atomic_add.
The default 'acq_rel` is used to create a critical section for data communication like the example here:

def serialized_add(data, Lock, SEM: tl.constexpr):
.

@SamGinzburg
Copy link
Contributor Author

Hi @SamGinzburg, just wondering whether this lowering optimization applicable to non-buffer-atomics (i.e., global_atomic)? Thanks

Yes I think so, buffer ops just make it easier to control the lowering. I can put up a follow-up PR which does the same for those, but we will just be emitting inline assembly if that is okay (unless LLVM can add an optimization which automatically does this)

Thanks for the quick reply. Some other thoughts, I am wondering whether the tl.atomic_add() in the splitK gemm can use the sem input relaxed, see https://triton-lang.org/main/python-api/generated/triton.language.atomic_add.html#triton.language.atomic_add. The default 'acq_rel` is used to create a critical section for data communication like the example here:

def serialized_add(data, Lock, SEM: tl.constexpr):

.

Yes that is correct, with sem="relaxed", the performance is equivalent between buffer atomics and regular. When sem="acq_rel" the gap is much larger. e.g., For M=128 N=13312 K=16384, the gap is 75 vs 177 TFLOPs with acq_rel. With sem="relaxed", both get ~228 TFLOPs.

@SamGinzburg
Copy link
Contributor Author

tests currently failing with "urllib.error.HTTPError: HTTP Error 524"---@antiagainst possibly needs to be restarted

@antiagainst antiagainst merged commit 6556ec6 into triton-lang:main Jan 17, 2025
7 checks passed
antiagainst pushed a commit that referenced this pull request Jan 28, 2025
This is a minor change, when implementing PR #5549 I used:
```rewriter.notifyMatchFailure``` in place of ```return failure();``` as
per suggestions to leverage MLIR infra for errors.

We should probably be consistent throughout the file and use the MLIR
infra for the other buffer ops.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants