Skip to content

Commit

Permalink
[torch] Add torch.aten.where.* folders (llvm#2886)
Browse files Browse the repository at this point in the history
Where operation can be statically computed when involving splats of
known value. Added handling these cases with multiple tests.
  • Loading branch information
rsuderman authored Feb 8, 2024
1 parent 23647ab commit a8aad2a
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 4 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10635,6 +10635,7 @@ def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [
Expand All @@ -10660,6 +10661,7 @@ def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [
Expand All @@ -10685,6 +10687,7 @@ def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
Expand All @@ -10710,6 +10713,7 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
Expand Down
120 changes: 120 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3152,6 +3152,126 @@ OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) {
return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenWhereSelfOp
//===----------------------------------------------------------------------===//

static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) {
if (!attr || !ty.hasDtype() || !ty.hasSizes())
return nullptr;

auto dty = ty.getDtype();

if (auto valueDense = dyn_cast<DenseElementsAttr>(attr)) {
if (!valueDense.isSplat())
return nullptr;
auto splattr = valueDense.getSplatValue<Attribute>();
auto attrty = ty.toBuiltinTensor().clone(dty);
return DenseElementsAttr::get(attrty, splattr);
}

if (auto intAttr = dyn_cast_or_null<IntegerAttr>(attr)) {
if (!isa<mlir::IntegerType>(dty))
return nullptr;
int64_t intval = intAttr.getInt();
auto attrty = ty.toBuiltinTensor().clone(dty);
return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval));
}

if (auto fpAttr = dyn_cast_or_null<FloatAttr>(attr)) {
if (!isa<mlir::FloatType>(dty))
return nullptr;
double dblval = fpAttr.getValueAsDouble();
auto attrty = ty.toBuiltinTensor().clone(dty);
return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval));
}

return nullptr;
}

OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) {
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
!dense.isSplat())
return nullptr;

auto condattr = dense.getSplatValue<APInt>();
auto value = getSelf();
auto valueAttr = adaptor.getSelf();
if (condattr.isZero()) {
value = getOther();
valueAttr = adaptor.getOther();
}

auto valueTy = dyn_cast<ValueTensorType>(value.getType());
if (valueTy && valueTy.hasSizes() && valueTy.hasDtype() &&
valueTy == resultTy)
return value;

return getBroadcastedAttr(valueAttr, resultTy);
}

//===----------------------------------------------------------------------===//
// AtenWhereScalarOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) {
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
!dense.isSplat())
return nullptr;

auto condattr = dense.getSplatValue<APInt>();
auto valueAttr = adaptor.getSelf();
if (condattr.isZero()) {
valueAttr = adaptor.getOther();
}

return getBroadcastedAttr(valueAttr, resultTy);
}

//===----------------------------------------------------------------------===//
// AtenWhereScalarOtherOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenWhereScalarOtherOp::fold(FoldAdaptor adaptor) {
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
!dense.isSplat())
return nullptr;

auto condattr = dense.getSplatValue<APInt>();
auto valueAttr = adaptor.getSelf();
if (condattr.isZero()) {
valueAttr = adaptor.getOther();
}

return getBroadcastedAttr(valueAttr, resultTy);
}

//===----------------------------------------------------------------------===//
// AtenWhereScalarSelfOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenWhereScalarSelfOp::fold(FoldAdaptor adaptor) {
auto dense = dyn_cast_or_null<DenseElementsAttr>(adaptor.getCondition());
auto resultTy = dyn_cast<ValueTensorType>(getType());
if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense ||
!dense.isSplat())
return nullptr;

auto condattr = dense.getSplatValue<APInt>();
auto valueAttr = adaptor.getSelf();
if (condattr.isZero()) {
valueAttr = adaptor.getOther();
}

return getBroadcastedAttr(valueAttr, resultTy);
}

//===----------------------------------------------------------------------===//
// PrimMaxIntOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,10 +649,10 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True)
emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_folder=True)
emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True)
emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True)
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True)
emit("aten::len.Tensor : (Tensor) -> (int)")
Expand Down

0 comments on commit a8aad2a

Please sign in to comment.