Skip to content

Commit

Permalink
Custom kernel lowering (#191)
Browse files Browse the repository at this point in the history
* Custom kernel lowering

* fix

* actually add file

* gpu kernel generation

* almost complete

* use the same context

* fixup

* Doing things properly

* cleaning up

* cleanup

* cleanup

* fmt

* final clean

* cleanup

* fmt

* Now with dynamic shmem

* fmt

---------

Co-authored-by: Alex Zinenko <[email protected]>
  • Loading branch information
wsmoses and ftynse authored Dec 17, 2024
1 parent 51687b0 commit fb483c0
Show file tree
Hide file tree
Showing 18 changed files with 1,192 additions and 67 deletions.
41 changes: 38 additions & 3 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@ load("@rules_python//python:packaging.bzl", "py_wheel")
load(":package.bzl", "py_package")
load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION")

load(
"@xla//xla/tsl/platform:build_config_root.bzl",
"if_llvm_aarch32_available",
"if_llvm_aarch64_available",
"if_llvm_powerpc_available",
"if_llvm_system_z_available",
"if_llvm_x86_available",
)

licenses(["notice"])

package(
Expand All @@ -24,7 +33,10 @@ py_package(

cc_binary(
name = "enzymexlamlir-opt",
srcs = ["//src/enzyme_ad/jax:enzymexlamlir-opt.cpp"],
srcs = [
"//src/enzyme_ad/jax:enzymexlamlir-opt.cpp",
"//src/enzyme_ad/jax:RegistryUtils.cpp",
],
visibility = ["//visibility:public"],
deps = [
"@enzyme//:EnzymeMLIR",
Expand All @@ -44,6 +56,7 @@ cc_binary(
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:NVGPUDialect",
"@llvm-project//mlir:OpenMPDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
Expand All @@ -52,8 +65,30 @@ cc_binary(
"//src/enzyme_ad/jax:TransformOps",
"//src/enzyme_ad/jax:XLADerivatives",
"@stablehlo//:chlo_ops",
"@stablehlo//stablehlo/tests:check_ops"
],
"@stablehlo//stablehlo/tests:check_ops",
"@llvm-project//mlir:ArithToLLVM",
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
"@llvm-project//mlir:ComplexToLLVM",
"@llvm-project//mlir:ControlFlowToLLVM",
"@llvm-project//mlir:GPUToLLVMIRTranslation",
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
] + if_llvm_aarch32_available([
"@llvm-project//llvm:ARMAsmParser",
"@llvm-project//llvm:ARMCodeGen",
]) + if_llvm_aarch64_available([
"@llvm-project//llvm:AArch64AsmParser",
"@llvm-project//llvm:AArch64CodeGen",
]) + if_llvm_powerpc_available([
"@llvm-project//llvm:PowerPCAsmParser",
"@llvm-project//llvm:PowerPCCodeGen",
]) + if_llvm_system_z_available([
"@llvm-project//llvm:SystemZAsmParser",
"@llvm-project//llvm:SystemZCodeGen",
]) + if_llvm_x86_available([
"@llvm-project//llvm:X86AsmParser",
"@llvm-project//llvm:X86CodeGen",
]),
copts = [
"-Wno-unused-variable",
"-Wno-return-type",
Expand Down
66 changes: 63 additions & 3 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ load("@jax//jaxlib:symlink_files.bzl", "symlink_inputs")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

exports_files(["enzymexlamlir-opt.cpp"])
exports_files(["enzymexlamlir-opt.cpp", "RegistryUtils.cpp"])

licenses(["notice"])

Expand Down Expand Up @@ -205,6 +205,55 @@ gentbl_cc_library(
deps = [":EnzymeXLAPassesTdFiles"],
)


td_library(
name = "EnzymeXLADialectTdFiles",
srcs = [
"Dialect/Dialect.td",
],
includes = ["."],
deps = [
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
"@llvm-project//mlir:LoopLikeInterfaceTdFiles",
"@llvm-project//mlir:MemorySlotInterfacesTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
],
)

gentbl_cc_library(
name = "EnzymeXLAOpsIncGen",
tbl_outs = [
(
["-gen-op-decls"],
"Dialect/EnzymeXLAOps.h.inc",
),
(
["-gen-op-defs"],
"Dialect/EnzymeXLAOps.cpp.inc",
),
(
[
"-gen-dialect-decls",
"-dialect=enzymexla",
],
"Dialect/EnzymeXLADialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=enzymexla",
],
"Dialect/EnzymeXLADialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Dialect/EnzymeXLAOps.td",
deps = [":EnzymeXLADialectTdFiles", "@enzyme//:EnzymeDialectTdFiles", "@stablehlo//:stablehlo_ops_td_files"],
)

gentbl_cc_library(
name = "EnzyeHLOPatternsIncGen",
tbl_outs = [
Expand All @@ -228,11 +277,13 @@ cc_library(
[
"Implementations/*.cpp",
"Passes/*.cpp",
"Dialect/*.cpp",
],
),
hdrs = glob([
"Implementations/*.h",
"Passes/*.h",
"Dialect/*.h",
]),
copts = [
"-Werror=unused-variable",
Expand All @@ -241,8 +292,14 @@ cc_library(
"-Werror=unused-result",
],
deps = [
":EnzymeXLAOpsIncGen",
":EnzymeXLAPassesIncGen",
":EnzyeHLOPatternsIncGen",
"@llvm-project//mlir:GPUPipelines",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:ExecutionEngine",
"@llvm-project//llvm:OrcJIT",
"@llvm-project//llvm:OrcTargetProcess",
":mhlo-derivatives",
":stablehlo-derivatives",
":chlo-derivatives",
Expand Down Expand Up @@ -271,11 +328,12 @@ cc_library(

pybind_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
srcs = ["compile_with_xla.cc", "RegistryUtils.cpp"],
hdrs = glob([
"compile_with_xla.h",
"Implementations/*.h",
"Passes/*.h",
"RegistryUtils.h"
]),
deps = [
":XLADerivatives",
Expand Down Expand Up @@ -368,7 +426,9 @@ pybind_library(

pybind_extension(
name = "enzyme_call",
srcs = ["enzyme_call.cc"],
srcs = ["enzyme_call.cc",
"RegistryUtils.cpp"
],
visibility = ["//visibility:public"],
deps = [
":clang_compile",
Expand Down
52 changes: 52 additions & 0 deletions src/enzyme_ad/jax/Dialect/Dialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- EnzymeXLADialect.cpp - EnzymeXLA dialect -----------------------*- C++
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Dialect.h"
#include "Ops.h"
#include "mlir/IR/DialectImplementation.h"

#include "mlir/IR/Builders.h"
#include "llvm/ADT/TypeSwitch.h"

#include "mlir/IR/Dialect.h"

// #include "Dialect/EnzymeEnums.cpp.inc"
#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.cpp.inc"

#define GET_OP_CLASSES
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc"

// #define GET_TYPEDEF_CLASSES
// #include "Dialect/EnzymeXLAOpsTypes.cpp.inc"
// #include "Dialect/EnzymeTypes.cpp.inc"

using namespace mlir;
using namespace mlir::enzymexla;

//===----------------------------------------------------------------------===//
// Enzyme dialect.
//===----------------------------------------------------------------------===//

void EnzymeXLADialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc"
>();
// addAttributes<
// #define GET_ATTRDEF_LIST
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc"
// >();
// addTypes<
// #define GET_TYPEDEF_LIST
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAOpsTypes.cpp.inc"
// >();
}

// #define GET_ATTRDEF_CLASSES
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc"
17 changes: 17 additions & 0 deletions src/enzyme_ad/jax/Dialect/Dialect.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
//===- Dialect.h - EnzymeXLA dialect -------------------------------*- C++
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENZYMEXLA_DIALECT_H
#define ENZYMEXLA_DIALECT_H

#include "mlir/IR/Dialect.h"

#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.h.inc"

#endif // ENZYME_DIALECT_H
36 changes: 36 additions & 0 deletions src/enzyme_ad/jax/Dialect/Dialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- EnzymeXLA.td - EnzymeXLA dialect --------------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENZYMEXLA_DIALECT
#define ENZYMEXLA_DIALECT

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"

//===----------------------------------------------------------------------===//
// Enzyme dialect definition.
//===----------------------------------------------------------------------===//

def EnzymeXLA_Dialect : Dialect {
let name = "enzymexla";
let description = [{}];
let cppNamespace = "::mlir::enzymexla";
// let useDefaultAttributePrinterParser = 1;
// let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
// Base Enzyme operation definition.
//===----------------------------------------------------------------------===//

class EnzymeXLA_Op<string mnemonic, list<Trait> traits = []>
: Op<EnzymeXLA_Dialect, mnemonic, traits>;

class EnzymeXLA_Type<string name> : TypeDef<EnzymeXLA_Dialect, name>;

#endif // ENZYMEXLA_DIALECT
74 changes: 74 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//===- EnzymeXLAOps.td - EnzymeXLA dialect ops ------------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef ENZYMEXLA_OPS
#define ENZYMEXLA_OPS

include "Enzyme/MLIR/Dialect/Dialect.td"
include "Dialect.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/EnumAttr.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td"

include "mlir/IR/AttrTypeBase.td"

include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

//include "stablehlo/dialect/Base.td"
//include "stablehlo/dialect/StablehloAttrs.td"

def TensorI64 : Type<CPred<"::llvm::isa<::mlir::TensorType>($_self) && ::llvm::cast<::mlir::TensorType>($_self).getShape().size() == 0 && ::llvm::cast<::mlir::TensorType>($_self).getElementType().isSignlessInteger(64)">, "tensor<i64>",
"::mlir::TensorType">,
BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">;

def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, Pure]> {
let summary = "Kernel Call operation";
let description = [{
}];

let arguments = (ins
FlatSymbolRefAttr:$fn,
TensorI64:$gridx,
TensorI64:$gridy,
TensorI64:$gridz,
TensorI64:$blockx,
TensorI64:$blocky,
TensorI64:$blockz,
TensorI64:$shmem,
Variadic<AnyType>:$inputs,
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
OptionalAttr<AnyAttr>:$operand_layouts,
OptionalAttr<AnyAttr>:$result_layouts,
DefaultValuedOptionalAttr<
ArrayAttr, "{}">:$output_operand_aliases
//OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts,
//OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts,
//DefaultValuedOptionalAttr<
// TypedArrayAttrBase<
// StableHLO_OutputOperandAlias,
// "Aliasing attribute for outputs and operands of CustomCall">,
// "{}">:$output_operand_aliases
);

let results = (outs Variadic<AnyType>);


let assemblyFormat = [{
$fn ` ` `blocks` `in` `(` $gridx `,` $gridy `,` $gridz `)` ` ` `threads` `in` `(` $blockx `,` $blocky `,` $blockz `)` ` ` `shmem` `=` $shmem ` ` `(` $inputs `)` attr-dict `:` functional-type($inputs, results)
}];

}

#endif // ENZYMEXLA_OPS
Loading

0 comments on commit fb483c0

Please sign in to comment.