Skip to content

Commit

Permalink
Rob's atenTensor folder
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Liddell committed Feb 5, 2024
1 parent b3a56c0 commit 9483afa
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2758,6 +2758,24 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
});
}

//===----------------------------------------------------------------------===//
// AtenTensorOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
auto resultTy = dyn_cast<ValueTensorType>(getType());
Type eTy = resultTy.getDtype();
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);

SmallVector<int64_t> data;
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && data.size() == 1) {
Attribute attribute = IntegerAttr::get(eTy, data[0]);
return DenseElementsAttr::get(shapedTy, attribute);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// AtenIntTensorOp
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True)
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
Expand Down
11 changes: 11 additions & 0 deletions test/Dialect/Torch/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
return %0 : !torch.tensor<[],f32>
}

// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(
// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) {
%none = torch.constant.none
%false = torch.constant.bool false
%int42 = torch.constant.int 42
%66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list<int>
%67 = torch.aten.tensor %66, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
return %67 : !torch.vtensor<[1],si64>
}

// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
Expand Down

0 comments on commit 9483afa

Please sign in to comment.