Skip to content

Commit

Permalink
Auto fix: clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Apr 22, 2024
1 parent 74b81be commit 61960d4
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 46 deletions.
67 changes: 32 additions & 35 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,42 +628,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, operand);
return success();
});
patterns.onOp("Not", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}
patterns.onOp(
"Not", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value operand;
if (binder.tensorOperand(operand) ||
binder.tensorResultType(resultType)) {
return failure();
}

auto loc = binder.getLoc();
auto operandTy =
cast<Torch::ValueTensorType>(operand.getType());
auto eTy = operandTy.getDtype();

if (!eTy.isInteger(1)) {
auto i1ty = rewriter.getI1Type();
auto ty = rewriter.getType<Torch::ValueTensorType>(
operandTy.getSizes(), i1ty);
auto torchqTy = Torch::getScalarTypeForType(i1ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(
rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(loc, false);
operand = rewriter.create<Torch::AtenToDtypeOp>(
loc, ty, operand, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
binder.op, resultType, operand);
return success();
});
auto loc = binder.getLoc();
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
auto eTy = operandTy.getDtype();

if (!eTy.isInteger(1)) {
auto i1ty = rewriter.getI1Type();
auto ty = rewriter.getType<Torch::ValueTensorType>(
operandTy.getSizes(), i1ty);
auto torchqTy = Torch::getScalarTypeForType(i1ty);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
operand = rewriter.create<Torch::AtenToDtypeOp>(
loc, ty, operand, tyConst,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
}
rewriter.replaceOpWithNewOp<Torch::AtenBitwiseNotOp>(
binder.op, resultType, operand);
return success();
});
patterns.onOp("Or", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
do_bcast = true;
} else {
op->emitError("The size of tensor a (")
<< inDim << ")"
<< "must match the size of tensor b (" << outDim << ")"
<< "at non-singleton dimension " << inPos;
<< inDim << ")" << "must match the size of tensor b (" << outDim
<< ")" << "at non-singleton dimension " << inPos;
}
}
std::reverse(bcastDims.begin(), bcastDims.end());
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,7 @@ class LowerToBackendContractPass
return signalPassFailure();
} while (!satisfiesBackendContract(module, target));
LLVM_DEBUG({
llvm::dbgs() << "LowerToBackendContractPass: "
<< "succeeded after " << i
llvm::dbgs() << "LowerToBackendContractPass: " << "succeeded after " << i
<< " iterations of the simplification pipeline\n";
});
}
Expand Down
2 changes: 1 addition & 1 deletion projects/pt1/examples/torchscript_resnet_inference.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"cells": [
"cells": [
{
"cell_type": "code",
"execution_count": 14,
Expand Down
14 changes: 9 additions & 5 deletions test/CAPI/torch.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes,
fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \
torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \
int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \
torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \
torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \
for (int i = 0; i < numSizes; ++i) { \
fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \
TTT##Sizes[i]); \
Expand Down Expand Up @@ -157,22 +157,26 @@ static void testTypeMetaDataAccessors(MlirContext ctx) {
MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType);

fprintf(stderr, "dict keyType: ");
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL);
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr,
NULL);
fprintf(stderr, "\n");
// CHECK: dict keyType: !torch.str
fprintf(stderr, "dict valueType: ");
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL);
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr,
NULL);
fprintf(stderr, "\n");
// CHECK: dict valueType: !torch.float

MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType);

fprintf(stderr, "dict keyType: ");
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL);
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr,
NULL);
fprintf(stderr, "\n");
// CHECK: dict keyType: !torch.float
fprintf(stderr, "dict valueType: ");
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL);
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr,
NULL);
fprintf(stderr, "\n");
// CHECK: dict valueType: !torch.str
}
Expand Down

0 comments on commit 61960d4

Please sign in to comment.