Skip to content

Commit

Permalink
Create TcpTypes.td and move type definitions there
Browse files Browse the repository at this point in the history
  • Loading branch information
sjain-stanford committed Oct 26, 2023
1 parent 02e950c commit 6dbabbd
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 45 deletions.
20 changes: 20 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ td_library(
"include/mlir-tcp/Dialect/IR/TcpBase.td",
"include/mlir-tcp/Dialect/IR/TcpEnums.td",
"include/mlir-tcp/Dialect/IR/TcpOps.td",
"include/mlir-tcp/Dialect/IR/TcpTypes.td",
],
includes = ["include"],
deps = [
Expand Down Expand Up @@ -73,6 +74,24 @@ gentbl_cc_library(
deps = [":TcpDialectTdFiles"],
)

gentbl_cc_library(
name = "TcpTypesIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-typedef-decls"],
"include/mlir-tcp/Dialect/IR/TcpTypes.h.inc",
),
(
["-gen-typedef-defs"],
"include/mlir-tcp/Dialect/IR/TcpTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "include/mlir-tcp/Dialect/IR/TcpTypes.td",
deps = [":TcpDialectTdFiles"],
)

cc_library(
name = "TcpDialect",
srcs = [
Expand All @@ -86,6 +105,7 @@ cc_library(
strip_include_prefix = "include",
deps = [
":TcpOpsIncGen",
":TcpTypesIncGen",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
Expand Down
56 changes: 13 additions & 43 deletions include/mlir-tcp/Dialect/IR/TcpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
#define TCP_BASE

include "mlir/IR/OpBase.td"

include "mlir/Interfaces/SideEffectInterfaces.td"

//===----------------------------------------------------------------------===//
// Tcp Dialect Base
//===----------------------------------------------------------------------===//

def Tcp_Dialect : Dialect {
let name = "tcp";
let cppNamespace = "::mlir::tcp";
Expand All @@ -35,48 +38,7 @@ def Tcp_Dialect : Dialect {
}

//===----------------------------------------------------------------------===//
// Tcp Type Definitions.
//===----------------------------------------------------------------------===//

// The base class of a quantized type.
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
// the 8-bit case.
class Tcp_QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
"Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
string name = n;
string asTraitArgsStr = !interleave(params, ", ") #
!if(signed, ", true", ", false");
}

//===----------------------------------------------------------------------===//
// Quantized Integer Types.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Name Symmetry Sign
//===----------------------------------------------------------------------===//
// q8ua : asymmetric unsigned
// q8sa : asymmetric signed
// q8ss : symmetric signed
// q16ss : symmetric signed
//===----------------------------------------------------------------------===//
def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>,
Tcp_QuantizedType<"q8sa", [8], 1>,
Tcp_QuantizedType<"q8ss", [8, 0], 1>,
Tcp_QuantizedType<"q16ss", [16, 0], 1>]>;

def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>;
def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>;
def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>;

def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>;
def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>;

//===----------------------------------------------------------------------===//
// Tcp Ops Base.
// Tcp Ops Base
//===----------------------------------------------------------------------===//

class Tcp_Op<string mnemonic, list<Trait> traits = []> :
Expand All @@ -97,4 +59,12 @@ class Tcp_BinaryElementwiseOp<string mnemonic, list<Trait> traits = []> :
SameOperandsAndResultShape])> {
}

//===----------------------------------------------------------------------===//
// Tcp Types Base
//===----------------------------------------------------------------------===//

class Tcp_Type<string name, string typeMnemonic> : TypeDef<Tcp_Dialect, name> {
let mnemonic = typeMnemonic;
}

#endif // TCP_BASE
2 changes: 2 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

include "mlir/IR/EnumAttr.td"

include "mlir-tcp/Dialect/IR/TcpBase.td"

// TCP Signedness enum (mirror the mlir::IntegerType::SignednessSemantics enum)
def Tcp_Signedness_Signless : I32EnumAttrCase<"Signless", 0>;
def Tcp_Signedness_Signed : I32EnumAttrCase<"Signed", 1>;
Expand Down
5 changes: 3 additions & 2 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
#ifndef TCP_OPS
#define TCP_OPS

include "mlir/IR/OpBase.td"

include "mlir-tcp/Dialect/IR/TcpBase.td"
include "mlir-tcp/Dialect/IR/TcpEnums.td"

include "mlir/IR/OpBase.td"
include "mlir-tcp/Dialect/IR/TcpTypes.td"

def Tcp_TanhOp : Tcp_UnaryElementwiseOp<"tanh", [SameOperandsAndResultElementType]> {
let summary = "Computes tanh of input, elementwise";
Expand Down
59 changes: 59 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpTypes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Also available under a BSD-style license. See LICENSE.
//
//===----------------------------------------------------------------------===//

#ifndef TCP_TYPES
#define TCP_TYPES

include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/DialectBase.td"

include "mlir-tcp/Dialect/IR/TcpBase.td"

//===----------------------------------------------------------------------===//
// Tcp Type Definitions.
//===----------------------------------------------------------------------===//

// The base class of a quantized type.
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
// the 8-bit case.
class Tcp_QuantizedType<string n, list<int> params, bit signed>
: Type<And<[CPred<"$_self.isa<mlir::quant::QuantizedType>()">,
CPred<"$_self.cast<mlir::quant::QuantizedType>()" #
".getStorageTypeIntegralWidth() == " # !head(params)>]>,
"Q" # !if (signed, "int", "uint") # !head(params) # " type"> {
string name = n;
string asTraitArgsStr = !interleave(params, ", ") #
!if(signed, ", true", ", false");
}

//===----------------------------------------------------------------------===//
// Quantized Integer Types.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Name Symmetry Sign
//===----------------------------------------------------------------------===//
// q8ua : asymmetric unsigned
// q8sa : asymmetric signed
// q8ss : symmetric signed
// q16ss : symmetric signed
//===----------------------------------------------------------------------===//
def Tcp_QuantizedInt : AnyTypeOf<[Tcp_QuantizedType<"q8ua", [8], 0>,
Tcp_QuantizedType<"q8sa", [8], 1>,
Tcp_QuantizedType<"q8ss", [8, 0], 1>,
Tcp_QuantizedType<"q16ss", [16, 0], 1>]>;

def Tcp_Scalar : AnyTypeOf<[AnyFloat, AnySignlessInteger, Tcp_QuantizedInt]>;
def Tcp_Tensor : RankedTensorOf<[Tcp_Scalar]>;
def Tcp_TensorOrScalar : AnyTypeOf<[Tcp_Tensor, Tcp_Scalar]>;

def Tcp_FloatTensor : RankedTensorOf<[AnyFloat]>;
def Tcp_FloatOrIntTensor : RankedTensorOf<[AnyFloat, AnySignlessInteger]>;

#endif // TCP_TYPES

0 comments on commit 6dbabbd

Please sign in to comment.