diff --git a/BUILD b/BUILD index 407f8a7f..c5593dce 100644 --- a/BUILD +++ b/BUILD @@ -17,6 +17,7 @@ td_library( "include/mlir-tcp/Dialect/IR/TcpBase.td", "include/mlir-tcp/Dialect/IR/TcpEnums.td", "include/mlir-tcp/Dialect/IR/TcpOps.td", + "include/mlir-tcp/Dialect/IR/TcpTypes.td", ], includes = ["include"], deps = [ @@ -73,6 +74,24 @@ gentbl_cc_library( deps = [":TcpDialectTdFiles"], ) +gentbl_cc_library( + name = "TcpTypesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "include/mlir-tcp/Dialect/IR/TcpTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "include/mlir-tcp/Dialect/IR/TcpTypes.cpp.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "include/mlir-tcp/Dialect/IR/TcpTypes.td", + deps = [":TcpDialectTdFiles"], +) + cc_library( name = "TcpDialect", srcs = [ @@ -86,6 +105,7 @@ cc_library( strip_include_prefix = "include", deps = [ ":TcpOpsIncGen", + ":TcpTypesIncGen", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", diff --git a/include/mlir-tcp/Dialect/IR/TcpBase.td b/include/mlir-tcp/Dialect/IR/TcpBase.td index 31e4b98c..552c159e 100644 --- a/include/mlir-tcp/Dialect/IR/TcpBase.td +++ b/include/mlir-tcp/Dialect/IR/TcpBase.td @@ -11,9 +11,12 @@ #define TCP_BASE include "mlir/IR/OpBase.td" - include "mlir/Interfaces/SideEffectInterfaces.td" +//===----------------------------------------------------------------------===// +// Tcp Dialect Base +//===----------------------------------------------------------------------===// + def Tcp_Dialect : Dialect { let name = "tcp"; let cppNamespace = "::mlir::tcp"; @@ -35,48 +38,7 @@ def Tcp_Dialect : Dialect { } //===----------------------------------------------------------------------===// -// Tcp Type Definitions. -//===----------------------------------------------------------------------===// - -// The base class of a quantized type. -// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. -// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for -// the 8-bit case. -class Tcp_QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # - ".getStorageTypeIntegralWidth() == " # !head(params)>]>, - "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { - string name = n; - string asTraitArgsStr = !interleave(params, ", ") # - !if(signed, ", true", ", false"); -} - -//===----------------------------------------------------------------------===// -// Quantized Integer Types. -//===----------------------------------------------------------------------===// -//===----------------------------------------------------------------------===// -// Name Symmetry Sign -//===----------------------------------------------------------------------===// -// q8ua : asymmetric unsigned -// q8sa : asymmetric signed -// q8ss : symmetric signed -// q16ss : symmetric signed -//===----------------------------------------------------------------------===// -def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>, - Tcp_QuantizedType<"q8sa", [8], 1>, - Tcp_QuantizedType<"q8ss", [8, 0], 1>, - Tcp_QuantizedType<"q16ss", [16, 0], 1>]>; - -def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>; -def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>; -def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>; - -def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>; -def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>; - -//===----------------------------------------------------------------------===// -// Tcp Ops Base. +// Tcp Ops Base //===----------------------------------------------------------------------===// class Tcp_Op traits = []> : @@ -97,4 +59,12 @@ class Tcp_BinaryElementwiseOp traits = []> : SameOperandsAndResultShape])> { } +//===----------------------------------------------------------------------===// +// Tcp Types Base +//===----------------------------------------------------------------------===// + +class Tcp_Type : TypeDef { + let mnemonic = typeMnemonic; +} + #endif // TCP_BASE diff --git a/include/mlir-tcp/Dialect/IR/TcpEnums.td b/include/mlir-tcp/Dialect/IR/TcpEnums.td index 535f3ad3..6d78072a 100644 --- a/include/mlir-tcp/Dialect/IR/TcpEnums.td +++ b/include/mlir-tcp/Dialect/IR/TcpEnums.td @@ -12,6 +12,8 @@ include "mlir/IR/EnumAttr.td" +include "mlir-tcp/Dialect/IR/TcpBase.td" + // TCP Signedness enum (mirror the mlir::IntegerType::SignednessSemantics enum) def Tcp_Signedness_Signless : I32EnumAttrCase<"Signless", 0>; def Tcp_Signedness_Signed : I32EnumAttrCase<"Signed", 1>; diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index b9346b9e..4e999e16 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -10,10 +10,11 @@ #ifndef TCP_OPS #define TCP_OPS +include "mlir/IR/OpBase.td" + include "mlir-tcp/Dialect/IR/TcpBase.td" include "mlir-tcp/Dialect/IR/TcpEnums.td" - -include "mlir/IR/OpBase.td" +include "mlir-tcp/Dialect/IR/TcpTypes.td" def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh", [SameOperandsAndResultElementType]> { let summary = "Computes tanh of input, elementwise"; diff --git a/include/mlir-tcp/Dialect/IR/TcpTypes.td b/include/mlir-tcp/Dialect/IR/TcpTypes.td new file mode 100644 index 00000000..0a4c2a0e --- /dev/null +++ b/include/mlir-tcp/Dialect/IR/TcpTypes.td @@ -0,0 +1,59 @@ +//===-------------------------------------------------------*- tablegen -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TCP_TYPES +#define TCP_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/DialectBase.td" + +include "mlir-tcp/Dialect/IR/TcpBase.td" + +//===----------------------------------------------------------------------===// +// Tcp Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for +// the 8-bit case. +class Tcp_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = !interleave(params, ", ") # + !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Name Symmetry Sign +//===----------------------------------------------------------------------===// +// q8ua : asymmetric unsigned +// q8sa : asymmetric signed +// q8ss : symmetric signed +// q16ss : symmetric signed +//===----------------------------------------------------------------------===// +def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>, + Tcp_QuantizedType<"q8sa", [8], 1>, + Tcp_QuantizedType<"q8ss", [8, 0], 1>, + Tcp_QuantizedType<"q16ss", [16, 0], 1>]>; + +def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>; +def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>; +def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>; + +def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>; +def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>; + +#endif // TCP_TYPES