diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c7d0710791193..7cbc4295d5aaa 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -628,42 +628,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand); return success(); }); - patterns.onOp("Not", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } + patterns.onOp( + "Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } - auto loc = binder.getLoc(); - auto operandTy = - cast(operand.getType()); - auto eTy = operandTy.getDtype(); - - if (!eTy.isInteger(1)) { - auto i1ty = rewriter.getI1Type(); - auto ty = rewriter.getType( - operandTy.getSizes(), i1ty); - auto torchqTy = Torch::getScalarTypeForType(i1ty); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(64), - static_cast(torchqTy))); - Value none = rewriter.create(loc); - Value cstFalse = - rewriter.create(loc, false); - operand = rewriter.create( - loc, ty, operand, tyConst, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + auto loc = binder.getLoc(); + auto operandTy = cast(operand.getType()); + auto eTy = operandTy.getDtype(); + + if (!eTy.isInteger(1)) { + auto i1ty = rewriter.getI1Type(); + auto ty = rewriter.getType( + operandTy.getSizes(), i1ty); + auto torchqTy = Torch::getScalarTypeForType(i1ty); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + operand = rewriter.create( + loc, ty, operand, tyConst, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index 551f79c472885..d20e626656d24 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -189,9 +189,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, do_bcast = true; } else { op->emitError("The size of tensor a (") - << inDim << ")" - << "must match the size of tensor b (" << outDim << ")" - << "at non-singleton dimension " << inPos; + << inDim << ")" << "must match the size of tensor b (" << outDim + << ")" << "at non-singleton dimension " << inPos; } } std::reverse(bcastDims.begin(), bcastDims.end()); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 701300fefe436..f02a7fe4a823c 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -305,8 +305,7 @@ class LowerToBackendContractPass return signalPassFailure(); } while (!satisfiesBackendContract(module, target)); LLVM_DEBUG({ - llvm::dbgs() << "LowerToBackendContractPass: " - << "succeeded after " << i + llvm::dbgs() << "LowerToBackendContractPass: " << "succeeded after " << i << " iterations of the simplification pipeline\n"; }); } diff --git a/projects/pt1/examples/torchscript_resnet_inference.ipynb b/projects/pt1/examples/torchscript_resnet_inference.ipynb index 9970f90b8bb2e..e045f6a4c27da 100644 --- a/projects/pt1/examples/torchscript_resnet_inference.ipynb +++ b/projects/pt1/examples/torchscript_resnet_inference.ipynb @@ -1,5 +1,5 @@ { - "cells": [ + "cells": [ { "cell_type": "code", "execution_count": 14, diff --git a/test/CAPI/torch.c b/test/CAPI/torch.c index e9c5d23e24385..d42cf96d554cd 100644 --- a/test/CAPI/torch.c +++ b/test/CAPI/torch.c @@ -36,7 +36,7 @@ static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes, fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \ - torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ + torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ for (int i = 0; i < numSizes; ++i) { \ fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ TTT##Sizes[i]); \ @@ -157,22 +157,26 @@ static void testTypeMetaDataAccessors(MlirContext ctx) { MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType); fprintf(stderr, "dict keyType: "); - mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL); + mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, + NULL); fprintf(stderr, "\n"); // CHECK: dict keyType: !torch.str fprintf(stderr, "dict valueType: "); - mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL); + mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, + NULL); fprintf(stderr, "\n"); // CHECK: dict valueType: !torch.float MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType); fprintf(stderr, "dict keyType: "); - mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL); + mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, + NULL); fprintf(stderr, "\n"); // CHECK: dict keyType: !torch.float fprintf(stderr, "dict valueType: "); - mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL); + mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, + NULL); fprintf(stderr, "\n"); // CHECK: dict valueType: !torch.str }