Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
fix: fix reduce1d and reduce2d by reverting wave size
Browse files Browse the repository at this point in the history
  • Loading branch information
evshiron committed Jun 20, 2023
1 parent 0290036 commit bbd1441
Show file tree
Hide file tree
Showing 14 changed files with 34 additions and 82 deletions.
9 changes: 0 additions & 9 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,8 @@ for
"ArrayRef<unsigned>":$order,
"unsigned":$numWarps), [{
int rank = sizePerThread.size();
#ifdef USE_ROCM
unsigned remainingLanes = 64;
unsigned remainingThreads = numWarps*64;
#else
unsigned remainingLanes = 32;
unsigned remainingThreads = numWarps*32;
#endif
unsigned remainingWarps = numWarps;
unsigned prevLanes = 1;
unsigned prevWarps = 1;
Expand All @@ -264,11 +259,7 @@ for
prevWarps *= warpsPerCTA[i];
}
// Expand the last dimension to fill the remaining lanes and warps
#ifdef USE_ROCM
threadsPerWarp[order[rank-1]] = 64 / prevLanes;
#else
threadsPerWarp[order[rank-1]] = 32 / prevLanes;
#endif
warpsPerCTA[order[rank-1]] = numWarps / prevWarps;
return $_get(context, sizePerThread, threadsPerWarp, warpsPerCTA, order);

Expand Down
4 changes: 0 additions & 4 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,7 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
/// shared memory block1:
auto mod = op->getParentOfType<ModuleOp>();
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
#ifdef USE_ROCM
smemShapes[1].push_back(numWarps * 64);
#else
smemShapes[1].push_back(numWarps * 32);
#endif
return smemShapes;
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@ struct ConvertLayoutOpConversion
SmallVector<Value> mfmaColIdx(4);
SmallVector<Value> mfmaRowIdx(16);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(64);
// TODO: shall we change this?
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
// TODO: fix the bug in MMAEncodingAttr document
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
auto numRepM = numReps[0];
auto numRepK = numReps[1];

Value waveSize = i32_val(64);
Value waveSize = i32_val(32);
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);

Value waveM =
getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]);
int numOfElems = std::max<int>(mfmaInstrM * mfmaInstrK / 64 /*wave size*/, 1);
int numOfElems = std::max<int>(mfmaInstrM * mfmaInstrK / 32 /*wave size*/, 1);
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
Expand Down Expand Up @@ -224,13 +224,13 @@ Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
auto numRepK = numReps[0];
auto numRepN = numReps[1];

Value waveSize = i32_val(64);
Value waveSize = i32_val(32);
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);

Value waveN = getWaveN(rewriter, loc, wave, warpsPerCTA,
mfmaInstrN, shape[1]);
int numOfElems = std::max<int>(mfmaInstrK * mfmaInstrN / 64 /*wave size*/, 1);
int numOfElems = std::max<int>(mfmaInstrK * mfmaInstrN / 32 /*wave size*/, 1);
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);

int macroTileM =
Expand Down
9 changes: 0 additions & 9 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,7 @@ struct ReduceOpConversion
}

Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
#endif
Value warpId = udiv(threadId, warpSize);
Value laneId = urem(threadId, warpSize);

Expand Down Expand Up @@ -451,13 +447,8 @@ struct ReduceOpConversion
//
// Each thread needs to process:
// elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads
#ifdef USE_ROCM
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 64;
#else
unsigned numThreads =
product<unsigned>(triton::gpu::getWarpsPerCTA(srcLayout)) * 32;
#endif
unsigned elemsPerThread = std::max<unsigned>(elems / numThreads, 1);
Value readOffset = threadId;
for (unsigned round = 0; round < elemsPerThread; ++round) {
Expand Down
11 changes: 2 additions & 9 deletions lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
#endif
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
Expand Down Expand Up @@ -719,11 +715,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const {
auto shape = type.getShape();
Value threadId = getThreadId(rewriter, loc);
#ifdef USE_ROCM
Value warpSize = i32_val(64);
#else
Value warpSize = i32_val(32);
#endif
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto sizePerThread = blocked_layout.getSizePerThread();
Expand Down Expand Up @@ -996,7 +988,8 @@ class ConvertTritonGPUOpToLLVMPatternBase {
i32_val(_warpsPerCTA[1])};

Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(64);
// TODO: shall we change this?
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);

Value warpId = udiv(threadId, warpSize);
Expand Down
4 changes: 0 additions & 4 deletions lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,7 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
int numWarps = typeConverter->getNumWarps();

SmallVector<unsigned> retSizePerThread = {1, 1};
#ifdef USE_ROCM
int warpSize = 64;
#else
int warpSize = 32;
#endif

if (origShape[0] * origShape[1] / (numWarps * warpSize) >= 4)
retSizePerThread = {2, 2};
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
if (auto mfmaParent = getParent().dyn_cast<MfmaEncodingAttr>()) {
int warpsPerCTAM = mfmaParent.getWarpsPerCTA()[0];
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
constexpr int waveSize = 64;
// constexpr int waveSize = 32;
auto tileSize = getMFMAElemsPerThread(eltTy);
auto rep = getMFMARep(shape, eltTy);
return rep[0] * rep[1];
Expand Down
4 changes: 0 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/Coalesce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
}
}
int numElems = product(origType.getShape());
#ifdef USE_ROCM
int numThreads = numWarps * 64;
#else
int numThreads = numWarps * 32;
#endif
int numElemsPerThread = std::max(numElems / numThreads, 1);
// Thread tile size depends on memory alignment
SmallVector<unsigned, 4> sizePerThread(rank, 1);
Expand Down
41 changes: 13 additions & 28 deletions python/test/unit/language/test_core_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,34 +2143,19 @@ def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
def __str__(self):
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"

if torch.version.hip is not None:
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
# MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
# MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 2], [2, 32], [2, 2], [1, 0]),
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0]),
BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1]),
BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1]),
BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0])
]
else:
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]
layouts = [
# MmaLayout(version=1, warps_per_cta=[1, 4]),
# MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
# MmaLayout(version=1, warps_per_cta=[4, 1]),
# MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
]

@pytest.mark.parametrize("shape", [(128, 128)])
@pytest.mark.parametrize("dtype", ['float16'])
Expand Down
2 changes: 1 addition & 1 deletion python/triton/compiler/make_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def format_of(ty):
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if (gridX*gridY*gridZ > 0) {{
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 64*num_warps, 1, 1, shared_memory, stream, params, 0));
HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0));
}}
}}
Expand Down
7 changes: 5 additions & 2 deletions scripts/gfx1100/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ chmod -R 777 $LOG_DIR

bash scripts/amd/lit.sh 2>&1 | tee $LOG_DIR/lit.log

UNIT_TEST="python/test/unit/language/test_core_amd.py"
# UNIT_TEST="python/test/unit/language/test_core_amd.py"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_math_op"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_reduce1d"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_reduce2d"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_convert2d"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_reduce_layouts"
# UNIT_TEST="python/test/unit/language/test_core.py::test_empty_kernel[float32]"
# UNIT_TEST="python/test/unit/runtime/test_cache.py::test_compile_in_subproc"
# UNIT_TEST="python/test/unit/language/test_core_amd.py::test_shift_op[int8-int8-<<]"
Expand All @@ -43,7 +45,8 @@ if [ "$1" == "backtrace" ]; then
2>&1 | tee $LOG_DIR/backtrace.log

else
pytest --capture=tee-sys -rfs --verbose "$UNIT_TEST" 2>&1 | tee $LOG_DIR/unit_test.log
# pytest --capture=tee-sys -rfs --verbose "$UNIT_TEST" 2>&1 | tee $LOG_DIR/unit_test.log
pytest --capture=tee-sys -rfs --verbose python/test/unit/language/test_core_amd.py -k 'test_math_op or test_reduce1d or test_reduce2d or test_convert2d or test_reduce_layouts' 2>&1 | tee $LOG_DIR/unit_test.log
fi

# bash scripts/amd/cache_print.sh 2>&1 |tee $LOG_DIR/cache.log
8 changes: 4 additions & 4 deletions test/Conversion/triton_to_tritongpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ tt.func @load_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// -----

tt.func @reduce_ops(%ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
// Test if the total number of threadsPerWarp is 64
// Test if the total number of threadsPerWarp is 32
// Test if the total number of warps is 2
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 16], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked0:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked1:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: #[[blocked2:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 2], order = [0, 1]}>
// CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}}
%c0 = arith.constant dense<1.00e+00> : tensor<4x4xf32>
%c1 = arith.constant dense<2.00e+00> : tensor<8x2xf32>
Expand Down
4 changes: 2 additions & 2 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
module attributes {"triton_gpu.num-warps" = 4 : i32} {


// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
// CHECK: [[row_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [0, 1]}>
// CHECK: [[load_ptr:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64x!tt.ptr<f32>, [[row_layout]]>
// CHECK: [[load_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[row_layout]]>
// CHECK: [[load_other:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[row_layout]]>
Expand Down

0 comments on commit bbd1441

Please sign in to comment.