-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Custom kernel lowering * fix * actually add file * gpu kernel generation * almost complete * use the same context * fixup * Doing things properly * cleaning up * cleanup * cleanup * fmt * final clean * cleanup * fmt * Now with dynamic shmem * fmt --------- Co-authored-by: Alex Zinenko <[email protected]>
- Loading branch information
Showing
18 changed files
with
1,192 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
//===- EnzymeXLADialect.cpp - EnzymeXLA dialect -----------------------*- C++ | ||
//-*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "Dialect.h" | ||
#include "Ops.h" | ||
#include "mlir/IR/DialectImplementation.h" | ||
|
||
#include "mlir/IR/Builders.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
||
#include "mlir/IR/Dialect.h" | ||
|
||
// #include "Dialect/EnzymeEnums.cpp.inc" | ||
#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.cpp.inc" | ||
|
||
#define GET_OP_CLASSES | ||
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc" | ||
|
||
// #define GET_TYPEDEF_CLASSES | ||
// #include "Dialect/EnzymeXLAOpsTypes.cpp.inc" | ||
// #include "Dialect/EnzymeTypes.cpp.inc" | ||
|
||
using namespace mlir; | ||
using namespace mlir::enzymexla; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Enzyme dialect. | ||
//===----------------------------------------------------------------------===// | ||
|
||
void EnzymeXLADialect::initialize() { | ||
addOperations< | ||
#define GET_OP_LIST | ||
#include "src/enzyme_ad/jax/Dialect/EnzymeXLAOps.cpp.inc" | ||
>(); | ||
// addAttributes< | ||
// #define GET_ATTRDEF_LIST | ||
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc" | ||
// >(); | ||
// addTypes< | ||
// #define GET_TYPEDEF_LIST | ||
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAOpsTypes.cpp.inc" | ||
// >(); | ||
} | ||
|
||
// #define GET_ATTRDEF_CLASSES | ||
// #include "src/enzyme_ad/jax/Dialect/EnzymeXLAAttributes.cpp.inc" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
//===- Dialect.h - EnzymeXLA dialect -------------------------------*- C++ | ||
//-*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef ENZYMEXLA_DIALECT_H | ||
#define ENZYMEXLA_DIALECT_H | ||
|
||
#include "mlir/IR/Dialect.h" | ||
|
||
#include "src/enzyme_ad/jax/Dialect/EnzymeXLADialect.h.inc" | ||
|
||
#endif // ENZYME_DIALECT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
//===- EnzymeXLA.td - EnzymeXLA dialect --------------------------*- tablegen -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef ENZYMEXLA_DIALECT | ||
#define ENZYMEXLA_DIALECT | ||
|
||
include "mlir/IR/OpBase.td" | ||
include "mlir/IR/AttrTypeBase.td" | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Enzyme dialect definition. | ||
//===----------------------------------------------------------------------===// | ||
|
||
def EnzymeXLA_Dialect : Dialect { | ||
let name = "enzymexla"; | ||
let description = [{}]; | ||
let cppNamespace = "::mlir::enzymexla"; | ||
// let useDefaultAttributePrinterParser = 1; | ||
// let useDefaultTypePrinterParser = 1; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// Base Enzyme operation definition. | ||
//===----------------------------------------------------------------------===// | ||
|
||
class EnzymeXLA_Op<string mnemonic, list<Trait> traits = []> | ||
: Op<EnzymeXLA_Dialect, mnemonic, traits>; | ||
|
||
class EnzymeXLA_Type<string name> : TypeDef<EnzymeXLA_Dialect, name>; | ||
|
||
#endif // ENZYMEXLA_DIALECT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
//===- EnzymeXLAOps.td - EnzymeXLA dialect ops ------------------*- tablegen -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef ENZYMEXLA_OPS | ||
#define ENZYMEXLA_OPS | ||
|
||
include "Enzyme/MLIR/Dialect/Dialect.td" | ||
include "Dialect.td" | ||
include "mlir/Interfaces/ViewLikeInterface.td" | ||
include "mlir/IR/SymbolInterfaces.td" | ||
include "mlir/IR/EnumAttr.td" | ||
|
||
include "mlir/IR/OpBase.td" | ||
include "mlir/IR/SymbolInterfaces.td" | ||
|
||
include "mlir/IR/AttrTypeBase.td" | ||
|
||
include "mlir/Interfaces/ControlFlowInterfaces.td" | ||
include "mlir/Interfaces/FunctionInterfaces.td" | ||
include "mlir/Interfaces/LoopLikeInterface.td" | ||
include "mlir/Interfaces/MemorySlotInterfaces.td" | ||
include "mlir/Interfaces/SideEffectInterfaces.td" | ||
|
||
//include "stablehlo/dialect/Base.td" | ||
//include "stablehlo/dialect/StablehloAttrs.td" | ||
|
||
def TensorI64 : Type<CPred<"::llvm::isa<::mlir::TensorType>($_self) && ::llvm::cast<::mlir::TensorType>($_self).getShape().size() == 0 && ::llvm::cast<::mlir::TensorType>($_self).getElementType().isSignlessInteger(64)">, "tensor<i64>", | ||
"::mlir::TensorType">, | ||
BuildableType<"RankedTensorType::get({}, $_builder.getIntegerType(64))">; | ||
|
||
def KernelCallOp: EnzymeXLA_Op<"kernel_call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>, Pure]> { | ||
let summary = "Kernel Call operation"; | ||
let description = [{ | ||
}]; | ||
|
||
let arguments = (ins | ||
FlatSymbolRefAttr:$fn, | ||
TensorI64:$gridx, | ||
TensorI64:$gridy, | ||
TensorI64:$gridz, | ||
TensorI64:$blockx, | ||
TensorI64:$blocky, | ||
TensorI64:$blockz, | ||
TensorI64:$shmem, | ||
Variadic<AnyType>:$inputs, | ||
DefaultValuedStrAttr<StrAttr, "">:$backend_config, | ||
OptionalAttr<AnyAttr>:$operand_layouts, | ||
OptionalAttr<AnyAttr>:$result_layouts, | ||
DefaultValuedOptionalAttr< | ||
ArrayAttr, "{}">:$output_operand_aliases | ||
//OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$operand_layouts, | ||
//OptionalAttr<StableHLO_ArrayOfLayoutAttr>:$result_layouts, | ||
//DefaultValuedOptionalAttr< | ||
// TypedArrayAttrBase< | ||
// StableHLO_OutputOperandAlias, | ||
// "Aliasing attribute for outputs and operands of CustomCall">, | ||
// "{}">:$output_operand_aliases | ||
); | ||
|
||
let results = (outs Variadic<AnyType>); | ||
|
||
|
||
let assemblyFormat = [{ | ||
$fn ` ` `blocks` `in` `(` $gridx `,` $gridy `,` $gridz `)` ` ` `threads` `in` `(` $blockx `,` $blocky `,` $blockz `)` ` ` `shmem` `=` $shmem ` ` `(` $inputs `)` attr-dict `:` functional-type($inputs, results) | ||
}]; | ||
|
||
} | ||
|
||
#endif // ENZYMEXLA_OPS |
Oops, something went wrong.