diff --git a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp index 3155f7980..cb20bcaf3 100644 --- a/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp +++ b/lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp @@ -415,33 +415,33 @@ struct ConvertLWEReinterpretUnderlyingType } // namespace // BGV -using ConvertBGVAddOp = - ConvertRlweBinOp; -using ConvertBGVSubOp = - ConvertRlweBinOp; -using ConvertBGVMulOp = - ConvertRlweBinOp; +using ConvertBGVAddOp = ConvertRlweBinOp; +using ConvertBGVSubOp = ConvertRlweBinOp; +using ConvertBGVMulOp = ConvertRlweBinOp; using ConvertBGVAddPlainOp = ConvertRlwePlainOp; + lattigo::BGVAddNewOp>; using ConvertBGVSubPlainOp = ConvertRlwePlainOp; + lattigo::BGVSubNewOp>; using ConvertBGVMulPlainOp = ConvertRlwePlainOp; + lattigo::BGVMulNewOp>; using ConvertBGVRelinOp = ConvertRlweUnaryOp; + lattigo::BGVRelinearizeNewOp>; using ConvertBGVModulusSwitchOp = ConvertRlweUnaryOp; + lattigo::BGVRescaleNewOp>; // TODO(#1186): figure out generic rotating using BGVRotateColumns/RowsOp using ConvertBGVRotateOp = ConvertRlweRotateOp; + lattigo::BGVRotateColumnsNewOp>; using ConvertBGVEncryptOp = ConvertRlweUnaryOp { let results = (outs Lattigo_BGVEncoder:$encoder); } -def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> { +def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InplaceOpInterface]> { let summary = "Encode a plaintext value in the Lattigo BGV dialect"; let description = [{ This operation encodes a plaintext value using the specified encoder in the Lattigo BGV dialect. @@ -55,6 +55,8 @@ def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> { Lattigo_RLWEPlaintext:$plaintext ); let results = (outs Lattigo_RLWEPlaintext:$encoded); + + let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }"; } def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>]> { @@ -69,6 +71,8 @@ def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "deco Lattigo_RLWEPlaintext:$plaintext, RankedTensorOf<[AnyInteger]>:$value ); + // although bgv.Decode is also an inplace operation as bgv.Encode, as there are post-processing + // steps in emitter, we mark it as a normal operation. let results = (outs RankedTensorOf<[AnyInteger]>:$decoded); } @@ -102,27 +106,72 @@ class Lattigo_BGVBinaryOp : let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_BGVAddOp : Lattigo_BGVBinaryOp<"add"> { +def Lattigo_BGVAddNewOp : Lattigo_BGVBinaryOp<"add_new"> { let summary = "Add two ciphertexts in the Lattigo BGV dialect"; let description = [{ This operation adds two ciphertext values in the Lattigo BGV dialect. }]; } -def Lattigo_BGVSubOp : Lattigo_BGVBinaryOp<"sub"> { +def Lattigo_BGVSubNewOp : Lattigo_BGVBinaryOp<"sub_new"> { let summary = "Subtract two ciphertexts in the Lattigo BGV dialect"; let description = [{ This operation subtracts one ciphertext value from another in the Lattigo BGV dialect. }]; } -def Lattigo_BGVMulOp : Lattigo_BGVBinaryOp<"mul"> { +def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new"> { let summary = "Multiply two ciphertexts in the Lattigo BGV dialect"; let description = [{ This operation multiplies two ciphertext values in the Lattigo BGV dialect. }]; } +class Lattigo_BGVBinaryInplaceOp : + Lattigo_BGVOp { + let arguments = (ins + Lattigo_BGVEvaluator:$evaluator, + Lattigo_RLWECiphertext:$lhs, + Lattigo_RLWECiphertextOrPlaintext:$rhs, + // Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to + // satisfy the SSA form, so we still have a separate output. + Lattigo_RLWECiphertext:$inplace + ); + let results = (outs Lattigo_RLWECiphertext:$output); + + let extraClassDeclaration = "int getInplaceOperandIndex() { return 3; }"; +} + +def Lattigo_BGVAddOp : Lattigo_BGVBinaryInplaceOp<"add"> { + let summary = "Add two ciphertexts in the Lattigo BGV dialect"; + let description = [{ + This operation adds two ciphertext values in the Lattigo BGV dialect. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. + }]; +} + +def Lattigo_BGVSubOp : Lattigo_BGVBinaryInplaceOp<"sub"> { + let summary = "Subtract two ciphertexts in the Lattigo BGV dialect"; + let description = [{ + This operation subtracts one ciphertext value from another in the Lattigo BGV dialect. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. + }]; +} + +def Lattigo_BGVMulOp : Lattigo_BGVBinaryInplaceOp<"mul"> { + let summary = "Multiply two ciphertexts in the Lattigo BGV dialect"; + let description = [{ + This operation multiplies two ciphertext values in the Lattigo BGV dialect. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. + }]; +} + class Lattigo_BGVUnaryOp : Lattigo_BGVOp { let arguments = (ins @@ -132,21 +181,80 @@ class Lattigo_BGVUnaryOp : let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryOp<"relinearize"> { +def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> { + let summary = "Relinearize a ciphertext in the Lattigo BGV dialect"; + let description = [{ + This operation relinearizes a ciphertext value in the Lattigo BGV dialect. + }]; +} + +def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> { + let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; + let description = [{ + This operation rescales a ciphertext value in the Lattigo BGV dialect. + }]; +} + +def Lattigo_BGVRotateColumnsNewOp : Lattigo_BGVOp<"rotate_columns_new"> { + let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect"; + let description = [{ + This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect. + + Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column. + + Offset is valid for both positive and negative number. + }]; + let arguments = (ins + Lattigo_BGVEvaluator:$evaluator, + Lattigo_RLWECiphertext:$input, + Builtin_IntegerAttr:$offset + ); + let results = (outs Lattigo_RLWECiphertext:$output); +} + +def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> { + let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect"; + let description = [{ + This operation swap the rows of a ciphertext value in the Lattigo BGV dialect. + + Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row. + }]; +} + +class Lattigo_BGVUnaryInplaceOp : + Lattigo_BGVOp { + let arguments = (ins + Lattigo_BGVEvaluator:$evaluator, + Lattigo_RLWECiphertext:$input, + // see BinaryInplaceOp above + Lattigo_RLWECiphertext:$inplace + ); + let results = (outs Lattigo_RLWECiphertext:$output); + + let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }"; +} + +def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInplaceOp<"relinearize"> { let summary = "Relinearize a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation relinearizes a ciphertext value in the Lattigo BGV dialect. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. }]; } -def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryOp<"rescale"> { +def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInplaceOp<"rescale"> { let summary = "Rescale a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rescales a ciphertext value in the Lattigo BGV dialect. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. }]; } -def Lattigo_BGVRotateColumnsOp : Lattigo_BGVOp<"rotate_columns"> { +def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInplaceOp<"rotate_columns"> { let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect. @@ -154,21 +262,28 @@ def Lattigo_BGVRotateColumnsOp : Lattigo_BGVOp<"rotate_columns"> { Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column. Offset is valid for both positive and negative number. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. }]; let arguments = (ins Lattigo_BGVEvaluator:$evaluator, Lattigo_RLWECiphertext:$input, + Lattigo_RLWECiphertext:$inplace, Builtin_IntegerAttr:$offset ); let results = (outs Lattigo_RLWECiphertext:$output); } -def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryOp<"rotate_rows"> { +def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryInplaceOp<"rotate_rows"> { let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect"; let description = [{ This operation swap the rows of a ciphertext value in the Lattigo BGV dialect. Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row. + + The result will be written to the `inplace` operand. The `output`result is + a transitive reference to the `inplace` operand for sake of the MLIR SSA form. }]; } diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.h b/lib/Dialect/Lattigo/IR/LattigoOps.h index 016ee14f8..b49c5be44 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.h +++ b/lib/Dialect/Lattigo/IR/LattigoOps.h @@ -3,6 +3,7 @@ #include "lib/Dialect/Lattigo/IR/LattigoDialect.h" #include "lib/Dialect/Lattigo/IR/LattigoTypes.h" +#include "lib/Utils/Tablegen/InplaceOpInterface.h" #include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project #define GET_OP_CLASSES diff --git a/lib/Dialect/Lattigo/IR/LattigoOps.td b/lib/Dialect/Lattigo/IR/LattigoOps.td index adf869d0f..da515950b 100644 --- a/lib/Dialect/Lattigo/IR/LattigoOps.td +++ b/lib/Dialect/Lattigo/IR/LattigoOps.td @@ -4,6 +4,7 @@ include "LattigoDialect.td" include "LattigoTypes.td" include "mlir/IR/OpBase.td" +include "lib/Utils/Tablegen/InplaceOpInterface.td" class Lattigo_Op traits = []> : Op { diff --git a/lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp b/lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp new file mode 100644 index 000000000..a498f47dd --- /dev/null +++ b/lib/Dialect/Lattigo/Transforms/AllocToInplace.cpp @@ -0,0 +1,123 @@ +#include "lib/Dialect/Lattigo/Transforms/AllocToInplace.h" + +#include "lib/Dialect/Lattigo/IR/LattigoOps.h" +#include "mlir/include/mlir/Analysis/Liveness.h" // from @llvm-project +#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lattigo { + +template +struct ConvertBinOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertBinOp(mlir::MLIRContext *context, Liveness *liveness) + : OpRewritePattern(context), liveness(liveness) {} + + LogicalResult matchAndRewrite(BinOp op, + PatternRewriter &rewriter) const override { + // operand 0 is evaluator + auto lhs = op.getOperand(1); + if (!liveness->isDeadAfter(lhs, op)) { + return failure(); + } + + // InplaceOp has the form: output = InplaceOp(evaluator, lhs, rhs, inplace) + // where inplace is the actual output but for SSA form we need to return a + // new value + rewriter.replaceOpWithNewOp(op, op.getOperand(1).getType(), + op.getOperand(0), op.getOperand(1), + op.getOperand(2), op.getOperand(1)); + return success(); + } + + private: + Liveness *liveness; +}; + +template +struct ConvertUnaryOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertUnaryOp(mlir::MLIRContext *context, Liveness *liveness) + : OpRewritePattern(context), liveness(liveness) {} + + LogicalResult matchAndRewrite(UnaryOp op, + PatternRewriter &rewriter) const override { + // operand 0 is evaluator + auto lhs = op.getOperand(1); + if (!liveness->isDeadAfter(lhs, op)) { + return failure(); + } + + // InplaceOp has the form: output = InplaceOp(evaluator, lhs, inplace) + // where inplace is the actual output but for SSA form we need to return a + // new value + rewriter.replaceOpWithNewOp(op, op.getOperand(1).getType(), + op.getOperand(0), op.getOperand(1), + op.getOperand(1)); + return success(); + } + + private: + Liveness *liveness; +}; + +template +struct ConvertRotateOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + ConvertRotateOp(mlir::MLIRContext *context, Liveness *liveness) + : OpRewritePattern(context), liveness(liveness) {} + + LogicalResult matchAndRewrite(RotateOp op, + PatternRewriter &rewriter) const override { + // operand 0 is evaluator + auto lhs = op.getOperand(1); + if (!liveness->isDeadAfter(lhs, op)) { + return failure(); + } + + // InplaceOp has the form: output = InplaceOp(evaluator, lhs, inplace) + // {offset} where inplace is the actual output but for SSA form we need to + // return a new value + rewriter.replaceOpWithNewOp(op, op.getOperand(1).getType(), + op.getOperand(0), op.getOperand(1), + op.getOperand(1), op.getOffset()); + return success(); + } + + private: + Liveness *liveness; +}; + +#define GEN_PASS_DEF_ALLOCTOINPLACE +#include "lib/Dialect/Lattigo/Transforms/Passes.h.inc" + +struct AllocToInplace : impl::AllocToInplaceBase { + using AllocToInplaceBase::AllocToInplaceBase; + + void runOnOperation() override { + Liveness liveness(getOperation()); + + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + + patterns.add< + ConvertBinOp, + ConvertBinOp, + ConvertBinOp, + ConvertUnaryOp, + ConvertUnaryOp, + ConvertRotateOp >(context, &liveness); + + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + +} // namespace lattigo +} // namespace heir +} // namespace mlir diff --git a/lib/Dialect/Lattigo/Transforms/AllocToInplace.h b/lib/Dialect/Lattigo/Transforms/AllocToInplace.h new file mode 100644 index 000000000..652a96a74 --- /dev/null +++ b/lib/Dialect/Lattigo/Transforms/AllocToInplace.h @@ -0,0 +1,17 @@ +#ifndef LIB_DIALECT_LATTIGO_TRANSFORMS_ALLOCTOINPLACE_H_ +#define LIB_DIALECT_LATTIGO_TRANSFORMS_ALLOCTOINPLACE_H_ + +#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace heir { +namespace lattigo { + +#define GEN_PASS_DECL_ALLOCTOINPLACE +#include "lib/Dialect/Lattigo/Transforms/Passes.h.inc" + +} // namespace lattigo +} // namespace heir +} // namespace mlir + +#endif // LIB_DIALECT_LATTIGO_TRANSFORMS_ALLOCTOINPLACE_H_ diff --git a/lib/Dialect/Lattigo/Transforms/BUILD b/lib/Dialect/Lattigo/Transforms/BUILD index a982de40f..49e41cf44 100644 --- a/lib/Dialect/Lattigo/Transforms/BUILD +++ b/lib/Dialect/Lattigo/Transforms/BUILD @@ -9,12 +9,29 @@ cc_library( name = "Transforms", hdrs = ["Passes.h"], deps = [ + ":AllocToInplace", ":ConfigureCryptoContext", ":pass_inc_gen", "@heir//lib/Dialect/Lattigo/IR:Dialect", ], ) +cc_library( + name = "AllocToInplace", + srcs = ["AllocToInplace.cpp"], + hdrs = ["AllocToInplace.h"], + deps = [ + ":pass_inc_gen", + "@heir//lib/Dialect/Lattigo/IR:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + cc_library( name = "ConfigureCryptoContext", srcs = ["ConfigureCryptoContext.cpp"], diff --git a/lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.cpp b/lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.cpp index 1aad6a4d4..9aa874d6b 100644 --- a/lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.cpp +++ b/lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.cpp @@ -37,7 +37,7 @@ namespace lattigo { bool hasRelinOp(func::FuncOp op) { bool result = false; op.walk([&](Operation *op) { - if (isa(op)) { + if (isa(op)) { result = true; return WalkResult::interrupt(); } @@ -50,6 +50,10 @@ bool hasRelinOp(func::FuncOp op) { // TODO(#1186): handle rotate rows SmallVector findAllRotIndices(func::FuncOp op) { std::set distinctRotIndices; + op.walk([&](BGVRotateColumnsNewOp rotOp) { + distinctRotIndices.insert(rotOp.getOffset().getInt()); + return WalkResult::advance(); + }); op.walk([&](BGVRotateColumnsOp rotOp) { distinctRotIndices.insert(rotOp.getOffset().getInt()); return WalkResult::advance(); diff --git a/lib/Dialect/Lattigo/Transforms/Passes.h b/lib/Dialect/Lattigo/Transforms/Passes.h index d25457207..bfbac6fb5 100644 --- a/lib/Dialect/Lattigo/Transforms/Passes.h +++ b/lib/Dialect/Lattigo/Transforms/Passes.h @@ -2,6 +2,7 @@ #define LIB_DIALECT_LATTIGO_TRANSFORMS_PASSES_H_ #include "lib/Dialect/Lattigo/IR/LattigoDialect.h" +#include "lib/Dialect/Lattigo/Transforms/AllocToInplace.h" #include "lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.h" namespace mlir { diff --git a/lib/Dialect/Lattigo/Transforms/Passes.td b/lib/Dialect/Lattigo/Transforms/Passes.td index 6e465b3a9..f00f3c509 100644 --- a/lib/Dialect/Lattigo/Transforms/Passes.td +++ b/lib/Dialect/Lattigo/Transforms/Passes.td @@ -3,6 +3,15 @@ include "mlir/Pass/PassBase.td" +def AllocToInplace : Pass<"lattigo-alloc-to-inplace"> { + let summary = "Convert AllocOps to InplaceOps in Lattigo"; + let description = [{ + This pass converts AllocOps to InplaceOps in Lattigo. + + }]; + let dependentDialects = ["mlir::heir::lattigo::LattigoDialect"]; +} + def ConfigureCryptoContext : Pass<"lattigo-configure-crypto-context"> { let summary = "Configure the crypto context in Lattigo"; let description = [{ diff --git a/lib/Pipelines/ArithmeticPipelineRegistration.cpp b/lib/Pipelines/ArithmeticPipelineRegistration.cpp index 6c40745cb..6d3d6e727 100644 --- a/lib/Pipelines/ArithmeticPipelineRegistration.cpp +++ b/lib/Pipelines/ArithmeticPipelineRegistration.cpp @@ -10,6 +10,7 @@ #include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h" #include "lib/Dialect/LWE/Transforms/AddClientInterface.h" #include "lib/Dialect/LWE/Transforms/AddDebugPort.h" +#include "lib/Dialect/Lattigo/Transforms/AllocToInplace.h" #include "lib/Dialect/Lattigo/Transforms/ConfigureCryptoContext.h" #include "lib/Dialect/LinAlg/Conversions/LinalgToTensorExt/LinalgToTensorExt.h" #include "lib/Dialect/Openfhe/Transforms/ConfigureCryptoContext.h" @@ -257,6 +258,9 @@ BackendPipelineBuilder toLattigoPipelineBuilder() { // Convert LWE (and scheme-specific BGV ops) to Lattigo pm.addPass(lwe::createLWEToLattigo()); + // Convert Alloc Ops to Inplace Ops + pm.addPass(lattigo::createAllocToInplace()); + // Simplify, in case the lowering revealed redundancy pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); diff --git a/lib/Pipelines/BUILD b/lib/Pipelines/BUILD index 47b5146eb..3596d6c32 100644 --- a/lib/Pipelines/BUILD +++ b/lib/Pipelines/BUILD @@ -92,6 +92,7 @@ cc_library( "@heir//lib/Dialect/LWE/Conversions/LWEToPolynomial", "@heir//lib/Dialect/LWE/Transforms:AddClientInterface", "@heir//lib/Dialect/LWE/Transforms:AddDebugPort", + "@heir//lib/Dialect/Lattigo/Transforms:AllocToInplace", "@heir//lib/Dialect/Lattigo/Transforms:ConfigureCryptoContext", "@heir//lib/Dialect/LinAlg/Conversions/LinalgToTensorExt", "@heir//lib/Dialect/Openfhe/Transforms:ConfigureCryptoContext", diff --git a/lib/Target/Lattigo/BUILD b/lib/Target/Lattigo/BUILD index c54b46cfc..f4c8f55f1 100644 --- a/lib/Target/Lattigo/BUILD +++ b/lib/Target/Lattigo/BUILD @@ -19,6 +19,7 @@ cc_library( "@heir//lib/Dialect/Mgmt/IR:Dialect", "@heir//lib/Dialect/RNS/IR:Dialect", "@heir//lib/Utils:TargetUtils", + "@heir//lib/Utils/Tablegen:InplaceOpInterface", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", diff --git a/lib/Target/Lattigo/LattigoEmitter.cpp b/lib/Target/Lattigo/LattigoEmitter.cpp index 441416e96..ddea1f272 100644 --- a/lib/Target/Lattigo/LattigoEmitter.cpp +++ b/lib/Target/Lattigo/LattigoEmitter.cpp @@ -66,9 +66,11 @@ LogicalResult LattigoEmitter::translate(Operation &op) { RLWENewEvaluationKeySetOp, RLWEEncryptOp, RLWEDecryptOp, // BGV BGVNewParametersFromLiteralOp, BGVNewEncoderOp, BGVNewEvaluatorOp, - BGVNewPlaintextOp, BGVEncodeOp, BGVDecodeOp, BGVAddOp, BGVSubOp, - BGVMulOp, BGVRelinearizeOp, BGVRescaleOp, BGVRotateColumnsOp, - BGVRotateRowsOp, + BGVNewPlaintextOp, BGVEncodeOp, BGVDecodeOp, BGVAddNewOp, + BGVSubNewOp, BGVMulNewOp, BGVAddOp, BGVSubOp, BGVMulOp, + BGVRelinearizeOp, BGVRescaleOp, BGVRotateColumnsOp, + BGVRotateRowsOp, BGVRelinearizeNewOp, BGVRescaleNewOp, + BGVRotateColumnsNewOp, BGVRotateRowsNewOp, // CKKS CKKSNewParametersFromLiteralOp, CKKSNewEncoderOp, CKKSNewEvaluatorOp, CKKSNewPlaintextOp, CKKSEncodeOp, @@ -350,8 +352,6 @@ LogicalResult LattigoEmitter::printOperation(BGVEncodeOp op) { os << getName(op.getEncoder()) << ".Encode("; os << packedName << ", "; os << getName(op.getPlaintext()) << ")\n"; - os << getName(op.getEncoded()) << " := " << getName(op.getPlaintext()) - << "\n"; return success(); } @@ -375,32 +375,53 @@ LogicalResult LattigoEmitter::printOperation(BGVDecodeOp op) { return success(); } -LogicalResult LattigoEmitter::printOperation(BGVAddOp op) { +LogicalResult LattigoEmitter::printOperation(BGVAddNewOp op) { return printEvalNewMethod(op.getResult(), op.getEvaluator(), {op.getLhs(), op.getRhs()}, "AddNew", true); } -LogicalResult LattigoEmitter::printOperation(BGVSubOp op) { +LogicalResult LattigoEmitter::printOperation(BGVSubNewOp op) { return printEvalNewMethod(op.getResult(), op.getEvaluator(), {op.getLhs(), op.getRhs()}, "SubNew", true); } -LogicalResult LattigoEmitter::printOperation(BGVMulOp op) { +LogicalResult LattigoEmitter::printOperation(BGVMulNewOp op) { return printEvalNewMethod(op.getResult(), op.getEvaluator(), {op.getLhs(), op.getRhs()}, "MulNew", true); } -LogicalResult LattigoEmitter::printOperation(BGVRelinearizeOp op) { +LogicalResult LattigoEmitter::printOperation(BGVAddOp op) { + return printEvalInplaceMethod(op.getEvaluator(), + {op.getLhs(), op.getRhs(), op.getInplace()}, + "Add", true); +} + +LogicalResult LattigoEmitter::printOperation(BGVSubOp op) { + return printEvalInplaceMethod(op.getEvaluator(), + {op.getLhs(), op.getRhs(), op.getInplace()}, + "Sub", true); +} + +LogicalResult LattigoEmitter::printOperation(BGVMulOp op) { + return printEvalInplaceMethod(op.getEvaluator(), + {op.getLhs(), op.getRhs(), op.getInplace()}, + "Mul", true); +} + +LogicalResult LattigoEmitter::printOperation(BGVRelinearizeNewOp op) { return printEvalNewMethod(op.getOutput(), op.getEvaluator(), op.getInput(), "RelinearizeNew", true); } -LogicalResult LattigoEmitter::printOperation(BGVRescaleOp op) { - return printEvalInplaceMethod(op.getOutput(), op.getEvaluator(), - op.getInput(), op.getInput(), "Rescale", true); +LogicalResult LattigoEmitter::printOperation(BGVRescaleNewOp op) { + // there is no RescaleNew method in Lattigo, manually create new ciphertext + os << getName(op.getOutput()) << " := " << getName(op.getInput()) + << ".CopyNew()\n"; + return printEvalInplaceMethod( + op.getEvaluator(), {op.getInput(), op.getOutput()}, "Rescale", true); } -LogicalResult LattigoEmitter::printOperation(BGVRotateColumnsOp op) { +LogicalResult LattigoEmitter::printOperation(BGVRotateColumnsNewOp op) { auto errName = getErrName(); os << getName(op.getOutput()) << ", " << errName << " := " << getName(op.getEvaluator()) << ".RotateColumnsNew("; @@ -410,11 +431,36 @@ LogicalResult LattigoEmitter::printOperation(BGVRotateColumnsOp op) { return success(); } -LogicalResult LattigoEmitter::printOperation(BGVRotateRowsOp op) { +LogicalResult LattigoEmitter::printOperation(BGVRotateRowsNewOp op) { return printEvalNewMethod(op.getOutput(), op.getEvaluator(), {op.getInput()}, "RotateRowsNew", true); } +LogicalResult LattigoEmitter::printOperation(BGVRelinearizeOp op) { + return printEvalInplaceMethod( + op.getEvaluator(), {op.getInput(), op.getInplace()}, "Relinearize", true); +} + +LogicalResult LattigoEmitter::printOperation(BGVRescaleOp op) { + return printEvalInplaceMethod( + op.getEvaluator(), {op.getInput(), op.getInplace()}, "Rescale", true); +} + +LogicalResult LattigoEmitter::printOperation(BGVRotateColumnsOp op) { + auto errName = getErrName(); + os << errName << " := " << getName(op.getEvaluator()) << ".RotateColumns("; + os << getName(op.getInput()) << ", "; + os << op.getOffset().getInt() << ", "; + os << getName(op.getInplace()) << ")\n"; + printErrPanic(errName); + return success(); +} + +LogicalResult LattigoEmitter::printOperation(BGVRotateRowsOp op) { + return printEvalInplaceMethod( + op.getEvaluator(), {op.getInput(), op.getInplace()}, "RotateRows", true); +} + std::string printDenseI32ArrayAttr(DenseI32ArrayAttr attr) { std::string res = "[]int{"; res += commaSeparated(attr.asArrayRef()); @@ -641,6 +687,21 @@ LogicalResult LattigoEmitter::printEvalInplaceMethod( return success(); } +LogicalResult LattigoEmitter::printEvalInplaceMethod( + ::mlir::Value evaluator, ::mlir::ValueRange operands, std::string_view op, + bool err) { + std::string errName = getErrName(); + if (err) { + os << errName << " := "; + } + os << getName(evaluator) << "." << op << "(" + << getCommaSeparatedNames(operands) << ");\n"; + if (err) { + printErrPanic(errName); + } + return success(); +} + LogicalResult LattigoEmitter::printEvalNewMethod(::mlir::ValueRange results, ::mlir::Value evaluator, ::mlir::ValueRange operands, diff --git a/lib/Target/Lattigo/LattigoEmitter.h b/lib/Target/Lattigo/LattigoEmitter.h index 7fadf8577..c7b4074da 100644 --- a/lib/Target/Lattigo/LattigoEmitter.h +++ b/lib/Target/Lattigo/LattigoEmitter.h @@ -6,6 +6,7 @@ #include "lib/Analysis/SelectVariableNames/SelectVariableNames.h" #include "lib/Dialect/Lattigo/IR/LattigoOps.h" +#include "lib/Utils/Tablegen/InplaceOpInterface.h" #include "lib/Utils/TargetUtils.h" #include "llvm/include/llvm/Support/ManagedStatic.h" // from @llvm-project #include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project @@ -77,6 +78,13 @@ class LattigoEmitter { LogicalResult printOperation(BGVNewPlaintextOp op); LogicalResult printOperation(BGVEncodeOp op); LogicalResult printOperation(BGVDecodeOp op); + LogicalResult printOperation(BGVAddNewOp op); + LogicalResult printOperation(BGVSubNewOp op); + LogicalResult printOperation(BGVMulNewOp op); + LogicalResult printOperation(BGVRelinearizeNewOp op); + LogicalResult printOperation(BGVRescaleNewOp op); + LogicalResult printOperation(BGVRotateColumnsNewOp op); + LogicalResult printOperation(BGVRotateRowsNewOp op); LogicalResult printOperation(BGVAddOp op); LogicalResult printOperation(BGVSubOp op); LogicalResult printOperation(BGVMulOp op); @@ -111,6 +119,10 @@ class LattigoEmitter { ::mlir::Value operandInplace, std::string_view op, bool err); + LogicalResult printEvalInplaceMethod(::mlir::Value evaluator, + ::mlir::ValueRange operands, + std::string_view op, bool err); + LogicalResult printEvalNewMethod(::mlir::ValueRange results, ::mlir::Value evaluator, ::mlir::ValueRange operands, @@ -128,6 +140,18 @@ class LattigoEmitter { bool isDebugPort(::llvm::StringRef debugPortName); ::llvm::StringRef canonicalizeDebugPort(::llvm::StringRef debugPortName); + // find the actual value used for inplace op + ::mlir::Value getStorageValue(::mlir::Value value) { + if (auto *op = value.getDefiningOp()) { + if (auto inplaceOpInterface = mlir::dyn_cast(op)) { + auto inplace = + op->getOperand(inplaceOpInterface.getInplaceOperandIndex()); + return getStorageValue(inplace); + } + } + return value; + } + // helper on name and type std::string getName(::mlir::Value value) { // special case for 'nil' emission @@ -139,7 +163,7 @@ class LattigoEmitter { if (value.use_empty()) { return "_"; } - return variableNames->getNameForValue(value); + return variableNames->getNameForValue(getStorageValue(value)); } std::string getErrName() { diff --git a/lib/Utils/Tablegen/BUILD b/lib/Utils/Tablegen/BUILD new file mode 100644 index 000000000..83a803299 --- /dev/null +++ b/lib/Utils/Tablegen/BUILD @@ -0,0 +1,52 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + default_applicable_licenses = ["@heir//:license"], + default_visibility = ["//visibility:public"], +) + +td_library( + name = "td_files", + srcs = [ + "InplaceOpInterface.td", + ], + # include from the heir-root to enable fully-qualified include-paths + includes = ["../../../.."], +) + +cc_library( + name = "InplaceOpInterface", + srcs = [ + "InplaceOpInterface.cpp", + ], + hdrs = [ + "InplaceOpInterface.h", + ], + deps = [ + ":inplace_op_interface_inc_gen", + "@llvm-project//mlir:IR", + ], +) + +gentbl_cc_library( + name = "inplace_op_interface_inc_gen", + tbl_outs = [ + ( + ["--gen-op-interface-decls"], + "InplaceOpInterface.h.inc", + ), + ( + ["--gen-op-interface-defs"], + "InplaceOpInterface.cpp.inc", + ), + ( + ["-gen-op-interface-docs"], + "InplaceOpInterface.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "InplaceOpInterface.td", + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + ], +) diff --git a/lib/Utils/Tablegen/InplaceOpInterface.cpp b/lib/Utils/Tablegen/InplaceOpInterface.cpp new file mode 100644 index 000000000..b6141a437 --- /dev/null +++ b/lib/Utils/Tablegen/InplaceOpInterface.cpp @@ -0,0 +1,3 @@ +#include "lib/Utils/Tablegen/InplaceOpInterface.h" + +#include "lib/Utils/Tablegen/InplaceOpInterface.cpp.inc" diff --git a/lib/Utils/Tablegen/InplaceOpInterface.h b/lib/Utils/Tablegen/InplaceOpInterface.h new file mode 100644 index 000000000..effa7f2f5 --- /dev/null +++ b/lib/Utils/Tablegen/InplaceOpInterface.h @@ -0,0 +1,11 @@ +#ifndef LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_H_ +#define LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_H_ + +#include + +#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project + +// Block clang-format from reordering +#include "lib/Utils/Tablegen/InplaceOpInterface.h.inc" + +#endif // LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_H_ diff --git a/lib/Utils/Tablegen/InplaceOpInterface.td b/lib/Utils/Tablegen/InplaceOpInterface.td new file mode 100644 index 000000000..7977f4f4b --- /dev/null +++ b/lib/Utils/Tablegen/InplaceOpInterface.td @@ -0,0 +1,22 @@ +#ifndef LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_TD_ +#define LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_TD_ + +include "mlir/IR/Interfaces.td" + +def InplaceOpInterface: OpInterface<"InplaceOpInterface"> { + let cppNamespace = "::mlir::heir"; + + let description = [{ + Interface for ops to to tell which operand is the same as the result. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Return the inplace operand for this op.", + /*retTy=*/"int", + /*methodName=*/"getInplaceOperandIndex" + >, + ]; +} + +#endif // LIB_UTILS_TABLEGEN_INPLACEOPINTERFACE_TD_ diff --git a/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir b/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir index d3fadc1b7..540e9d78b 100644 --- a/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir +++ b/tests/Dialect/BGV/Conversions/bgv_to_lattigo/bgv_to_lattigo.mlir @@ -32,17 +32,17 @@ module attributes {scheme.bgv} { // CHECK-LABEL: @test_ops // CHECK-SAME: ([[C:%.+]]: [[S:.*evaluator]], [[X:%.+]]: [[T:!lattigo.rlwe.ciphertext]], [[Y:%.+]]: [[T]]) func.func @test_ops(%x : !ct, %y : !ct) { - // CHECK: %[[v1:.*]] = lattigo.bgv.add [[C]], %[[x:.*]], %[[y:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] + // CHECK: %[[v1:.*]] = lattigo.bgv.add_new [[C]], %[[x:.*]], %[[y:.*]]: ([[S]], [[T]], [[T]]) -> [[T]] %add = bgv.add %x, %y : (!ct, !ct) -> !ct - // CHECK: %[[mul:.*]] = lattigo.bgv.mul [[C]], %[[x]], %[[y]]: ([[S]], [[T]], [[T]]) -> [[T]] + // CHECK: %[[mul:.*]] = lattigo.bgv.mul_new [[C]], %[[x]], %[[y]]: ([[S]], [[T]], [[T]]) -> [[T]] %mul = bgv.mul %x, %y : (!ct, !ct) -> !ct1 - // CHECK: %[[relin:.*]] = lattigo.bgv.relinearize [[C]], %[[mul]] : ([[S]], [[T]]) -> [[T]] + // CHECK: %[[relin:.*]] = lattigo.bgv.relinearize_new [[C]], %[[mul]] : ([[S]], [[T]]) -> [[T]] %relin = bgv.relinearize %mul { from_basis = array, to_basis = array }: !ct1 -> !ct - // CHECK: %[[rescale:.*]] = lattigo.bgv.rescale [[C]], %[[relin]] : ([[S]], [[T]]) -> [[T]] + // CHECK: %[[rescale:.*]] = lattigo.bgv.rescale_new [[C]], %[[relin]] : ([[S]], [[T]]) -> [[T]] %rescale = bgv.modulus_switch %relin {to_ring = #ring_rns_L0_1_x1024_} : !ct -> !ct2 - // CHECK: %[[rot:.*]] = lattigo.bgv.rotate_columns [[C]], %[[rescale]] {offset = 1 : i64} : ([[S]], [[T]]) -> [[T]] + // CHECK: %[[rot:.*]] = lattigo.bgv.rotate_columns_new [[C]], %[[rescale]] {offset = 1 : i64} : ([[S]], [[T]]) -> [[T]] %rot = bgv.rotate %rescale { offset = 1 } : !ct2 return } diff --git a/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir b/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir index d481523f2..7d9b07794 100644 --- a/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir +++ b/tests/Dialect/Lattigo/Emitters/emit_lattigo.mlir @@ -39,16 +39,16 @@ module attributes {scheme.bgv} { // CHECK: [[ct2:[^, ].*]], [[err:.*]] := [[evaluator]].AddNew([[ct]], [[ct1]]) // CHECK: [[ct3:[^, ].*]], [[err:.*]] := [[evaluator]].MulNew([[ct2]], [[ct1]]) // CHECK: [[ct4:[^, ].*]], [[err:.*]] := [[evaluator]].RelinearizeNew([[ct3]]) - // CHECK: [[err:.*]] := [[evaluator]].Rescale([[ct4]], [[ct4]]) - // CHECK: [[ct5:[^, ].*]] := [[ct4]] + // CHECK: [[ct5:[^, ].*]] := [[ct4]].CopyNew() + // CHECK: [[err:.*]] := [[evaluator]].Rescale([[ct4]], [[ct5]]) // CHECK: [[ct6:[^, ].*]], [[err:.*]] := [[evaluator]].RotateColumnsNew([[ct5]], 1) // CHECK: return [[ct6]] func.func @compute(%evaluator : !evaluator, %ct1 : !ct, %ct2 : !ct) -> (!ct) { - %added = lattigo.bgv.add %evaluator, %ct1, %ct2 : (!evaluator, !ct, !ct) -> !ct - %mul = lattigo.bgv.mul %evaluator, %added, %ct2 : (!evaluator, !ct, !ct) -> !ct - %relin = lattigo.bgv.relinearize %evaluator, %mul : (!evaluator, !ct) -> !ct - %rescale = lattigo.bgv.rescale %evaluator, %relin : (!evaluator, !ct) -> !ct - %rotate = lattigo.bgv.rotate_columns %evaluator, %rescale {offset = 1} : (!evaluator, !ct) -> !ct + %added = lattigo.bgv.add_new %evaluator, %ct1, %ct2 : (!evaluator, !ct, !ct) -> !ct + %mul = lattigo.bgv.mul_new %evaluator, %added, %ct2 : (!evaluator, !ct, !ct) -> !ct + %relin = lattigo.bgv.relinearize_new %evaluator, %mul : (!evaluator, !ct) -> !ct + %rescale = lattigo.bgv.rescale_new %evaluator, %relin : (!evaluator, !ct) -> !ct + %rotate = lattigo.bgv.rotate_columns_new %evaluator, %rescale {offset = 1} : (!evaluator, !ct) -> !ct return %rotate : !ct } @@ -75,12 +75,10 @@ module attributes {scheme.bgv} { // CHECK: [[pt2:[^, ].*]] := bgv.NewPlaintext([[param]], [[param]].MaxLevel()) // CHECK: [[value1Packed:[^, ].*]][i] = int64([[value1]][i % len([[value1]])]) // CHECK: [[encoder]].Encode([[value1Packed]], [[pt1]]) - // CHECK: [[pt3:[^, ].*]] := [[pt1]] // CHECK: [[value2Packed:[^, ].*]][i] = int64([[value2]][i % len([[value2]])]) // CHECK: [[encoder]].Encode([[value2Packed]], [[pt2]]) - // CHECK: [[pt4:[^, ].*]] := [[pt2]] - // CHECK: [[ct1:[^, ].*]], [[err:.*]] := [[enc]].EncryptNew([[pt3]]) - // CHECK: [[ct2:[^, ].*]], [[err:.*]] := [[enc]].EncryptNew([[pt4]]) + // CHECK: [[ct1:[^, ].*]], [[err:.*]] := [[enc]].EncryptNew([[pt1]]) + // CHECK: [[ct2:[^, ].*]], [[err:.*]] := [[enc]].EncryptNew([[pt2]]) // CHECK: [[res:[^, ].*]] := compute([[eval]], [[ct1]], [[ct2]]) // CHECK: [[pt5:[^, ].*]] := [[dec]].DecryptNew([[res]]) // CHECK: [[value3:[^, ].*]] := []int64 diff --git a/tests/Dialect/Lattigo/IR/bgv_ops.mlir b/tests/Dialect/Lattigo/IR/bgv_ops.mlir index 14e7e7ae1..a5d3e8c40 100644 --- a/tests/Dialect/Lattigo/IR/bgv_ops.mlir +++ b/tests/Dialect/Lattigo/IR/bgv_ops.mlir @@ -80,24 +80,45 @@ module { return } + // CHECK-LABEL: func @test_bgv_add_new + func.func @test_bgv_add_new(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.add_new + %output = lattigo.bgv.add_new %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + return + } + + // CHECK-LABEL: func @test_bgv_sub_new + func.func @test_bgv_sub_new(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.sub_new + %output = lattigo.bgv.sub_new %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + return + } + + // CHECK-LABEL: func @test_bgv_mul_new + func.func @test_bgv_mul_new(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.mul_new + %output = lattigo.bgv.mul_new %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + return + } + // CHECK-LABEL: func @test_bgv_add func.func @test_bgv_add(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.add - %output = lattigo.bgv.add %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + %output = lattigo.bgv.add %evaluator, %lhs, %rhs, %lhs : (!evaluator, !ct, !ct, !ct) -> !ct return } // CHECK-LABEL: func @test_bgv_sub func.func @test_bgv_sub(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.sub - %output = lattigo.bgv.sub %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + %output = lattigo.bgv.sub %evaluator, %lhs, %rhs, %lhs : (!evaluator, !ct, !ct, !ct) -> !ct return } // CHECK-LABEL: func @test_bgv_mul func.func @test_bgv_mul(%evaluator: !evaluator, %lhs: !ct, %rhs: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.mul - %output = lattigo.bgv.mul %evaluator, %lhs, %rhs : (!evaluator, !ct, !ct) -> !ct + %output = lattigo.bgv.mul %evaluator, %lhs, %rhs, %lhs : (!evaluator, !ct, !ct, !ct) -> !ct return } @@ -108,31 +129,59 @@ module { return } + // CHECK-LABEL: func @test_bgv_relinearize_new + func.func @test_bgv_relinearize_new(%evaluator: !evaluator, %ct: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.relinearize_new + %output = lattigo.bgv.relinearize_new %evaluator, %ct : (!evaluator, !ct) -> !ct + return + } + + // CHECK-LABEL: func @test_bgv_rescale_new + func.func @test_bgv_rescale_new(%evaluator: !evaluator, %ct: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.rescale_new + %output = lattigo.bgv.rescale_new %evaluator, %ct : (!evaluator, !ct) -> !ct + return + } + + // CHECK-LABEL: func @test_bgv_rotate_columns_new + func.func @test_bgv_rotate_columns_new(%evaluator: !evaluator, %ct: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.rotate_columns_new + %output = lattigo.bgv.rotate_columns_new %evaluator, %ct {offset = 1} : (!evaluator, !ct) -> !ct + return + } + + // CHECK-LABEL: func @test_bgv_rotate_rows_new + func.func @test_bgv_rotate_rows_new(%evaluator: !evaluator, %ct: !ct) { + // CHECK: %[[v1:.*]] = lattigo.bgv.rotate_rows_new + %output = lattigo.bgv.rotate_rows_new %evaluator, %ct : (!evaluator, !ct) -> !ct + return + } + // CHECK-LABEL: func @test_bgv_relinearize func.func @test_bgv_relinearize(%evaluator: !evaluator, %ct: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.relinearize - %output = lattigo.bgv.relinearize %evaluator, %ct : (!evaluator, !ct) -> !ct + %output = lattigo.bgv.relinearize %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return } // CHECK-LABEL: func @test_bgv_rescale func.func @test_bgv_rescale(%evaluator: !evaluator, %ct: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.rescale - %output = lattigo.bgv.rescale %evaluator, %ct : (!evaluator, !ct) -> !ct + %output = lattigo.bgv.rescale %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return } // CHECK-LABEL: func @test_bgv_rotate_columns func.func @test_bgv_rotate_columns(%evaluator: !evaluator, %ct: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.rotate_columns - %output = lattigo.bgv.rotate_columns %evaluator, %ct {offset = 1} : (!evaluator, !ct) -> !ct + %output = lattigo.bgv.rotate_columns %evaluator, %ct, %ct {offset = 1} : (!evaluator, !ct, !ct) -> !ct return } // CHECK-LABEL: func @test_bgv_rotate_rows func.func @test_bgv_rotate_rows(%evaluator: !evaluator, %ct: !ct) { // CHECK: %[[v1:.*]] = lattigo.bgv.rotate_rows - %output = lattigo.bgv.rotate_rows %evaluator, %ct : (!evaluator, !ct) -> !ct + %output = lattigo.bgv.rotate_rows %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return } } diff --git a/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir new file mode 100644 index 000000000..b61be54d8 --- /dev/null +++ b/tests/Dialect/Lattigo/Transforms/alloc_to_inplace_add.mlir @@ -0,0 +1,11 @@ +// RUN: heir-opt --mlir-to-bgv --bgv-to-lwe --lwe-to-lattigo --lattigo-alloc-to-inplace %s | FileCheck %s + +// CHECK-LABEL: func.func @add +func.func @add(%arg0 : i16 {secret.secret}) -> i16 { + // CHECK-COUNT-3: lattigo.bgv.add + // CHECK-NOT: lattigo.bgv.add_new + %0 = arith.addi %arg0, %arg0 : i16 + %1 = arith.addi %0, %0 : i16 + %2 = arith.addi %1, %1 : i16 + return %2 : i16 +} diff --git a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_add.mlir b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_add.mlir index 612ac1cc3..2349a8297 100644 --- a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_add.mlir +++ b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_add.mlir @@ -6,7 +6,7 @@ module attributes {scheme.bgv} { func.func @add(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return %res : !ct } } diff --git a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect.mlir b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect.mlir index 8f9c55f7c..ae988465e 100644 --- a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect.mlir +++ b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect.mlir @@ -8,7 +8,7 @@ module attributes {scheme.bgv} { func.func @add(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return %res : !ct } } @@ -26,11 +26,11 @@ module attributes {scheme.bgv} { module attributes {scheme.bgv} { func.func @sub(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return %res : !ct } func.func @add(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct %sub = call @sub(%evaluator, %res) : (!evaluator, !ct) -> !ct return %sub : !ct } @@ -51,7 +51,7 @@ module attributes {scheme.bgv} { module attributes {scheme.bgv} { func.func private @sub(%evaluator : !evaluator, %ct : !ct) -> !ct func.func @add(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return %res : !ct } } diff --git a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect_diagnostic.mlir b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect_diagnostic.mlir index dd947d642..7eba1468b 100644 --- a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect_diagnostic.mlir +++ b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_detect_diagnostic.mlir @@ -7,7 +7,7 @@ // expected-warning@+1 {{Entry function not found, please provide entry-function in the pass options}} module attributes {scheme.bgv} { func.func @__add(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.add %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.add_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct return %res : !ct } } diff --git a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_relin.mlir b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_relin.mlir index 390a95bb8..16a86d2f2 100644 --- a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_relin.mlir +++ b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_relin.mlir @@ -6,8 +6,8 @@ module attributes {scheme.bgv} { func.func @relin(%evaluator : !evaluator, %ct : !ct) -> !ct { - %ct1 = lattigo.bgv.mul %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct - %res = lattigo.bgv.relinearize %evaluator, %ct1 : (!evaluator, !ct) -> !ct + %ct1 = lattigo.bgv.mul_new %evaluator, %ct, %ct : (!evaluator, !ct, !ct) -> !ct + %res = lattigo.bgv.relinearize_new %evaluator, %ct1 : (!evaluator, !ct) -> !ct return %res : !ct } } diff --git a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_rotate.mlir b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_rotate.mlir index 4c887bea6..41e33459f 100644 --- a/tests/Dialect/Lattigo/Transforms/configure_crypto_context_rotate.mlir +++ b/tests/Dialect/Lattigo/Transforms/configure_crypto_context_rotate.mlir @@ -6,7 +6,7 @@ module attributes {scheme.bgv} { func.func @rotate(%evaluator : !evaluator, %ct : !ct) -> !ct { - %res = lattigo.bgv.rotate_columns %evaluator, %ct {offset = 1} : (!evaluator, !ct) -> !ct + %res = lattigo.bgv.rotate_columns_new %evaluator, %ct {offset = 1} : (!evaluator, !ct) -> !ct return %res : !ct } } diff --git a/tools/BUILD b/tools/BUILD index bf5c5edee..af6062eb5 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -59,6 +59,7 @@ cc_binary( "@heir//lib/Dialect/LWE/Transforms:AddClientInterface", "@heir//lib/Dialect/Lattigo/IR:Dialect", "@heir//lib/Dialect/Lattigo/Transforms", + "@heir//lib/Dialect/Lattigo/Transforms:AllocToInplace", "@heir//lib/Dialect/Lattigo/Transforms:ConfigureCryptoContext", "@heir//lib/Dialect/LinAlg/Conversions/LinalgToTensorExt", "@heir//lib/Dialect/Mgmt/IR:Dialect",