Skip to content

Commit

Permalink
Remove uses of ArgResultAliasAttr as it is not used, specced, or supp…
Browse files Browse the repository at this point in the history
…orted in VHLO (#1421)

Was looking at VHLO code coverage for serialization and realized the
only unused attribute is ArgResultAliasAttr, because it was unsupported
in VHLO legalizations. Speaking with @burmako it sounds like this
attribute is unused and should not be in StableHLO/VHLO currently
without an RFC first.

This is not an incompatibility since serialization would fail for
programs that use this attribute currently.

Adding @sdasgup3 as reviewer as well for spec implications (should be
none?)
  • Loading branch information
GleasonK authored Jun 22, 2023
1 parent 5ae526d commit 72a0117
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 457 deletions.
33 changes: 0 additions & 33 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -108,39 +108,6 @@ def StableHLO_OutputOperandAlias : AttrDef<StableHLO_Dialect, "OutputOperandAlia
}];
}

def StableHLO_ArgResultAlias : AttrDef<StableHLO_Dialect, "ArgResultAlias"> {
let cppNamespace = "::mlir::stablehlo";
let mnemonic = "result_alias";
let summary =
"Attribute that models the alias relationship of entry function argument";
let description = [{
This attribute captures the alias relationship of a main function
argument to one of the results, denoted by `resultIndex`. The
`argTupleIndices` and `resultTupleIndices` are used to index into nested
tuples in operand and result respectively. If `isMustAlias` is true then the
operand-result pair must alias.

This is meant to be used as an attribute on a function argument.
For example, in the following code it expresses that `%arg1` may alias 0-th
result.

```mlir
func @main(%arg0: tensor<2xf32>, %arg1: tensor<3xf32>
{stablehlo.result_alias = stablehlo.result_alias<result_index = [2], ...>}
) -> tensor<2xf32>, tensor<3xf32> {
// function body ...
}
```
}];
let parameters = (ins
StableHLO_Dims:$argTupleIndices,
"int64_t":$resultIndex,
StableHLO_Dims:$resultTupleIndices,
"bool":$isMustAlias
);
let hasCustomAssemblyFormat = 1;
}

// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
Expand Down
57 changes: 7 additions & 50 deletions stablehlo/dialect/StablehloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,8 @@ namespace stablehlo_encoding {
enum AttributeCode {
// TO ADD ATTRIBUTE: Add an enum value with doc string for new attr.

/// ArgResultAliasAttr {
/// argTupleIndices: svarint[]
/// resultIndex: svarint
/// resultIndex: svarint[]
/// isMustAlias: varint
/// }
kArgResultAliasAttr = 0,
/// ArgResultAliasAttr (obsolete)
// kArgResultAliasAttr = 0,

/// ChannelHandleAttr {
/// handle: svarint
Expand Down Expand Up @@ -207,8 +202,6 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {

// TO ADD ATTRIBUTE: Include a read method for each attribute in StableHLO
// Ex: SomeAttr readSomeAttr(DialectBytecodeReader &reader) const;
ArgResultAliasAttr readArgResultAliasAttr(
DialectBytecodeReader &reader) const;
ChannelHandleAttr readChannelHandleAttr(DialectBytecodeReader &reader) const;
ComparisonDirectionAttr readComparisonDirectionAttr(
DialectBytecodeReader &reader) const;
Expand All @@ -235,7 +228,6 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {

// TO ADD ATTRIBUTE: Include a write method for each attribute in StableHLO
// Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const;
void write(ArgResultAliasAttr attr, DialectBytecodeWriter &writer) const;
void write(ChannelHandleAttr attr, DialectBytecodeWriter &writer) const;
void write(ComparisonDirectionAttr attr, DialectBytecodeWriter &writer) const;
void write(ComparisonTypeAttr attr, DialectBytecodeWriter &writer) const;
Expand Down Expand Up @@ -282,8 +274,6 @@ Attribute StablehloBytecodeInterface::readAttribute(
uint64_t code;
if (failed(reader.readVarInt(code))) return Attribute();
switch (code) {
case stablehlo_encoding::kArgResultAliasAttr:
return readArgResultAliasAttr(reader);
case stablehlo_encoding::kChannelHandleAttr:
return readChannelHandleAttr(reader);
case stablehlo_encoding::kComparisonDirectionAttr:
Expand Down Expand Up @@ -324,12 +314,11 @@ Attribute StablehloBytecodeInterface::readAttribute(
LogicalResult StablehloBytecodeInterface::writeAttribute(
Attribute attr, DialectBytecodeWriter &writer) const {
return TypeSwitch<Attribute, LogicalResult>(attr)
.Case<ArgResultAliasAttr, ChannelHandleAttr, ComparisonDirectionAttr,
ComparisonTypeAttr, ConvDimensionNumbersAttr,
DotDimensionNumbersAttr, FftTypeAttr, GatherDimensionNumbersAttr,
OutputOperandAliasAttr, PrecisionAttr, RngAlgorithmAttr,
RngDistributionAttr, ScatterDimensionNumbersAttr, TransposeAttr,
TypeExtensionsAttr>([&](auto attr) {
.Case<ChannelHandleAttr, ComparisonDirectionAttr, ComparisonTypeAttr,
ConvDimensionNumbersAttr, DotDimensionNumbersAttr, FftTypeAttr,
GatherDimensionNumbersAttr, OutputOperandAliasAttr, PrecisionAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
Expand All @@ -340,38 +329,6 @@ LogicalResult StablehloBytecodeInterface::writeAttribute(
});
}

//===----------------------------------------------------------------------===//
// ArgResultAliasAttr

ArgResultAliasAttr StablehloBytecodeInterface::readArgResultAliasAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;

llvm::SmallVector<int64_t> argTupleIndices;
int64_t resultIndex;
llvm::SmallVector<int64_t> resultTupleIndices;
uint64_t isMustAliasUint;

if (failed(reader.readSignedVarInts(argTupleIndices)) ||
failed(reader.readSignedVarInt(resultIndex)) ||
failed(reader.readSignedVarInts(resultTupleIndices)) ||
failed(reader.readVarInt(isMustAliasUint)))
return ArgResultAliasAttr();

return ArgResultAliasAttr::get(getContext(), argTupleIndices, resultIndex,
resultTupleIndices,
static_cast<bool>(isMustAliasUint));
}

void StablehloBytecodeInterface::write(ArgResultAliasAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kArgResultAliasAttr);
writer.writeSignedVarInts(attr.getArgTupleIndices());
writer.writeSignedVarInt(attr.getResultIndex());
writer.writeSignedVarInts(attr.getResultTupleIndices());
writer.writeVarInt(attr.getIsMustAlias());
}

//===----------------------------------------------------------------------===//
// ChannelHandleAttr

Expand Down
161 changes: 0 additions & 161 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2555,17 +2555,6 @@ static ParseResult parseDims(AsmParser& parser,
return success();
}

static ParseResult parseDimsWithMinimumElements(AsmParser& parser,
SmallVector<int64_t>& dimSizes,
int minElements) {
if (failed(parseDims(parser, dimSizes))) return failure();
if (static_cast<int64_t>(dimSizes.size()) < minElements)
return parser.emitError(parser.getCurrentLocation())
<< "expected at least " << minElements << " element(s), found "
<< dimSizes.size();
return success();
}

/// Parse a custom attribute that resembles a struct of the form
/// <
/// foo = something_parsed_by_custom_parser,
Expand Down Expand Up @@ -3072,136 +3061,6 @@ Attribute ConvDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
return dnums;
}

// Custom printer and parser for ArgResultAliasAttr.
constexpr char kMustAlias[] = "must_alias";
constexpr char kResult[] = "result_index";
constexpr char kArgTupleIndices[] = "tuple_indices";

void ArgResultAliasAttr::print(AsmPrinter& printer) const {
printer << "<";

// The attribute can have empty tuple indices. Only print argument tuple
// indices if they are non-empty.
if (!getArgTupleIndices().empty())
printer << kArgTupleIndices << " = [" << getArgTupleIndices() << "], ";

// Print the result index followed by any result tuple indices if present.
printer << kResult << " = [";
printer << getResultIndex();
if (!getResultTupleIndices().empty())
printer << ", " << getResultTupleIndices();
printer << "]";

// Print the "must_alias" keyword if this is a must alias, otherwise skip.
if (getIsMustAlias()) printer << ", " << kMustAlias;

printer << ">";
}

Attribute ArgResultAliasAttr::parse(AsmParser& parser, Type type) {
if (failed(parser.parseLess())) return {};
llvm::SmallVector<int64_t> argTupleIndices;
// The first element of result indices holds the aliased result index and the
// remaining elements are the result tuple indices.
llvm::SmallVector<int64_t> resultIndices;
bool isMustAlias = false;

// This conveys to parseStruct that keyword "must_alias" (3rd field) is not
// followed by a "=", but other fields are.
llvm::SmallVector<bool, 3> parseEqual = {true, true, false};

if (failed(parseStruct(parser, {kArgTupleIndices, kResult, kMustAlias},
{[&]() { return parseDims(parser, argTupleIndices); },
[&]() {
// Since the first element is the index of result,
// at least one element is expected.
return parseDimsWithMinimumElements(
parser, resultIndices, /*minElements=*/1);
},
[&]() {
// always succeeds if the keyword "must_alias" was
// parsed
isMustAlias = true;
return success();
}},
parseEqual))) {
parser.emitError(parser.getCurrentLocation())
<< "failed parsing argument-result alias attribute";
return {};
}

int64_t resultIndex = resultIndices[0];
auto resultTupleIndices =
ArrayRef<int64_t>{resultIndices.begin() + 1, resultIndices.end()};

return ArgResultAliasAttr::get(parser.getContext(), argTupleIndices,
resultIndex, resultTupleIndices, isMustAlias);
}

// Returns the element type pointed to by `indices` in type `t`. If the indices
// are invalid, returns nullptr.
static Type getTypeFromTupleIndices(Type type, ArrayRef<int64_t> indices) {
Type current = type;
for (auto index : indices) {
TupleType tupleType = current.dyn_cast<TupleType>();
if (!tupleType || index >= static_cast<int64_t>(tupleType.size()))
return {};
current = tupleType.getType(index);
}
return current;
}

static LogicalResult verifyArgResultAliasAttr(StringAttr attrName,
ArgResultAliasAttr aliasAttr,
unsigned argIndex,
Operation* op) {
// The attribute can only be applied to function-like operations.
if (!isa<mlir::FunctionOpInterface>(op))
return op->emitOpError() << "attribute " << attrName
<< " can only be used on function-like operations";

// Verify there are no negative indices.
auto tupleIndices = llvm::concat<const int64_t>(
aliasAttr.getArgTupleIndices(), aliasAttr.getResultTupleIndices());
if (llvm::any_of(tupleIndices, [](const int64_t val) { return val < 0; }) ||
aliasAttr.getResultIndex() < 0)
return op->emitOpError()
<< "attribute " << attrName
<< " expects all argument and result indices to be >= 0";

// Verify that the result index is not out of range. Since the attribute is a
// function argument attribute, the argument index is always correct when this
// verifier is called.
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
ArrayRef<Type> argTypes = funcOp.getArgumentTypes();
ArrayRef<Type> resultTypes = funcOp.getResultTypes();
if (aliasAttr.getResultIndex() >= static_cast<int64_t>(resultTypes.size()))
return op->emitOpError()
<< "attribute " << attrName
<< " result index is out of range, must be <" << resultTypes.size();

// Verify that argument and result types pointed to by the indices are valid
// and compatible.
Type argType = getTypeFromTupleIndices(argTypes[argIndex],
aliasAttr.getArgTupleIndices());
if (!argType)
return op->emitOpError()
<< "attribute " << attrName << " argument tuple indices are invalid";
Type resultType =
getTypeFromTupleIndices(resultTypes[aliasAttr.getResultIndex()],
aliasAttr.getResultTupleIndices());
if (!resultType)
return op->emitOpError()
<< "attribute " << attrName << " result tuple indices are invalid";

if (failed(mlir::verifyCompatibleShape(argType, resultType)) ||
getElementTypeOrSelf(argType) != getElementTypeOrSelf(resultType))
return op->emitOpError() << "attribute " << attrName
<< " aliases do not have compatible types, "
<< argType << " vs. " << resultType;
return success();
}

namespace {
// Custom formatting for convolution window attributes.
void printWindowAttribute(OpAsmPrinter& p, DenseElementsAttr attribute) {
Expand Down Expand Up @@ -3421,25 +3280,5 @@ Operation* StablehloDialect::materializeConstant(OpBuilder& builder,
return builder.create<ConstantOp>(loc, type, elementsAttr);
}

LogicalResult StablehloDialect::verifyRegionArgAttribute(
Operation* op, unsigned /*regionIndex*/, unsigned argIndex,
NamedAttribute attr) {
if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>())
if (failed(
verifyArgResultAliasAttr(attr.getName(), aliasAttr, argIndex, op)))
return failure();
return success();
}

LogicalResult StablehloDialect::verifyOperationAttribute(Operation* op,
NamedAttribute attr) {
if (auto aliasAttr = attr.getValue().dyn_cast<ArgResultAliasAttr>())
if (!isa<mlir::FunctionOpInterface>(op))
return op->emitOpError()
<< "attribute " << attr.getName()
<< " can only be used on function-like operations";
return success();
}

} // namespace stablehlo
} // namespace mlir
10 changes: 0 additions & 10 deletions stablehlo/dialect/StablehloOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ class StablehloDialect : public Dialect {
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;

// Registered hook to verify region arg attributes on operations.
LogicalResult verifyRegionArgAttribute(mlir::Operation *op,
unsigned regionIndex,
unsigned argIndex,
mlir::NamedAttribute attr) override;

// Registered hook to verify an attribute from this dialect on operations.
LogicalResult verifyOperationAttribute(mlir::Operation *op,
mlir::NamedAttribute attr) override;

// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;

Expand Down
14 changes: 0 additions & 14 deletions stablehlo/dialect/VhloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class VHLO_AttrDef<string name, string minVersion, string maxVersion>
}];
}

// TODO(#740): ArgResultAlias is not yet part of the StableHLO spec.
// At the moment, it is used to represent buffer donation, and we're planning
// to look into it as part of the work on speccing buffer donation in StableHLO.
def VHLO_ArgResultAliasAttrV1 : VHLO_AttrDef<"ArgResultAliasV1", "0.9.0", "current"> {
let mnemonic = "result_alias_v1";
let parameters = (ins
VHLO_Dims:$argTupleIndices,
"int64_t":$resultIndex,
VHLO_Dims:$resultTupleIndices,
"bool":$isMustAlias
);
let assemblyFormat = "`<` struct(params) `>`";
}

// Represents attributes from the StableHLO spec which say "variadic number of",
// although not called out explicitly in the "Constants" section.
def VHLO_ArrayAttrV1 : VHLO_AttrDef<"ArrayV1", "0.9.0", "current"> {
Expand Down
Loading

0 comments on commit 72a0117

Please sign in to comment.