Skip to content

Commit

Permalink
Consolidate TcpOpsIncGen to include attrdef and enum decls/defs (#13)
Browse files Browse the repository at this point in the history
This is idiomatic to usage in upstream MLIR. Includes other minor
cleanup.
  • Loading branch information
sjain-stanford authored Oct 25, 2023
1 parent 4670cab commit 02e950c
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 119 deletions.
156 changes: 45 additions & 111 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ package(
)

td_library(
name = "TcpTdFiles",
name = "TcpDialectTdFiles",
srcs = [
"include/mlir-tcp/Dialect/IR/TcpBase.td",
"include/mlir-tcp/Dialect/IR/TcpEnums.td",
Expand All @@ -25,46 +25,6 @@ td_library(
],
)

gentbl_cc_library(
name = "TcpEnumsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-enum-decls"],
"include/mlir-tcp/Dialect/IR/TcpEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir-tcp/Dialect/IR/TcpEnums.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td",
deps = [
":TcpTdFiles",
],
)

gentbl_cc_library(
name = "TcpAttrsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-attrdef-decls"],
"include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc",
),
(
["-gen-attrdef-defs"],
"include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td",
deps = [
":TcpTdFiles",
],
)

gentbl_cc_library(
name = "TcpOpsIncGen",
strip_include_prefix = "include",
Expand All @@ -91,12 +51,26 @@ gentbl_cc_library(
],
"include/mlir-tcp/Dialect/IR/TcpDialect.cpp.inc",
),
(
["-gen-attrdef-decls"],
"include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc",
),
(
["-gen-attrdef-defs"],
"include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc",
),
(
["-gen-enum-decls"],
"include/mlir-tcp/Dialect/IR/TcpEnums.h.inc",
),
(
["-gen-enum-defs"],
"include/mlir-tcp/Dialect/IR/TcpEnums.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td",
deps = [
":TcpTdFiles",
],
deps = [":TcpDialectTdFiles"],
)

cc_library(
Expand All @@ -111,8 +85,6 @@ cc_library(
],
strip_include_prefix = "include",
deps = [
":TcpAttrsIncGen",
":TcpEnumsIncGen",
":TcpOpsIncGen",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:DialectUtils",
Expand All @@ -121,19 +93,8 @@ cc_library(
],
)

td_library(
name = "TcpTransformsPassesTdFiles",
srcs = [
"include/mlir-tcp/Dialect/Transforms/Passes.td",
],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

gentbl_cc_library(
name = "TcpTransformsPassesIncGen",
name = "TcpDialectPassesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
Expand All @@ -144,7 +105,8 @@ gentbl_cc_library(
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Dialect/Transforms/Passes.td",
deps = [
":TcpTransformsPassesTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

Expand All @@ -166,48 +128,31 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TcpDialect",
":TcpTransformsPassesIncGen",
":TcpDialectPassesIncGen",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

td_library(
name = "TcpConversionPassesTdFiles",
srcs = [
"include/mlir-tcp/Conversion/Passes.td",
],
includes = ["include"],
)

gentbl_cc_library(
name = "TcpConversionPassesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
[
"-gen-pass-decls",
],
["-gen-pass-decls"],
"include/mlir-tcp/Conversion/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Conversion/Passes.td",
deps = [
":TcpConversionPassesTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "TcpConversionPasses",
srcs = [
"lib/Conversion/Passes.cpp",
],
hdrs = [
"include/mlir-tcp/Conversion/Passes.h",
],
srcs = ["lib/Conversion/Passes.cpp"],
hdrs = ["include/mlir-tcp/Conversion/Passes.h"],
strip_include_prefix = "include",
deps = [
":StablehloToTcp",
Expand All @@ -219,12 +164,17 @@ cc_library(

cc_library(
name = "TorchToTcp",
srcs = glob([
"lib/Conversion/*.h",
"lib/Conversion/TorchToTcp/*.h",
"lib/Conversion/TorchToTcp/*.cpp",
]),
hdrs = glob(["include/mlir-tcp/Conversion/TorchToTcp/*.h"]),
srcs = [
"lib/Conversion/PassDetail.h",
"lib/Conversion/TorchToTcp/DataMovement.cpp",
"lib/Conversion/TorchToTcp/Elementwise.cpp",
"lib/Conversion/TorchToTcp/Misc.cpp",
"lib/Conversion/TorchToTcp/PopulatePatterns.h",
"lib/Conversion/TorchToTcp/TorchToTcp.cpp",
"lib/Conversion/TorchToTcp/Utils.cpp",
"lib/Conversion/TorchToTcp/Utils.h",
],
hdrs = ["include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPassesIncGen",
Expand All @@ -242,9 +192,7 @@ cc_library(
"lib/Conversion/PassDetail.h",
"lib/Conversion/StablehloToTcp/StablehloToTcp.cpp",
],
hdrs = [
"include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h",
],
hdrs = ["include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPassesIncGen",
Expand All @@ -266,9 +214,7 @@ cc_library(
"lib/Conversion/TcpToLinalg/PopulatePatterns.h",
"lib/Conversion/TcpToLinalg/TcpToLinalg.cpp",
],
hdrs = [
"include/mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h",
],
hdrs = ["include/mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPassesIncGen",
Expand All @@ -289,9 +235,7 @@ cc_library(
"lib/Conversion/PassDetail.h",
"lib/Conversion/TcpToArith/TcpToArith.cpp",
],
hdrs = [
"include/mlir-tcp/Conversion/TcpToArith/TcpToArith.h",
],
hdrs = ["include/mlir-tcp/Conversion/TcpToArith/TcpToArith.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPassesIncGen",
Expand All @@ -307,12 +251,8 @@ cc_library(

cc_library(
name = "TcpInitAll",
srcs = [
"lib/InitAll.cpp",
],
hdrs = [
"include/mlir-tcp/InitAll.h",
],
srcs = ["lib/InitAll.cpp"],
hdrs = ["include/mlir-tcp/InitAll.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPasses",
Expand All @@ -327,12 +267,8 @@ cc_library(

cc_library(
name = "Pipeline",
srcs = [
"lib/Pipeline/Pipeline.cpp",
],
hdrs = [
"include/mlir-tcp/Pipeline/Pipeline.h",
],
srcs = ["lib/Pipeline/Pipeline.cpp"],
hdrs = ["include/mlir-tcp/Pipeline/Pipeline.h"],
strip_include_prefix = "include",
deps = [
":TcpConversionPasses",
Expand All @@ -345,9 +281,7 @@ cc_library(

cc_binary(
name = "tcp-opt",
srcs = [
"tools/tcp-opt/tcp-opt.cpp",
],
srcs = ["tools/tcp-opt/tcp-opt.cpp"],
deps = [
":Pipeline",
":TcpDialect",
Expand Down
12 changes: 6 additions & 6 deletions include/mlir-tcp/Dialect/IR/TcpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def Tcp_Dialect : Dialect {
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
// the 8-bit case.
class TCP_QuantizedType<string n, list<int> params, bit signed>
class Tcp_QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
Expand All @@ -63,12 +63,12 @@ class TCP_QuantizedType<string n, list<int> params, bit signed>
// q8ss : symmetric signed
// q16ss : symmetric signed
//===----------------------------------------------------------------------===//
def TCP_QuantizedInt : AnyTypeOf<[ TCP_QuantizedType<"q8ua", [8], 0>,
TCP_QuantizedType<"q8sa", [8], 1>,
TCP_QuantizedType<"q8ss", [8, 0], 1>,
TCP_QuantizedType<"q16ss", [16, 0], 1>]>;
def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>,
Tcp_QuantizedType<"q8sa", [8], 1>,
Tcp_QuantizedType<"q8ss", [8, 0], 1>,
Tcp_QuantizedType<"q16ss", [16, 0], 1>]>;

def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, TCP_QuantizedInt]>;
def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>;
def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>;
def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>;

Expand Down
3 changes: 1 addition & 2 deletions lib/Conversion/TorchToTcp/TorchToTcp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase<ConvertTorchToTcp> {

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
mlir::tcp::createConvertTorchToTcpPass() {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToTcpPass() {
return std::make_unique<ConvertTorchToTcp>();
}

Expand Down

0 comments on commit 02e950c

Please sign in to comment.