Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement alloc-to-inplace pass to support inplace ops #1407

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions lib/Dialect/LWE/Conversions/LWEToLattigo/LWEToLattigo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,33 +430,33 @@ struct ConvertLWEReinterpretUnderlyingType
} // namespace

// BGV
using ConvertBGVAddOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp, lattigo::BGVAddOp>;
using ConvertBGVSubOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp, lattigo::BGVSubOp>;
using ConvertBGVMulOp =
ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp, lattigo::BGVMulOp>;
using ConvertBGVAddOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RAddOp,
lattigo::BGVAddNewOp>;
using ConvertBGVSubOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RSubOp,
lattigo::BGVSubNewOp>;
using ConvertBGVMulOp = ConvertRlweBinOp<lattigo::BGVEvaluatorType, lwe::RMulOp,
lattigo::BGVMulNewOp>;
using ConvertBGVAddPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::AddPlainOp,
lattigo::BGVAddOp>;
lattigo::BGVAddNewOp>;
using ConvertBGVSubPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::SubPlainOp,
lattigo::BGVSubOp>;
lattigo::BGVSubNewOp>;
using ConvertBGVMulPlainOp =
ConvertRlwePlainOp<lattigo::BGVEvaluatorType, bgv::MulPlainOp,
lattigo::BGVMulOp>;
lattigo::BGVMulNewOp>;

using ConvertBGVRelinOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, bgv::RelinearizeOp,
lattigo::BGVRelinearizeOp>;
lattigo::BGVRelinearizeNewOp>;
using ConvertBGVModulusSwitchOp =
ConvertRlweUnaryOp<lattigo::BGVEvaluatorType, bgv::ModulusSwitchOp,
lattigo::BGVRescaleOp>;
lattigo::BGVRescaleNewOp>;

// TODO(#1186): figure out generic rotating using BGVRotateColumns/RowsOp
using ConvertBGVRotateOp =
ConvertRlweRotateOp<lattigo::BGVEvaluatorType, bgv::RotateOp,
lattigo::BGVRotateColumnsOp>;
lattigo::BGVRotateColumnsNewOp>;

using ConvertBGVEncryptOp =
ConvertRlweUnaryOp<lattigo::RLWEEncryptorType, lwe::RLWEEncryptOp,
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
":ops_inc_gen",
":types_inc_gen",
"@heir//lib/Utils/Tablegen:AsmInterfaces",
"@heir//lib/Utils/Tablegen:InplaceOpInterface",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
Expand Down
131 changes: 123 additions & 8 deletions lib/Dialect/Lattigo/IR/LattigoBGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def Lattigo_BGVNewEncoderOp : Lattigo_BGVOp<"new_encoder"> {
let results = (outs Lattigo_BGVEncoder:$encoder);
}

def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> {
def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode", [InplaceOpInterface]> {
let summary = "Encode a plaintext value in the Lattigo BGV dialect";
let description = [{
This operation encodes a plaintext value using the specified encoder in the Lattigo BGV dialect.
Expand All @@ -55,6 +55,8 @@ def Lattigo_BGVEncodeOp : Lattigo_BGVOp<"encode"> {
Lattigo_RLWEPlaintext:$plaintext
);
let results = (outs Lattigo_RLWEPlaintext:$encoded);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "decoded"]>]> {
Expand All @@ -69,6 +71,8 @@ def Lattigo_BGVDecodeOp : Lattigo_BGVOp<"decode", [AllTypesMatch<["value", "deco
Lattigo_RLWEPlaintext:$plaintext,
RankedTensorOf<[AnyInteger]>:$value
);
// although bgv.Decode is also an inplace operation as bgv.Encode, as there are post-processing
// steps in emitter, we mark it as a normal operation.
let results = (outs RankedTensorOf<[AnyInteger]>:$decoded);
}

Expand Down Expand Up @@ -102,27 +106,72 @@ class Lattigo_BGVBinaryOp<string mnemonic> :
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVAddOp : Lattigo_BGVBinaryOp<"add"> {
def Lattigo_BGVAddNewOp : Lattigo_BGVBinaryOp<"add_new"> {
let summary = "Add two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation adds two ciphertext values in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVSubOp : Lattigo_BGVBinaryOp<"sub"> {
def Lattigo_BGVSubNewOp : Lattigo_BGVBinaryOp<"sub_new"> {
let summary = "Subtract two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation subtracts one ciphertext value from another in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVMulOp : Lattigo_BGVBinaryOp<"mul"> {
def Lattigo_BGVMulNewOp : Lattigo_BGVBinaryOp<"mul_new"> {
let summary = "Multiply two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation multiplies two ciphertext values in the Lattigo BGV dialect.
}];
}

class Lattigo_BGVBinaryInplaceOp<string mnemonic> :
Lattigo_BGVOp<mnemonic, [InplaceOpInterface]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$lhs,
Lattigo_RLWECiphertextOrPlaintext:$rhs,
// Lattigo API is like bgv.Add(lhs, rhs, out) but for MLIR we need to
// satisfy the SSA form, so we still have a separate output.
Lattigo_RLWECiphertext:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 3; }";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to return 1 here, and remove the inplace operand? I don't think you need the extra operand if you have this op interface to identify it.

}

def Lattigo_BGVAddOp : Lattigo_BGVBinaryInplaceOp<"add"> {
let summary = "Add two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation adds two ciphertext values in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVSubOp : Lattigo_BGVBinaryInplaceOp<"sub"> {
let summary = "Subtract two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation subtracts one ciphertext value from another in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVMulOp : Lattigo_BGVBinaryInplaceOp<"mul"> {
let summary = "Multiply two ciphertexts in the Lattigo BGV dialect";
let description = [{
This operation multiplies two ciphertext values in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

class Lattigo_BGVUnaryOp<string mnemonic> :
Lattigo_BGVOp<mnemonic> {
let arguments = (ins
Expand All @@ -132,43 +181,109 @@ class Lattigo_BGVUnaryOp<string mnemonic> :
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryOp<"relinearize"> {
def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> {
let summary = "Relinearize a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation relinearizes a ciphertext value in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.
}];
}

def Lattigo_BGVRotateColumnsNewOp : Lattigo_BGVOp<"rotate_columns_new"> {
let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column.

Offset is valid for both positive and negative number.
}];
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Builtin_IntegerAttr:$offset
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRotateRowsNewOp : Lattigo_BGVUnaryOp<"rotate_rows_new"> {
let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation swap the rows of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row.
}];
}

class Lattigo_BGVUnaryInplaceOp<string mnemonic> :
Lattigo_BGVOp<mnemonic, [InplaceOpInterface]> {
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
// see BinaryInplaceOp above
Lattigo_RLWECiphertext:$inplace
);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 2; }";
}

def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInplaceOp<"relinearize"> {
let summary = "Relinearize a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation relinearizes a ciphertext value in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryOp<"rescale"> {
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInplaceOp<"rescale"> {
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rescales a ciphertext value in the Lattigo BGV dialect.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

def Lattigo_BGVRotateColumnsOp : Lattigo_BGVOp<"rotate_columns"> {
def Lattigo_BGVRotateColumnsOp : Lattigo_BGVUnaryInplaceOp<"rotate_columns"> {
let summary = "Rotate columns of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation rotates the columns of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where N/2 is the column.

Offset is valid for both positive and negative number.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
let arguments = (ins
Lattigo_BGVEvaluator:$evaluator,
Lattigo_RLWECiphertext:$input,
Lattigo_RLWECiphertext:$inplace,
Builtin_IntegerAttr:$offset
);
let results = (outs Lattigo_RLWECiphertext:$output);
}

def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryOp<"rotate_rows"> {
def Lattigo_BGVRotateRowsOp : Lattigo_BGVUnaryInplaceOp<"rotate_rows"> {
let summary = "Rotate rows of a ciphertext in the Lattigo BGV dialect";
let description = [{
This operation swap the rows of a ciphertext value in the Lattigo BGV dialect.

Lattigo exposes the SIMD slot of BGV as a N/2 x 2 matrix, where 2 is the row.

The result will be written to the `inplace` operand. The `output`result is
a transitive reference to the `inplace` operand for sake of the MLIR SSA form.
}];
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "lib/Dialect/Lattigo/IR/LattigoDialect.h"
#include "lib/Dialect/Lattigo/IR/LattigoTypes.h"
#include "lib/Utils/Tablegen/InplaceOpInterface.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project

#define GET_OP_CLASSES
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Lattigo/IR/LattigoOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include "LattigoDialect.td"
include "LattigoTypes.td"
include "mlir/IR/OpBase.td"
include "mlir/IR/OpAsmInterface.td"
include "lib/Utils/Tablegen/InplaceOpInterface.td"

class Lattigo_Op<string mnemonic, list<Trait> traits = []> :
Op<Lattigo_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
Expand Down
Loading
Loading