-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement alloc-to-inplace pass to support inplace ops #1407
base: main
Are you sure you want to change the base?
Implement alloc-to-inplace pass to support inplace ops #1407
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I support the idea behind this PR, the decision to model the ops as mutating the operand violates the SSA requirement of MLIR, and that will cause problems down the line as other transformations that rely on SSA properties are applied. For example, nothing would prevent a canonicalization pass from reordering in-place instructions (though nothing would cause it to right now).
The memref
pass gets around this by having the SSA value be an (unchanging) pointer to memory, while the memory inside it may change without violating SSA. The tensor
dialect gets around this by having you specify the output tensor as an argument, and having the result SSA value semantically represent that output tensor. I think the tensor strategy would work well for this case, and then the emitter would have to record the transitive references and drop the explicit results to get back to the expected in-place style.
On the other hand, since lattigo/openfhe are exit dialects, we could ignore all this and blatantly violate SSA if we agree this pass is the last thing we would do before codegen. However, we have talked before about potentially supporting openfhe as an entry dialect in the future (parsing C++ programs written against OpenFHE's API and lowering to polynomial, then to an accelerator). While that is highly speculative, respecting SSA now would be more future proof for possibilities like that in the future.
If we do "burn the SSA bridge", it would be worthwhile to document this clearly on the pass's tablegen.
7b46411
to
9302c15
Compare
Implemented in the tensor strategy where Note that OpAsmOpInterface impl is overriden so the SSA name is not elegant. Will be fixed by OpAsmTypeInterface recently upstreamed. Example %0 = lattigo.bgv.add %evaluator, %ct, %ct, %ct : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
%1 = lattigo.bgv.add %evaluator, %0, %0, %0 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext
%2 = lattigo.bgv.add %evaluator, %1, %1, %1 : (!lattigo.bgv.evaluator, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext, !lattigo.rlwe.ciphertext) -> !lattigo.rlwe.ciphertext Emitted to err0 := evaluator.Add(ct, ct, ct);
err1 := evaluator.Add(ct, ct, ct);
err2 := evaluator.Add(ct, ct, ct); For dot product we have the following, where only 3 new allocation is made (actually ct6/ct8 can reuse ct4 but that is further optimization) v0 := []int64{0, 0, 0, 0, 0, 0, 0, 1}
err0 := evaluator.Mul(ct, ct1, ct);
err1 := evaluator.Relinearize(ct, ct);
ct4, err2 := evaluator.RotateColumnsNew(ct, 4)
err3 := evaluator.Add(ct, ct4, ct);
ct6, err4 := evaluator.RotateColumnsNew(ct, 2)
err5 := evaluator.Add(ct, ct6, ct);
ct8, err6 := evaluator.RotateColumnsNew(ct, 1)
err7 := evaluator.Add(ct, ct8, ct);
err8 := evaluator.Rescale(ct, ct);
err9 := evaluator.Mul(ct, pt, ct);
err10 := evaluator.RotateColumns(ct, 7, ct)
err11 := evaluator.Rescale(ct, ct);
return ct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a future idea, I can imagine a generalized version of this pass that could be applied to other backends like openfhe, in which the pass operates on a general Operation *
, dyn_casts
the op to an ConvertibleToInplaceOpInterface
(this would be an interface attached to add_new
rather than attached to the inplace add
op). Then this interface has a method that handles the rewriting from non-inplace to inplace internally.
That said, if it's only Lattigo and OpenFHE that have this, it may not be worth the abstraction just to support two things. If we add more library API backends (say, if jaxite or tfhe-rust want this), it may be worth it.
LogicalResult matchAndRewrite(UnaryOp op, | ||
PatternRewriter &rewriter) const override { | ||
// operand 0 is evaluator | ||
auto lhs = op.getOperand(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be misunderstanding, but why not use the op interface you created to fetch the operand that is mutated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well there is a further optimization that find previous dead value as the storage opOut
, instead of mutating input operand inplace.
In Lattigo, for Add(op0, op1, opOut)
, opOut
can be either op0
, op1
or other previously allocated ciphertext. I write the pass now with op0
always as opOut
for simplicity as you can see the detection algorithm above has some complicacy and worth a separate PR.
For that optimization, the result would be the following, where no new allocation is made
err0 := evaluator.Mul(ct, ct1, ct);
err1 := evaluator.Relinearize(ct, ct);
err2 := evaluator.RotateColumns(ct, 4, ct1)
err3 := evaluator.Add(ct, ct1, ct);
err4 := evaluator.RotateColumns(ct, 2, ct1)
err5 := evaluator.Add(ct, ct1, ct);
err6 := evaluator.RotateColumns(ct, 1, ct1)
err7 := evaluator.Add(ct, ct1, ct);
err8 := evaluator.Rescale(ct, ct);
err9 := evaluator.Mul(ct, pt, ct);
err10 := evaluator.RotateColumns(ct, 7, ct)
err11 := evaluator.Rescale(ct, ct);
return ct
Does it make sense to return 1 here, and remove the inplace operand? I don't think you need the extra operand if you have this op interface to identify it.
So the InplaceOpInterface
will only track opOut
instead of op0
, which semantically suggests output = opOut
but not op0 = opOut = output
. For Openfhe API we always have op0 = opOut
but it is not the case for Lattigo.
); | ||
let results = (outs Lattigo_RLWECiphertext:$output); | ||
|
||
let extraClassDeclaration = "int getInplaceOperandIndex() { return 3; }"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to return 1
here, and remove the inplace
operand? I don't think you need the extra operand if you have this op interface to identify it.
9302c15
to
587464b
Compare
587464b
to
9c9a8be
Compare
See #1327.
Lattigo's API prefers the inplace version. Actually it is more flexible in that it has the form
so a more dedicated analysis could use some no longer used buffer as the result buffer, instead of the inplace buffer.
Currently, I implement
lattigo.bgv.add op0, op1
in the IR and useop0
asopOut
in the emitter.Discussion
Transforms
getUses()
is hard to deal with (it might be unordered)TODO
bgv.Rescale
. I previously pretend them all to be alloc version in the IR and handle these difference in emitter. I should resolve the difference later so that we wont have some error caused by such difference (fake alloc op backed by inplace op and the input has more than one uses)Example
Pure addition
The input
becomes
dot product
namely there are only 10 ct. In the alloc version, there are 15 ct in total.