Skip to content

Commit dbdd9e7

Browse files
WillFroomGoogle-ML-Automation
authored andcommitted
[XTile] Introduce to/from tensor ops.
PiperOrigin-RevId: 827514977
1 parent adc0442 commit dbdd9e7

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: emitters_opt %s -canonicalize | FileCheck %s
2+
3+
// CHECK-LABEL: @to_scalar_roundtrip
4+
func.func @to_scalar_roundtrip(%arg0: tensor<i32>) -> tensor<i32> {
5+
%0 = xtile.to_scalar %arg0 : tensor<i32>
6+
%1 = xtile.to_tensor %0 : i32
7+
// CHECK: return %arg0 : tensor<i32>
8+
return %1 : tensor<i32>
9+
}
10+
11+
// CHECK-LABEL: @to_tensor_roundtrip
12+
func.func @to_tensor_roundtrip(%arg0: i32) -> i32 {
13+
%0 = xtile.to_tensor %arg0 : i32
14+
%1 = xtile.to_scalar %0 : tensor<i32>
15+
// CHECK: return %arg0 : i32
16+
return %1 : i32
17+
}

xla/codegen/xtile/ir/xtile_ops.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,4 +196,58 @@ mlir::TypedValue<mlir::RankedTensorType> InsertTileOp::getTile() {
196196

197197
mlir::LogicalResult InsertTileOp::verify() { return VerifyBufferOp(*this); }
198198

199+
mlir::LogicalResult ToScalarOp::inferReturnTypes(
200+
mlir::MLIRContext* context, ::std::optional<mlir::Location> location,
201+
mlir::ValueRange operands, mlir::DictionaryAttr attributes,
202+
mlir::OpaqueProperties properties, mlir::RegionRange regions,
203+
::llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
204+
if (operands.size() != 1) {
205+
return mlir::failure();
206+
}
207+
208+
auto tensor_type =
209+
mlir::dyn_cast<mlir::RankedTensorType>(operands[0].getType());
210+
if (!tensor_type) {
211+
return mlir::failure();
212+
}
213+
214+
if (tensor_type.getRank() != 0) {
215+
return mlir::failure();
216+
}
217+
218+
inferredReturnTypes.push_back(tensor_type.getElementType());
219+
return mlir::success();
220+
}
221+
222+
mlir::OpFoldResult ToScalarOp::fold(FoldAdaptor adaptor) {
223+
if (auto to_tensor = getOperand().getDefiningOp<ToTensorOp>()) {
224+
// to_scalar(to_tensor(x)) -> x
225+
return to_tensor.getOperand();
226+
}
227+
228+
return {};
229+
}
230+
231+
mlir::LogicalResult ToTensorOp::inferReturnTypes(
232+
mlir::MLIRContext* context, ::std::optional<mlir::Location> location,
233+
mlir::ValueRange operands, mlir::DictionaryAttr attributes,
234+
mlir::OpaqueProperties properties, mlir::RegionRange regions,
235+
::llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
236+
if (operands.size() != 1) {
237+
return mlir::failure();
238+
}
239+
inferredReturnTypes.push_back(
240+
mlir::RankedTensorType::get({}, operands[0].getType()));
241+
return mlir::success();
242+
}
243+
244+
mlir::OpFoldResult ToTensorOp::fold(FoldAdaptor adaptor) {
245+
if (auto to_scalar = getOperand().getDefiningOp<ToScalarOp>()) {
246+
// to_tensor(to_scalar(x)) -> x
247+
return to_scalar.getOperand();
248+
}
249+
250+
return {};
251+
}
252+
199253
} // namespace xla::xtile

xla/codegen/xtile/ir/xtile_ops.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,5 +202,27 @@ def InsertTileOp : XTile_Op<"insert", [TiledBufferInterface]> {
202202
let hasVerifier = 1;
203203
}
204204

205+
// TODO(willfroom): Revisit if/where these are needed after the migration of the
206+
// triton emitter to support 0D tensors is complete.
207+
def ToScalarOp : XTile_Op<"to_scalar", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
208+
let summary = "Converts a 0D tensor to a scalar";
209+
let arguments = (ins 0DTensorOf<[AnyType]>:$input);
210+
let results = (outs AnyType:$output);
211+
212+
let assemblyFormat = "$input `:` type($input) attr-dict";
213+
214+
let hasFolder = 1;
215+
}
216+
217+
def ToTensorOp : XTile_Op<"to_tensor", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
218+
let summary = "Converts a scalar to a 0D tensor";
219+
let arguments = (ins AnyType:$input);
220+
let results = (outs 0DTensorOf<[AnyType]>:$output);
221+
222+
let assemblyFormat = "$input `:` type($input) attr-dict";
223+
224+
let hasFolder = 1;
225+
}
226+
205227
#endif // XLA_CODEGEN_XTILE_IR_XTILE_OPS
206228

0 commit comments

Comments
 (0)