@@ -196,4 +196,58 @@ mlir::TypedValue<mlir::RankedTensorType> InsertTileOp::getTile() {
196196
197197mlir::LogicalResult InsertTileOp::verify () { return VerifyBufferOp (*this ); }
198198
199+ mlir::LogicalResult ToScalarOp::inferReturnTypes (
200+ mlir::MLIRContext* context, ::std::optional<mlir::Location> location,
201+ mlir::ValueRange operands, mlir::DictionaryAttr attributes,
202+ mlir::OpaqueProperties properties, mlir::RegionRange regions,
203+ ::llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
204+ if (operands.size () != 1 ) {
205+ return mlir::failure ();
206+ }
207+
208+ auto tensor_type =
209+ mlir::dyn_cast<mlir::RankedTensorType>(operands[0 ].getType ());
210+ if (!tensor_type) {
211+ return mlir::failure ();
212+ }
213+
214+ if (tensor_type.getRank () != 0 ) {
215+ return mlir::failure ();
216+ }
217+
218+ inferredReturnTypes.push_back (tensor_type.getElementType ());
219+ return mlir::success ();
220+ }
221+
222+ mlir::OpFoldResult ToScalarOp::fold (FoldAdaptor adaptor) {
223+ if (auto to_tensor = getOperand ().getDefiningOp <ToTensorOp>()) {
224+ // to_scalar(to_tensor(x)) -> x
225+ return to_tensor.getOperand ();
226+ }
227+
228+ return {};
229+ }
230+
231+ mlir::LogicalResult ToTensorOp::inferReturnTypes (
232+ mlir::MLIRContext* context, ::std::optional<mlir::Location> location,
233+ mlir::ValueRange operands, mlir::DictionaryAttr attributes,
234+ mlir::OpaqueProperties properties, mlir::RegionRange regions,
235+ ::llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
236+ if (operands.size () != 1 ) {
237+ return mlir::failure ();
238+ }
239+ inferredReturnTypes.push_back (
240+ mlir::RankedTensorType::get ({}, operands[0 ].getType ()));
241+ return mlir::success ();
242+ }
243+
244+ mlir::OpFoldResult ToTensorOp::fold (FoldAdaptor adaptor) {
245+ if (auto to_scalar = getOperand ().getDefiningOp <ToScalarOp>()) {
246+ // to_tensor(to_scalar(x)) -> x
247+ return to_scalar.getOperand ();
248+ }
249+
250+ return {};
251+ }
252+
199253} // namespace xla::xtile
0 commit comments