From 02e950c7b4d6c897c140276fd635931af6ff3bbb Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 25 Oct 2023 06:46:38 -0700 Subject: [PATCH] Consolidate `TcpOpsIncGen` to include attrdef and enum decls/defs (#13) This is idiomatic to usage in upstream MLIR. Includes other minor cleanup. --- BUILD | 156 +++++++---------------- include/mlir-tcp/Dialect/IR/TcpBase.td | 12 +- lib/Conversion/TorchToTcp/TorchToTcp.cpp | 3 +- 3 files changed, 52 insertions(+), 119 deletions(-) diff --git a/BUILD b/BUILD index beb114ba..407f8a7f 100644 --- a/BUILD +++ b/BUILD @@ -12,7 +12,7 @@ package( ) td_library( - name = "TcpTdFiles", + name = "TcpDialectTdFiles", srcs = [ "include/mlir-tcp/Dialect/IR/TcpBase.td", "include/mlir-tcp/Dialect/IR/TcpEnums.td", @@ -25,46 +25,6 @@ td_library( ], ) -gentbl_cc_library( - name = "TcpEnumsIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["-gen-enum-decls"], - "include/mlir-tcp/Dialect/IR/TcpEnums.h.inc", - ), - ( - ["-gen-enum-defs"], - "include/mlir-tcp/Dialect/IR/TcpEnums.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td", - deps = [ - ":TcpTdFiles", - ], -) - -gentbl_cc_library( - name = "TcpAttrsIncGen", - strip_include_prefix = "include", - tbl_outs = [ - ( - ["-gen-attrdef-decls"], - "include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td", - deps = [ - ":TcpTdFiles", - ], -) - gentbl_cc_library( name = "TcpOpsIncGen", strip_include_prefix = "include", @@ -91,12 +51,26 @@ gentbl_cc_library( ], "include/mlir-tcp/Dialect/IR/TcpDialect.cpp.inc", ), + ( + ["-gen-attrdef-decls"], + "include/mlir-tcp/Dialect/IR/TcpAttrs.h.inc", + ), + ( + ["-gen-attrdef-defs"], + "include/mlir-tcp/Dialect/IR/TcpAttrs.cpp.inc", + ), + ( + ["-gen-enum-decls"], + "include/mlir-tcp/Dialect/IR/TcpEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir-tcp/Dialect/IR/TcpEnums.cpp.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-tcp/Dialect/IR/TcpOps.td", - deps = [ - ":TcpTdFiles", - ], + deps = [":TcpDialectTdFiles"], ) cc_library( @@ -111,8 +85,6 @@ cc_library( ], strip_include_prefix = "include", deps = [ - ":TcpAttrsIncGen", - ":TcpEnumsIncGen", ":TcpOpsIncGen", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", @@ -121,19 +93,8 @@ cc_library( ], ) -td_library( - name = "TcpTransformsPassesTdFiles", - srcs = [ - "include/mlir-tcp/Dialect/Transforms/Passes.td", - ], - deps = [ - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - ], -) - gentbl_cc_library( - name = "TcpTransformsPassesIncGen", + name = "TcpDialectPassesIncGen", strip_include_prefix = "include", tbl_outs = [ ( @@ -144,7 +105,8 @@ gentbl_cc_library( tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-tcp/Dialect/Transforms/Passes.td", deps = [ - ":TcpTransformsPassesTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:PassBaseTdFiles", ], ) @@ -166,48 +128,31 @@ cc_library( strip_include_prefix = "include", deps = [ ":TcpDialect", - ":TcpTransformsPassesIncGen", + ":TcpDialectPassesIncGen", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], ) -td_library( - name = "TcpConversionPassesTdFiles", - srcs = [ - "include/mlir-tcp/Conversion/Passes.td", - ], - includes = ["include"], -) - gentbl_cc_library( name = "TcpConversionPassesIncGen", strip_include_prefix = "include", tbl_outs = [ ( - [ - "-gen-pass-decls", - ], + ["-gen-pass-decls"], "include/mlir-tcp/Conversion/Passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "include/mlir-tcp/Conversion/Passes.td", - deps = [ - ":TcpConversionPassesTdFiles", - "@llvm-project//mlir:PassBaseTdFiles", - ], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "TcpConversionPasses", - srcs = [ - "lib/Conversion/Passes.cpp", - ], - hdrs = [ - "include/mlir-tcp/Conversion/Passes.h", - ], + srcs = ["lib/Conversion/Passes.cpp"], + hdrs = ["include/mlir-tcp/Conversion/Passes.h"], strip_include_prefix = "include", deps = [ ":StablehloToTcp", @@ -219,12 +164,17 @@ cc_library( cc_library( name = "TorchToTcp", - srcs = glob([ - "lib/Conversion/*.h", - "lib/Conversion/TorchToTcp/*.h", - "lib/Conversion/TorchToTcp/*.cpp", - ]), - hdrs = glob(["include/mlir-tcp/Conversion/TorchToTcp/*.h"]), + srcs = [ + "lib/Conversion/PassDetail.h", + "lib/Conversion/TorchToTcp/DataMovement.cpp", + "lib/Conversion/TorchToTcp/Elementwise.cpp", + "lib/Conversion/TorchToTcp/Misc.cpp", + "lib/Conversion/TorchToTcp/PopulatePatterns.h", + "lib/Conversion/TorchToTcp/TorchToTcp.cpp", + "lib/Conversion/TorchToTcp/Utils.cpp", + "lib/Conversion/TorchToTcp/Utils.h", + ], + hdrs = ["include/mlir-tcp/Conversion/TorchToTcp/TorchToTcp.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPassesIncGen", @@ -242,9 +192,7 @@ cc_library( "lib/Conversion/PassDetail.h", "lib/Conversion/StablehloToTcp/StablehloToTcp.cpp", ], - hdrs = [ - "include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h", - ], + hdrs = ["include/mlir-tcp/Conversion/StablehloToTcp/StablehloToTcp.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPassesIncGen", @@ -266,9 +214,7 @@ cc_library( "lib/Conversion/TcpToLinalg/PopulatePatterns.h", "lib/Conversion/TcpToLinalg/TcpToLinalg.cpp", ], - hdrs = [ - "include/mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h", - ], + hdrs = ["include/mlir-tcp/Conversion/TcpToLinalg/TcpToLinalg.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPassesIncGen", @@ -289,9 +235,7 @@ cc_library( "lib/Conversion/PassDetail.h", "lib/Conversion/TcpToArith/TcpToArith.cpp", ], - hdrs = [ - "include/mlir-tcp/Conversion/TcpToArith/TcpToArith.h", - ], + hdrs = ["include/mlir-tcp/Conversion/TcpToArith/TcpToArith.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPassesIncGen", @@ -307,12 +251,8 @@ cc_library( cc_library( name = "TcpInitAll", - srcs = [ - "lib/InitAll.cpp", - ], - hdrs = [ - "include/mlir-tcp/InitAll.h", - ], + srcs = ["lib/InitAll.cpp"], + hdrs = ["include/mlir-tcp/InitAll.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPasses", @@ -327,12 +267,8 @@ cc_library( cc_library( name = "Pipeline", - srcs = [ - "lib/Pipeline/Pipeline.cpp", - ], - hdrs = [ - "include/mlir-tcp/Pipeline/Pipeline.h", - ], + srcs = ["lib/Pipeline/Pipeline.cpp"], + hdrs = ["include/mlir-tcp/Pipeline/Pipeline.h"], strip_include_prefix = "include", deps = [ ":TcpConversionPasses", @@ -345,9 +281,7 @@ cc_library( cc_binary( name = "tcp-opt", - srcs = [ - "tools/tcp-opt/tcp-opt.cpp", - ], + srcs = ["tools/tcp-opt/tcp-opt.cpp"], deps = [ ":Pipeline", ":TcpDialect", diff --git a/include/mlir-tcp/Dialect/IR/TcpBase.td b/include/mlir-tcp/Dialect/IR/TcpBase.td index af110d0d..31e4b98c 100644 --- a/include/mlir-tcp/Dialect/IR/TcpBase.td +++ b/include/mlir-tcp/Dialect/IR/TcpBase.td @@ -42,7 +42,7 @@ def Tcp_Dialect : Dialect { // Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. // Where low and high ends are 0,255 when unsigned, -128,127 when signed, for // the 8-bit case. -class TCP_QuantizedType params, bit signed> +class Tcp_QuantizedType params, bit signed> : Type()">, CPred<"$_self.cast()" # ".getStorageTypeIntegralWidth() == " # !head(params)>]>, @@ -63,12 +63,12 @@ class TCP_QuantizedType params, bit signed> // q8ss : symmetric signed // q16ss : symmetric signed //===----------------------------------------------------------------------===// -def TCP_QuantizedInt : AnyTypeOf<[ TCP_QuantizedType<"q8ua", [8], 0>, - TCP_QuantizedType<"q8sa", [8], 1>, - TCP_QuantizedType<"q8ss", [8, 0], 1>, - TCP_QuantizedType<"q16ss", [16, 0], 1>]>; +def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>, + Tcp_QuantizedType<"q8sa", [8], 1>, + Tcp_QuantizedType<"q8ss", [8, 0], 1>, + Tcp_QuantizedType<"q16ss", [16, 0], 1>]>; -def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, TCP_QuantizedInt]>; +def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>; def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>; def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>; diff --git a/lib/Conversion/TorchToTcp/TorchToTcp.cpp b/lib/Conversion/TorchToTcp/TorchToTcp.cpp index 29eb4aff..122ee98d 100644 --- a/lib/Conversion/TorchToTcp/TorchToTcp.cpp +++ b/lib/Conversion/TorchToTcp/TorchToTcp.cpp @@ -75,8 +75,7 @@ class ConvertTorchToTcp : public ConvertTorchToTcpBase { } // namespace -std::unique_ptr> -mlir::tcp::createConvertTorchToTcpPass() { +std::unique_ptr> createConvertTorchToTcpPass() { return std::make_unique(); }