Skip to content

Commit

Permalink
[compiler] Perform hardcoded workgroup distribution for NsNet2 (#48)
Browse files Browse the repository at this point in the history
This unblocks us running kernels by reducing the memory usage while
delegating implementing a proper tiling heuristic to the future. I
believe the design space to be rather small (needs to fit into L1), but
work nevertheless
  • Loading branch information
zero9178 authored Jun 27, 2024
1 parent e86161a commit 719d60b
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 17 deletions.
1 change: 1 addition & 0 deletions codegen/compiler/src/Quidditch/Target/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"Passes.h"
"Passes.h.inc"
SRCS
"ConfigureForSnitch.cpp"
"DisableQuidditchVariant.cpp"
"FormMicrokernels.cpp"
"LinkExecutables.cpp"
Expand Down
143 changes: 143 additions & 0 deletions codegen/compiler/src/Quidditch/Target/ConfigureForSnitch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include "Passes.h"

#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/CPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace quidditch {
#define GEN_PASS_DEF_CONFIGUREFORSNITCHPASS
#include "Quidditch/Target/Passes.h.inc"
} // namespace quidditch

using namespace mlir;
using namespace mlir::iree_compiler;

namespace {
class ConfigureForSnitch
: public quidditch::impl::ConfigureForSnitchPassBase<ConfigureForSnitch> {
public:
using Base::Base;

protected:
void runOnOperation() override;
};
} // namespace

static LogicalResult setRootConfig(FunctionOpInterface funcOp,
Operation *rootOp) {
return TypeSwitch<Operation *, LogicalResult>(rootOp)
.Case<linalg::MatmulTransposeBOp>([&](linalg::LinalgOp op) {
if (funcOp.getName() ==
"main$async_dispatch_0_matmul_transpose_b_1x400x161_f64") {
SmallVector<int64_t> bounds(3, 0);
// Future subgroup distribution.
bounds[0] = 1;
// How many rows we are processing (0 to 400). Should fit in L1.
// Should be as high as possible for subgroup distribution.
// (Could almost be 40).
bounds[1] = 50;

// Reduction dimension (0 to 161). How many columns are we processing
// at once?
// Cannot be distributed. As wide as possible for FPU utilization of a
// single core.
bounds[2] = 0;

TileSizesListType tileSizes = {bounds};
return setOpConfigAndEntryPointFnTranslation(
funcOp, rootOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}
if (funcOp.getName() ==
"main$async_dispatch_7_matmul_transpose_b_1x600x400_f64") {
SmallVector<int64_t> bounds(3, 0);
// Future subgroup distribution.
bounds[0] = 1;
// How many rows we are processing (0 to 600). Should fit in L1.
// Should be as high as possible for subgroup distribution.
// (Could almost be 40).
bounds[1] = 25;

// Reduction dimension (0 to 400). How many columns are we processing
// at once?
// Cannot be distributed. As wide as possible for FPU utilization of a
// single core.
bounds[2] = 0;

TileSizesListType tileSizes = {bounds};
return setOpConfigAndEntryPointFnTranslation(
funcOp, rootOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}
if (funcOp.getName() ==
"main$async_dispatch_8_matmul_transpose_b_1x600x600_f64") {
SmallVector<int64_t> bounds(3, 0);
// Future subgroup distribution.
bounds[0] = 1;
// How many rows we are processing (0 to 600). Should fit in L1.
// Should be as high as possible for subgroup distribution.
bounds[1] = 15;

// Reduction dimension (0 to 600). How many columns are we processing
// at once?
// Cannot be distributed. As wide as possible for FPU utilization of a
// single core.
bounds[2] = 0;

TileSizesListType tileSizes = {bounds};
return setOpConfigAndEntryPointFnTranslation(
funcOp, rootOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}
if (funcOp.getName() ==
"main$async_dispatch_1_matmul_transpose_b_1x1200x400_f64") {
SmallVector<int64_t> bounds(3, 0);
// Future subgroup distribution.
bounds[0] = 0;
// How many rows we are processing (0 to 1200). Should fit in L1.
// Should be as high as possible for subgroup distribution.
bounds[1] = 25;
// Reduction dimension (0 to 400). How many columns we are processing
// at once?
// Cannot be distributed. As wide as possible for FPU utilization of a
// single core.
bounds[2] = 0;

TileSizesListType tileSizes = {bounds};
return setOpConfigAndEntryPointFnTranslation(
funcOp, rootOp, tileSizes,
IREE::Codegen::DispatchLoweringPassPipeline::None);
}

return success();
})
.Default(success());
}

void ConfigureForSnitch::runOnOperation() {
FunctionOpInterface funcOp = getOperation();
if (getTranslationInfo(funcOp))
return;

SmallVector<Operation *> computeOps = getComputeOps(funcOp);
FailureOr<Operation *> rootOp = getRootOperation(computeOps);
if (failed(rootOp))
return signalPassFailure();
Operation *rootOperation = rootOp.value();
if (!rootOperation)
return;

if (failed(setRootConfig(funcOp, rootOperation)))
return signalPassFailure();

// The root configuration setting introduces `tensor.dim` operations.
// Resolve those away.
RewritePatternSet patterns(funcOp.getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
signalPassFailure();
}
4 changes: 4 additions & 0 deletions codegen/compiler/src/Quidditch/Target/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ def DisableQuidditchVariantPass : Pass<"quidditch-disable-variant",

def ReluToMaxPass : Pass<"quidditch-relu-to-max">;

def ConfigureForSnitchPass
: InterfacePass<"quidditch-configure-for-snitch",
"mlir::FunctionOpInterface">;

#endif
29 changes: 24 additions & 5 deletions codegen/compiler/src/Quidditch/Target/QuidditchTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct QuidditchTargetOptions {
std::string xDSLOptPath;
std::string toolChainRoot;
bool assertCompiled = false;
unsigned l1MemoryBytes = 112640;
// TODO: This should actually be 112640 but DMA stack overflows. Ooopsie!
unsigned l1MemoryBytes = 100000;

void bindOptions(OptionsBinder &binder) {
LLVMInitializeRISCVTarget();
Expand Down Expand Up @@ -147,6 +148,21 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
StringAttr::get(context, "static"), list.getDictionary(context)));
}

void
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
{
FunctionLikeNest funcPassManager(modulePassManager);
addCommonTargetExecutablePreprocessingPasses(
funcPassManager,
/*useDecomposeSoftmaxFusion=*/false);
}
modulePassManager.addPass(createMaterializeUserConfigsPass());
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(quidditch::createConfigureForSnitchPass);
}

void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
OpPassManager &passManager) override {
OpPassManager &modulePassManager = passManager.nest<ModuleOp>();
Expand Down Expand Up @@ -178,6 +194,8 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
.addPass(createEliminateEmptyTensorsPass)
.addPass(bufferization::createEmptyTensorToAllocTensorPass)
.addPass(quidditch::Snitch::createPromoteToL1Pass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass([&] {
return createIREEComprehensiveBufferizePass(allocationFn, memcpyFn);
});
Expand Down Expand Up @@ -289,7 +307,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
return objectFiles;
}

std::unique_ptr<llvm::Module>
static std::unique_ptr<llvm::Module>
toLLVMModule(llvm::LLVMContext &context, ModuleOp module,
const llvm::TargetMachine &machine,
IREE::HAL::ExecutableVariantOp variantOp) {
Expand Down Expand Up @@ -367,7 +385,8 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
return llvmModule;
}

void optimizeLLVMModule(llvm::Module &module, llvm::TargetMachine &machine) {
static void optimizeLLVMModule(llvm::Module &module,
llvm::TargetMachine &machine) {

llvm::LoopAnalysisManager loopAnalysisManager;
llvm::FunctionAnalysisManager functionAnalysisManager;
Expand All @@ -392,7 +411,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
modulePassManager.run(module, moduleAnalysisManager);
}

FailureOr<IREE::HAL::Artifact>
static FailureOr<IREE::HAL::Artifact>
compileLLVMModule(llvm::Module &module, llvm::TargetMachine &machine) {
auto objectFile = IREE::HAL::Artifact::createTemporary("iree-out", "o");

Expand Down Expand Up @@ -481,7 +500,7 @@ class QuidditchTargetBackend final : public IREE::HAL::TargetBackend {
std::vector<uint8_t> libraryNameVector(libraryName.begin(),
libraryName.end());
executableBuilder.create<IREE::HAL::ExecutableBinaryOp>(
variantOp.getLoc(), variantOp.getSymName(), "static",
variantOp.getLoc(), variantOp.getSymName(), "snitch",
libraryNameVector);

return success();
Expand Down
2 changes: 1 addition & 1 deletion runtime/runtime/src/Quidditch/dispatch/dispatch.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ void quidditch_dispatch_queue_workgroup(
if (nextCoreToUse != snrt_cluster_compute_core_num()) return;

quidditch_dispatch_execute_workgroups();
reset_workgroup_state();
}

void quidditch_dispatch_execute_workgroups() {
Expand All @@ -104,4 +103,5 @@ void quidditch_dispatch_execute_workgroups() {
snrt_cluster_hw_barrier();
// Then wait for workers to be done.
snrt_cluster_hw_barrier();
reset_workgroup_state();
}
42 changes: 32 additions & 10 deletions runtime/runtime/src/Quidditch/executable/executable.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ iree_status_t quidditch_executable_create(
executable->identifier = iree_make_cstring_view((*library_header)->name);
executable->dispatch_attrs = executable->library.v0->exports.attrs;
}
executable->is_llvm = !iree_string_view_equal(
executable_params->executable_format, IREE_SV("snitch"));

// Copy executable constants so we own them.
if (iree_status_is_ok(status) && executable_params->constant_count > 0) {
Expand Down Expand Up @@ -164,7 +166,7 @@ iree_status_t quidditch_executable_issue_dispatch_inline(
});
#endif // IREE_HAL_VERBOSE_TRACING_ENABLE

iree_hal_executable_workgroup_state_v0_t workgroup_state;
iree_hal_executable_workgroup_state_v0_t workgroup_state = {0};

workgroup_state.local_memory = local_memory.data;
workgroup_state.local_memory_size = (size_t)local_memory.data_length;
Expand All @@ -181,19 +183,39 @@ iree_status_t quidditch_executable_issue_dispatch_inline(
dispatch_state);

read_csr(mcycle);
for (uint32_t z = 0; z < workgroup_count_z; ++z) {
workgroup_state.workgroup_id_z = z;
for (uint32_t y = 0; y < workgroup_count_y; ++y) {
workgroup_state.workgroup_id_y = y;
for (uint32_t x = 0; x < workgroup_count_x; ++x) {
workgroup_state.workgroup_id_x = x;

quidditch_dispatch_queue_workgroup(&workgroup_state);
if (executable->is_llvm) {
// LLVM distributes workgroups to compute cores.
for (uint32_t z = 0; z < workgroup_count_z; ++z) {
workgroup_state.workgroup_id_z = z;
for (uint32_t y = 0; y < workgroup_count_y; ++y) {
workgroup_state.workgroup_id_y = y;
for (uint32_t x = 0; x < workgroup_count_x; ++x) {
workgroup_state.workgroup_id_x = x;

quidditch_dispatch_queue_workgroup(&workgroup_state);
}
}
}

quidditch_dispatch_execute_workgroups();
} else {
// Snitch distributes workgroups to clusters.
// I.e., one workgroup runs on one cluster.
// TODO: Subgroup distribution.
for (uint32_t z = 0; z < workgroup_count_z; ++z) {
workgroup_state.workgroup_id_z = z;
for (uint32_t y = 0; y < workgroup_count_y; ++y) {
workgroup_state.workgroup_id_y = y;
for (uint32_t x = 0; x < workgroup_count_x; ++x) {
workgroup_state.workgroup_id_x = x;

quidditch_dispatch_queue_workgroup(&workgroup_state);
quidditch_dispatch_execute_workgroups();
}
}
}
}

quidditch_dispatch_execute_workgroups();
read_csr(mcycle);

if (quidditch_dispatch_errors_occurred())
Expand Down
2 changes: 2 additions & 0 deletions runtime/runtime/src/Quidditch/executable/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ typedef struct quidditch_executable_t {
const iree_hal_executable_library_v0_t* v0;
} library;

bool is_llvm;

iree_hal_pipeline_layout_t* layouts[];
} quidditch_executable_t;

Expand Down
4 changes: 3 additions & 1 deletion runtime/runtime/src/Quidditch/loader/loader.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ static bool quidditch_loader_query_support(
iree_hal_executable_caching_mode_t caching_mode,
iree_string_view_t executable_format) {
return iree_string_view_equal(executable_format,
iree_make_cstring_view("static"));
iree_make_cstring_view("static")) ||
iree_string_view_equal(executable_format,
iree_make_cstring_view("snitch"));
}

static iree_status_t quidditch_loader_try_load(
Expand Down

0 comments on commit 719d60b

Please sign in to comment.