-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD][Atomics, Buffer Ops] Add support for buffer atomic RMW #5549
[AMD][Atomics, Buffer Ops] Add support for buffer atomic RMW #5549
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for adding support for it! I've a couple of comments. Also, could you turn AMDGCN_USE_BUFFER_OPS
on for now so we can test it out? We will turn it back to off before landing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an amazing PR! Thanks @SamGinzburg for not only extending buffer support but only coming up with a better lowering for atomic operations! I left few comments and agree what the comments left by @antiagainst !
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)", | ||
"($_op.getOperands().size() <= 3) || std::equal_to<>()">, | ||
]>{ | ||
let summary = "Load from a scalar base pointer and a tensor offset"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this summary correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for catching this, I've updated it to be accurate to atomicrmw
Type bufferElementType = elementType; | ||
if (elementType.isBF16()) | ||
// We don't want to cast to bf16 if we are emitting buffer atomics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? I had few bugs with memory operations when I was not casting bf16 to i16. Are those bugs not there for atomics?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are there, but present in different forms. Casting to i16 causes an error in LLVM (LLVM translation failed for operation) and passing bf16 through causes issues with instruction selection. There's the second issue that for loads/stores the type of the buffer is less important (just need a correctly sized op, can bitcast later---which is what I believe the code does today). For atomic rmw I think the type needs to be correct (e.g., fadd for fp16 vs bf16 is different).
The instruction does exist (or at least according to the docs it should).
I'm going to try and reach out to the AMD/LLVM team regarding this at some point, but since buffer ops are off by default and I had to disable the triton bf16 atomic fadd check to trigger this I don't think it should necessarily block the PR.
@@ -164,7 +200,7 @@ void BufferEmitter::fillCommonArgs(Type type, Value rsrcDesc, | |||
// bit 0: GLC = 0 (atomics drop value, less coherency) | |||
// bits 1-2: SLC, DLC = 0 (similarly) | |||
// bit 3: swizzled (0 for raw) | |||
Value cacheModifiers = int_val(32, 0); | |||
Value cacheModifiers = int_val(32, cacheModifiersFlag); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we can do it either way, I think its up to whichever PR is ready to land first, I don't mind rebasing.
8f7ce03
to
cb1a267
Compare
Thanks! I've set the flag to be true for now! |
458a1aa
to
4769d58
Compare
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0 | ||
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32 | ||
%8 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> | ||
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to CHECK
amdgpu.buffer_atomic_rmw
is generated?
c861075
to
ed18ac6
Compare
The patch LGTM now; can you resolve the conflicts so we can land @SamGinzburg? |
lint
cleanup more nits lint lint nits
lint update comments nit nit
Hi @SamGinzburg, just wondering whether this lowering optimization applicable to non-buffer-atomics (i.e., global_atomic)? Thanks |
Yes I think so, buffer ops just make it easier to control the lowering. I can put up a follow-up PR which does the same for those, but we will just be emitting inline assembly if that is okay (unless LLVM can add an optimization which automatically does this) |
Thanks for the quick reply. Some other thoughts, I am wondering whether the tl.atomic_add() in the splitK gemm can use the triton/python/test/unit/language/test_core.py Line 1614 in f9d9fad
|
ed18ac6
to
74bb07f
Compare
Yes that is correct, with sem="relaxed", the performance is equivalent between buffer atomics and regular. When sem="acq_rel" the gap is much larger. e.g., For M=128 N=13312 K=16384, the gap is 75 vs 177 TFLOPs with acq_rel. With sem="relaxed", both get ~228 TFLOPs. |
tests currently failing with "urllib.error.HTTPError: HTTP Error 524"---@antiagainst possibly needs to be restarted |
This is a minor change, when implementing PR #5549 I used: ```rewriter.notifyMatchFailure``` in place of ```return failure();``` as per suggestions to leverage MLIR infra for errors. We should probably be consistent throughout the file and use the MLIR infra for the other buffer ops.
Overview
This PR enables the raw.ptr.buffer.atomic.* RMW ops in the AMD backend. They feature similar calling conventions and semantics to the other buffer ops in the AMD backend.
The new ops are gated behind the
AMDGCN_ENABLE_BUFFER_ATOMICS
environment variable which must be used in conjunction withAMDGCN_USE_BUFFER_OPS
. They are also gated behind the GPU being CDNA3 (MI300-series GPUs) for now as the optimizations I added make assumptions regarding GFX942.I originally started exploratory work on the PR to better understand the comment in
LoadStoreOpToLLVM.cpp
referring to buffer atomics as "more efficient". In short I found that on their own they aren't necessarily more efficient, but using them in conjunction with more careful control over how cache coherence ops/memory fences are emitted can improve performance by a significant fraction.How
I've added a new buffer atomic RMW op in the AMDGPUOps dialect which has its own lowering in the backend. There are a number of checks in place to ensure that the lowering is done correctly between the ConvertToBufferOps pass and the LoadStoreOpToLLVM lowering.
The actual lowering is where most of the performance gains come from. At a high-level, when non-buffer atomic RMW ops are emitted, the memory fences lower to something along the lines of:
If my understanding of the GFX942 memory model is correct, then given several assumptions regarding CDNA3, this can actually be lowered to something that resembles:
There are comments in the code which explain the thought process for why (I think) that this is okay.
It appears the AMD's CK library (AMD version of CUTLASS) uses similar synchronization mechanisms, although I am probably missing some of the context here for sure (https://github.com/ROCm/composable_kernel/blob/9e95d54cd2160dffc07c1197951a9ab1ca6c35f2/include/ck_tile/core/arch/amd_buffer_addressing.hpp#L619).
Results and Testing
In addition to the added lit test, I ran the existing atomic rmw tests in tree with buffer ops + buffer atomics enabled and they appear to pass.
Following this, I evaluated FP16 Split-K gemm with llama shapes in tritonbench using an MI300x. Some minor modifications to the kernel were made to emit buffer ops (e.g., tl.assume calls). For testing purposes, I disabled the non split-k configurations. I also checked the numerical accuracy with rtol=atol=1e-4 for all shapes here.
Each bucket in the figure above corresponds to the average TFlops of all shapes with the same shared
M
-dim.At smaller batch sizes the performance is roughly equivalent. At BS=32, buffer atomics have ~50% greater TFlops. At BS=256 buffer atomics have ~3.75x the TFlops.
Note: the purpose of this test is to evaluate the performance of buffer atomics---split-k is not always optimal for these shapes/workload etc...
============================================================================================
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)