Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement alloc-to-inplace pass to support inplace ops #1407

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ZenithalHourlyRate
Copy link
Collaborator

See #1327.

Lattigo's API prefers the inplace version. Actually it is more flexible in that it has the form

Add(op0, op1, opOut)

so a more dedicated analysis could use some no longer used buffer as the result buffer, instead of the inplace buffer.

Currently, I implement lattigo.bgv.add op0, op1 in the IR and use op0 as opOut in the emitter.

Discussion

  • This pass obviously could also be applied to Openfhe backend. I currently put it under the lattigo folder but we can decide whether it should live in Transforms
  • I had another try on this by using canonicalizer to rewrite alloc to inplace but found the getUses() is hard to deal with (it might be unordered)

TODO

  • For some API Lattigo offers both the alloc version and inplace version but for some they only offer the inplace version like bgv.Rescale. I previously pretend them all to be alloc version in the IR and handle these difference in emitter. I should resolve the difference later so that we wont have some error caused by such difference (fake alloc op backed by inplace op and the input has more than one uses)
  • Add Tests.

Example

Pure addition

The input

func.func @add(%arg0 : i16 {secret.secret}) -> i16 {
    %0 = arith.addi %arg0, %arg0 : i16
    %1 = arith.addi %0, %0 : i16
    %2 = arith.addi %1, %1 : i16
    %3 = arith.addi %2, %2 : i16
    %4 = arith.addi %3, %3 : i16
    %5 = arith.addi %4, %4 : i16
    return %5 : i16
} 

becomes

    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    lattigo.bgv.add %evaluator, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    return %ct : !lattigo.rlwe.ciphertext

dot product

    %cst = arith.constant dense<[0, 0, 0, 0, 0, 0, 0, 1]> : tensor<8xi16>
    lattigo.bgv.mul %evaluator, %ct, %ct_0 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    %ct_1 = lattigo.bgv.relinearize %evaluator, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    %ct_2 = lattigo.bgv.rotate_columns %evaluator, %ct_1 {offset = 4 : index} : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    lattigo.bgv.add %evaluator, %ct_1, %ct_2 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    %ct_3 = lattigo.bgv.rotate_columns %evaluator, %ct_1 {offset = 2 : index} : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    lattigo.bgv.add %evaluator, %ct_1, %ct_3 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    %ct_4 = lattigo.bgv.rotate_columns %evaluator, %ct_1 {offset = 1 : index} : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    lattigo.bgv.add %evaluator, %ct_1, %ct_4 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> ()
    %ct_5 = lattigo.bgv.rescale %evaluator, %ct_1 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    %pt = lattigo.bgv.new_plaintext %param : (!lattigo.bgv.parameter) -> !lattigo.rlwe.plaintext
    %pt_6 = lattigo.bgv.encode %encoder, %cst, %pt : (!lattigo.bgv.encoder, tensor<8xi16>, !lattigo.rlwe.plaintext) -> !lattigo.rlwe.plaintext
    lattigo.bgv.mul %evaluator, %ct_5, %pt_6 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.plaintext) -> ()
    %ct_7 = lattigo.bgv.rotate_columns %evaluator, %ct_5 {offset = 7 : index} : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    %ct_8 = lattigo.bgv.rescale %evaluator, %ct_7 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    return %ct_8 : !lattigo.rlwe.ciphertext

namely there are only 10 ct. In the alloc version, there are 15 ct in total.

Copy link
Collaborator

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I support the idea behind this PR, the decision to model the ops as mutating the operand violates the SSA requirement of MLIR, and that will cause problems down the line as other transformations that rely on SSA properties are applied. For example, nothing would prevent a canonicalization pass from reordering in-place instructions (though nothing would cause it to right now).

The memref pass gets around this by having the SSA value be an (unchanging) pointer to memory, while the memory inside it may change without violating SSA. The tensor dialect gets around this by having you specify the output tensor as an argument, and having the result SSA value semantically represent that output tensor. I think the tensor strategy would work well for this case, and then the emitter would have to record the transitive references and drop the explicit results to get back to the expected in-place style.

On the other hand, since lattigo/openfhe are exit dialects, we could ignore all this and blatantly violate SSA if we agree this pass is the last thing we would do before codegen. However, we have talked before about potentially supporting openfhe as an entry dialect in the future (parsing C++ programs written against OpenFHE's API and lowering to polynomial, then to an accelerator). While that is highly speculative, respecting SSA now would be more future proof for possibilities like that in the future.

If we do "burn the SSA bridge", it would be worthwhile to document this clearly on the pass's tablegen.

@ZenithalHourlyRate ZenithalHourlyRate force-pushed the lattigo-bgv-inplace branch 2 times, most recently from 7b46411 to 9302c15 Compare February 19, 2025 15:38
@ZenithalHourlyRate
Copy link
Collaborator Author

ZenithalHourlyRate commented Feb 19, 2025

I think the tensor strategy would work well for this case, and then the emitter would have to record the transitive references and drop the explicit results to get back to the expected in-place style.

Implemented in the tensor strategy where evaluator.Add(op0, op1, opOut) is modelled as output = bgv.add evaluator, op0, op1, opOut and an InplaceOpInterface is implemented to indicate which operand is the inplace storage for output, and emitter would use such interface to trace back to the storage value

Note that OpAsmOpInterface impl is overriden so the SSA name is not elegant. Will be fixed by OpAsmTypeInterface recently upstreamed.

Example

    %0 = lattigo.bgv.add %evaluator, %ct, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    %1 = lattigo.bgv.add %evaluator, %0, %0, %0 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
    %2 = lattigo.bgv.add %evaluator, %1, %1, %1 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext

Emitted to

  err0 := evaluator.Add(ct, ct, ct);
  err1 := evaluator.Add(ct, ct, ct);
  err2 := evaluator.Add(ct, ct, ct);

For dot product we have the following, where only 3 new allocation is made (actually ct6/ct8 can reuse ct4 but that is further optimization)

  v0 := []int64{0, 0, 0, 0, 0, 0, 0, 1}
  err0 := evaluator.Mul(ct, ct1, ct);
  err1 := evaluator.Relinearize(ct, ct);
  ct4, err2 := evaluator.RotateColumnsNew(ct, 4)
  err3 := evaluator.Add(ct, ct4, ct);
  ct6, err4 := evaluator.RotateColumnsNew(ct, 2)
  err5 := evaluator.Add(ct, ct6, ct);
  ct8, err6 := evaluator.RotateColumnsNew(ct, 1)
  err7 := evaluator.Add(ct, ct8, ct);
  err8 := evaluator.Rescale(ct, ct);
  err9 := evaluator.Mul(ct, pt, ct);
  err10 := evaluator.RotateColumns(ct, 7, ct)
  err11 := evaluator.Rescale(ct, ct);
  return ct

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a future idea, I can imagine a generalized version of this pass that could be applied to other backends like openfhe, in which the pass operates on a general Operation *, dyn_casts the op to an ConvertibleToInplaceOpInterface (this would be an interface attached to add_new rather than attached to the inplace add op). Then this interface has a method that handles the rewriting from non-inplace to inplace internally.

That said, if it's only Lattigo and OpenFHE that have this, it may not be worth the abstraction just to support two things. If we add more library API backends (say, if jaxite or tfhe-rust want this), it may be worth it.

LogicalResult matchAndRewrite(UnaryOp op,
PatternRewriter &rewriter) const override {
// operand 0 is evaluator
auto lhs = op.getOperand(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be misunderstanding, but why not use the op interface you created to fetch the operand that is mutated?

Copy link
Collaborator Author

@ZenithalHourlyRate ZenithalHourlyRate Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well there is a further optimization that find previous dead value as the storage opOut, instead of mutating input operand inplace.

ZenithalHourlyRate@56d6228

In Lattigo, for Add(op0, op1, opOut), opOut can be either op0, op1 or other previously allocated ciphertext. I write the pass now with op0 always as opOut for simplicity as you can see the detection algorithm above has some complicacy and worth a separate PR.

For that optimization, the result would be the following, where no new allocation is made

  err0 := evaluator.Mul(ct, ct1, ct);
  err1 := evaluator.Relinearize(ct, ct);
  err2 := evaluator.RotateColumns(ct, 4, ct1)
  err3 := evaluator.Add(ct, ct1, ct);
  err4 := evaluator.RotateColumns(ct, 2, ct1)
  err5 := evaluator.Add(ct, ct1, ct);
  err6 := evaluator.RotateColumns(ct, 1, ct1)
  err7 := evaluator.Add(ct, ct1, ct);
  err8 := evaluator.Rescale(ct, ct);
  err9 := evaluator.Mul(ct, pt, ct);
  err10 := evaluator.RotateColumns(ct, 7, ct)
  err11 := evaluator.Rescale(ct, ct);
  return ct

Does it make sense to return 1 here, and remove the inplace operand? I don't think you need the extra operand if you have this op interface to identify it.

So the InplaceOpInterface will only track opOut instead of op0, which semantically suggests output = opOut but not op0 = opOut = output. For Openfhe API we always have op0 = opOut but it is not the case for Lattigo.

);
let results = (outs Lattigo_RLWECiphertext:$output);

let extraClassDeclaration = "int getInplaceOperandIndex() { return 3; }";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to return 1 here, and remove the inplace operand? I don't think you need the extra operand if you have this op interface to identify it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants