Skip to content

Commit

Permalink
refactoring shape inference work in progres
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias Gysi committed May 19, 2020
1 parent 0c69b13 commit cf9c0f5
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 328 deletions.
11 changes: 0 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
cmake_minimum_required(VERSION 3.10)

if(POLICY CMP0068)
cmake_policy(SET CMP0068 NEW)
set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
endif()

if(POLICY CMP0075)
cmake_policy(SET CMP0075 NEW)
endif()

if(POLICY CMP0077)
cmake_policy(SET CMP0077 NEW)
endif()

project(oec-opt LANGUAGES CXX C)
include(CheckLanguage)
Expand Down
3 changes: 3 additions & 0 deletions include/Dialect/Stencil/StencilDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ constexpr static int kKDimension = 2;
// Stencil dimensionality
constexpr static int64_t kNumOfDimensions = 3; // TODO accessible inside tablegen?

// Index type used to store offsets and bounds
typedef SmallVector<int64_t, 3> Index;

class StencilDialect : public Dialect {
public:
explicit StencilDialect(MLIRContext *context);
Expand Down
71 changes: 59 additions & 12 deletions include/Dialect/Stencil/StencilInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,65 @@ include "mlir/IR/OpBase.td"
// Stencil Interfaces
//===----------------------------------------------------------------------===//

// TODO
// def ShapeInterface : OpInterface<"ShapeInterface"> {
// let description = [{
// Interface to access the lower and upper bounds of an operation.
// }];
def ShapeAccess : OpInterface<"ShapeAccess"> {
let description = [{
Interface to get the operation bounds.
}];

// let methods = [
// InterfaceMethod<"/*Method returns the FieldType of the field operand or result*/",
// "FieldType", "getFieldType", (ins), [{
// return op.field().getType().template cast<FieldType>();
// }]>
// ];
// }
let methods = [
InterfaceMethod<"/*Get the lower bound of the operation*/",
"Index", "getLB", (ins), [{
Index result;
Optional<ArrayAttr> lb = op.lb();
for (auto &elem : lb.getValue())
result.push_back(elem.cast<IntegerAttr>().getValue().getSExtValue());
return result;
}]>,
InterfaceMethod<"/*Get the upper bound of the operation*/",
"Index", "getUB", (ins), [{
Index result;
Optional<ArrayAttr> ub = op.ub();
for (auto &elem : ub.getValue())
result.push_back(elem.cast<IntegerAttr>().getValue().getSExtValue());
return result;
}]>,
InterfaceMethod<"/*Verify if the operation has valid bounds*/",
"bool", "hasShape", (ins), [{
Optional<ArrayAttr> lb = op.lb();
Optional<ArrayAttr> ub = op.ub();
return lb.hasValue() && ub.hasValue();
}]>,
InterfaceMethod<"/*Get the rank of the operation*/",
"int64_t", "getRank", (ins), [{
Optional<ArrayAttr> lb = op.lb();
Optional<ArrayAttr> ub = op.ub();
assert(lb.getValue().size() == ub.getValue().size() &&
"expected lower and upper bound to have the same rank");
return lb.getValue().size();
}]>,
];
}

def ShapeInference : OpInterface<"ShapeInference"> {
let description = [{
Interface to set the operation bounds and result types.
}];

let methods = [
InterfaceMethod<"/*Method to get the bounds of the operation.*/",
"void", "setOpShape", (ins "ArrayRef<int64_t>":$lb, "ArrayRef<int64_t>":$ub), [{
assert(lb.size() == ub.size() &&
"expected bounds to have the same size");
SmallVector<Attribute, 3> lbAttr;
SmallVector<Attribute, 3> ubAttr;
for (size_t i = 0, e = lb.size(); i != e; ++i) {
lbAttr.push_back(IntegerAttr::get(IntegerType::get(64, op.getContext()), lb[i]));
ubAttr.push_back(IntegerAttr::get(IntegerType::get(64, op.getContext()), ub[i]));
}
op.lbAttr(ArrayAttr::get(lbAttr, op.getContext()));
op.ubAttr(ArrayAttr::get(ubAttr, op.getContext()));
}]>
];
}

#endif // Stencil_INTERFACES
125 changes: 75 additions & 50 deletions include/Dialect/Stencil/StencilOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
#define STENCIL_OPS

include "Dialect/Stencil/StencilBase.td"
include "Dialect/Stencil/StencilInterfaces.td"
include "mlir/Interfaces/SideEffects.td"

//===----------------------------------------------------------------------===//
// Concrete Operations
//===----------------------------------------------------------------------===//

def Stencil_AssertOp : Stencil_Op<"assert", []> {
def Stencil_AssertOp : Stencil_Op<"assert", [DeclareOpInterfaceMethods<ShapeAccess>]> {
let summary = "assert the input field size";
let description = [{
This operation asserts the size of input fields.
Expand All @@ -34,10 +35,23 @@ def Stencil_AssertOp : Stencil_Op<"assert", []> {
}];

let verifier = [{
// TODO check the dimensions
// TODO check the input shape is dynamic?
// (only check the sizes during the lowering)
auto fieldType = field().getType().cast<stencil::GridType>();
auto shapeOp = cast<ShapeAccess>(this->getOperation());
if(fieldType.hasStaticShape())
return emitOpError("expected fields to have dynamic shape");
if(shapeOp.getRank() != fieldType.getRank())
return emitOpError("expected op and field to have the same rank");

// Verify all users fit the shape
for(auto user : field().getUsers()) {
if(auto userOp = dyn_cast<ShapeAccess>(user)) {
if(userOp.hasShape() &&
(shapeOp.getLB() != mapFunctionToIdxPair(shapeOp.getLB(), userOp.getLB(), minimum) ||
shapeOp.getUB() != mapFunctionToIdxPair(shapeOp.getUB(), userOp.getUB(), maximum)))
return emitOpError("asserted shape not large enough to fit all accesses");
}
}

if (llvm::count_if(field().getUsers(), hasOpType<stencil::AssertOp>) != 1)
return emitOpError("expected exactly one assert operation");
return success();
Expand All @@ -46,17 +60,17 @@ def Stencil_AssertOp : Stencil_Op<"assert", []> {
let extraClassDeclaration = [{
static StringRef getLBAttrName() { return "lb"; }
static StringRef getUBAttrName() { return "ub"; }
SmallVector<int64_t, 3> getLB() {
return convertAttrToVec(lb());
Index getLB() {
return convertAttrToIndex(lb());
}
SmallVector<int64_t, 3> getUB() {
return convertAttrToVec(ub());
Index getUB() {
return convertAttrToIndex(ub());
}
void setLB(ArrayRef<int64_t> bound) {
lbAttr(convertVecToAttr(bound, getContext()));
lbAttr(convertIndexToAttr(bound, getContext()));
}
void setUB(ArrayRef<int64_t> bound) {
ubAttr(convertVecToAttr(bound, getContext()));
ubAttr(convertIndexToAttr(bound, getContext()));
}
stencil::FieldType getFieldType() {
return field().getType().cast<stencil::FieldType>();
Expand Down Expand Up @@ -100,16 +114,18 @@ def Stencil_AccessOp : Stencil_Op<"access", [NoSideEffect]> {

let extraClassDeclaration = [{
static StringRef getOffsetAttrName() { return "offset"; }
SmallVector<int64_t, 3> getOffset() {
return convertAttrToVec(offset());
Index getOffset() {
return convertAttrToIndex(offset());
}
void setOffset(ArrayRef<int64_t> bound) {
setAttr(stencil::AccessOp::getOffsetAttrName(), convertVecToAttr(bound, getContext()));
setAttr(stencil::AccessOp::getOffsetAttrName(), convertIndexToAttr(bound, getContext()));
}
}];
}

def Stencil_LoadOp : Stencil_Op<"load", [NoSideEffect]> {
def Stencil_LoadOp : Stencil_Op<"load", [DeclareOpInterfaceMethods<ShapeAccess>,
DeclareOpInterfaceMethods<ShapeInference>,
NoSideEffect]> {
let summary = "load operation";
let description = [{
This operation takes a field and returns a temporary values.
Expand Down Expand Up @@ -137,12 +153,16 @@ def Stencil_LoadOp : Stencil_Op<"load", [NoSideEffect]> {
}];

let verifier = [{
// TODO check the optional bounds dimensionality matches

if (getShape(field()).size() != getShape(res()).size())
return emitOpError("the field and temp dimensions do not match");
if (getElementType(field()) != getElementType(res()))
return emitOpError("the field and temp element types do not match");
// Check the field and result types
auto fieldType = field().getType().cast<stencil::GridType>();
auto resType = res().getType().cast<stencil::GridType>();
if (fieldType.getRank() != resType.getRank())
return emitOpError("the field and temp types have different rank");
if (fieldType.getAllocation() != resType.getAllocation())
return emitOpError("the field and temp types have different allocation");
if (fieldType.getElementType() != resType.getElementType())
return emitOpError("the field and temp types have different element type");

if (llvm::count_if(field().getUsers(), hasOpType<stencil::AssertOp>) != 1)
return emitOpError("expected exactly one assert operation");
return success();
Expand All @@ -151,22 +171,22 @@ def Stencil_LoadOp : Stencil_Op<"load", [NoSideEffect]> {
let extraClassDeclaration = [{
static StringRef getLBAttrName() { return "lb"; }
static StringRef getUBAttrName() { return "ub"; }
SmallVector<int64_t, 3> getLB() {
return convertAttrToVec(lb());
Index getLB() {
return convertAttrToIndex(lb());
}
SmallVector<int64_t, 3> getUB() {
return convertAttrToVec(ub());
Index getUB() {
return convertAttrToIndex(ub());
}
void setLB(ArrayRef<int64_t> bound) {
lbAttr(convertVecToAttr(bound, getContext()));
lbAttr(convertIndexToAttr(bound, getContext()));
}
void setUB(ArrayRef<int64_t> bound) {
ubAttr(convertVecToAttr(bound, getContext()));
ubAttr(convertIndexToAttr(bound, getContext()));
}
}];
}

def Stencil_StoreOp : Stencil_Op<"store"> {
def Stencil_StoreOp : Stencil_Op<"store", [DeclareOpInterfaceMethods<ShapeAccess>]> {
let summary = "store operation";
let description = [{
This operation takes a temp and writes a field on a user defined range.
Expand Down Expand Up @@ -196,13 +216,16 @@ def Stencil_StoreOp : Stencil_Op<"store"> {
}];

let verifier = [{
// TODO check the optional bounds dimensionality matches

if (getShape(temp()).size() != getShape(field()).size())
return emitOpError("the field and temp dimensions do not match");
if (getElementType(field()) != getElementType(temp()))
return emitOpError("the field and temp element types do not match");

// Check the field and result types
auto fieldType = field().getType().cast<stencil::GridType>();
auto tempType = temp().getType().cast<stencil::GridType>();
if (fieldType.getRank() != tempType.getRank())
return emitOpError("the field and temp types have different rank");
if (fieldType.getAllocation() != tempType.getAllocation())
return emitOpError("the field and temp types have different allocation");
if (fieldType.getElementType() != tempType.getElementType())
return emitOpError("the field and temp types have different element type");

if (!dyn_cast<stencil::ApplyOp>(temp().getDefiningOp()))
return emitOpError("output temp not result of an apply");
if (llvm::count_if(field().getUsers(), hasOpType<stencil::AssertOp>) != 1)
Expand All @@ -215,22 +238,24 @@ def Stencil_StoreOp : Stencil_Op<"store"> {
let extraClassDeclaration = [{
static StringRef getLBAttrName() { return "lb"; }
static StringRef getUBAttrName() { return "ub"; }
SmallVector<int64_t, 3> getLB() {
return convertAttrToVec(lb());
Index getLB() {
return convertAttrToIndex(lb());
}
SmallVector<int64_t, 3> getUB() {
return convertAttrToVec(ub());
Index getUB() {
return convertAttrToIndex(ub());
}
void setLB(ArrayRef<int64_t> bound) {
lbAttr(convertVecToAttr(bound, getContext()));
lbAttr(convertIndexToAttr(bound, getContext()));
}
void setUB(ArrayRef<int64_t> bound) {
ubAttr(convertVecToAttr(bound, getContext()));
ubAttr(convertIndexToAttr(bound, getContext()));
}
}];
}

def Stencil_ApplyOp : Stencil_Op<"apply", [IsolatedFromAbove,
def Stencil_ApplyOp : Stencil_Op<"apply", [DeclareOpInterfaceMethods<ShapeAccess>,
DeclareOpInterfaceMethods<ShapeInference>,
IsolatedFromAbove,
SingleBlockImplicitTerminator<"ReturnOp">,
NoSideEffect]> {
let summary = "apply operation";
Expand Down Expand Up @@ -288,17 +313,17 @@ def Stencil_ApplyOp : Stencil_Op<"apply", [IsolatedFromAbove,
let extraClassDeclaration = [{
static StringRef getLBAttrName() { return "lb"; }
static StringRef getUBAttrName() { return "ub"; }
SmallVector<int64_t, 3> getLB() {
return convertAttrToVec(lb());
Index getLB() {
return convertAttrToIndex(lb());
}
SmallVector<int64_t, 3> getUB() {
return convertAttrToVec(ub());
Index getUB() {
return convertAttrToIndex(ub());
}
void setLB(ArrayRef<int64_t> bound) {
lbAttr(convertVecToAttr(bound, getContext()));
lbAttr(convertIndexToAttr(bound, getContext()));
}
void setUB(ArrayRef<int64_t> bound) {
ubAttr(convertVecToAttr(bound, getContext()));
ubAttr(convertIndexToAttr(bound, getContext()));
}
Block *getBody() { return &region().front(); }
}];
Expand Down Expand Up @@ -356,13 +381,13 @@ def Stencil_ReturnOp : Stencil_Op<"return", [Terminator,

let extraClassDeclaration = [{
static StringRef getUnrollAttrName() { return "unroll"; }
SmallVector<int64_t, 3> getUnroll() {
return convertAttrToVec(unroll());
Index getUnroll() {
return convertAttrToIndex(unroll());
}
unsigned getUnrollFactor() {
unsigned factor = 1;
if (unroll().hasValue()) {
SmallVector<int64_t, 3> unroll = getUnroll();
Index unroll = getUnroll();
factor = std::accumulate(unroll.begin(), unroll.end(), 1,
std::multiplies<int64_t>());
}
Expand Down
23 changes: 22 additions & 1 deletion include/Dialect/Stencil/StencilTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include <bits/stdint-intn.h>

namespace mlir {
Expand Down Expand Up @@ -42,6 +44,25 @@ class GridType : public Type {
/// Return the shape of the type
ArrayRef<int64_t> getShape() const;

/// Return the rank of the type
int64_t getRank() const { return getShape().size(); }

/// Return true if all dimensions are dynamic
int64_t hasStaticShape() const {
return llvm::none_of(getShape(), [](int64_t size) {
return size == kDynamicDimension;
});
}

/// Return the allocated / non-scalar dimensions
SmallVector<bool, 3> getAllocation() const {
SmallVector<bool, 3> result;
result.resize(getRank());
llvm::transform(getShape(), result.begin(),
[](int64_t x) { return x != kScalarDimension; });
return result;
}

/// Support isa, cast, and dyn_cast
static bool classof(Type type) {
return type.getKind() == StencilTypes::Field ||
Expand All @@ -56,7 +77,7 @@ class GridType : public Type {
static constexpr bool isScalar(int64_t dimSize) {
return dimSize == kScalarDimension;
}
};
}; // namespace stencil

//===----------------------------------------------------------------------===//
// FieldType
Expand Down
Loading

0 comments on commit cf9c0f5

Please sign in to comment.