Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not require explicit passing of context #7

Merged
merged 1 commit into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/simple_ckks_bootstrapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@ function simple_ckks_bootstrapping(context)
init_bootstrapping!(context, private_key)

x = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0]
encoded_length = length(x)

pv = PlainVector(context, x)
pv = PlainVector(x, context)
println("Input: ", pv)

sv = encrypt(context, public_key, pv)
sv = encrypt(pv, public_key)

# Perform the bootstrapping operation. The goal is to increase the number of levels
# remaining for HE computation.
sv_after = bootstrap!(context, sv)
sv_after = bootstrap!(sv)

result = decrypt(context, private_key, sv)
result = decrypt(sv, private_key)
println("Output after bootstrapping \n\t", result)
end

Expand Down
22 changes: 11 additions & 11 deletions examples/simple_real_numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ function simple_real_numbers(context)
x1 = [0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0, 5.0]
x2 = [5.0, 4.0, 3.0, 2.0, 1.0, 0.75, 0.5, 0.25]

pv1 = PlainVector(context, x1)
pv2 = PlainVector(context, x2)
pv1 = PlainVector(x1, context)
pv2 = PlainVector(x2, context)

println("Input x1: ", pv1)
println("Input x2: ", pv2)

sv1 = encrypt(context, public_key, pv1)
sv2 = encrypt(context, public_key, pv2)
sv1 = encrypt(pv1, public_key)
sv2 = encrypt(pv2, public_key)

sv_add = sv1 + sv2

Expand All @@ -35,25 +35,25 @@ function simple_real_numbers(context)
println()
println("Results of homomorphic computations: ")

result_sv1 = decrypt(context, private_key, sv1)
result_sv1 = decrypt(sv1, private_key)
println("x1 = ", result_sv1)

result_sv_add = decrypt(context, private_key, sv_add)
result_sv_add = decrypt(sv_add, private_key)
println("x1 + x2 = ", result_sv_add)

result_sv_sub = decrypt(context, private_key, sv_sub)
result_sv_sub = decrypt(sv_sub, private_key)
println("x1 - x2 = ", result_sv_sub)

result_sv_scalar = decrypt(context, private_key, sv_scalar)
result_sv_scalar = decrypt(sv_scalar, private_key)
println("4 * x1 = ", result_sv_scalar)

result_sv_mult = decrypt(context, private_key, sv_mult)
result_sv_mult = decrypt(sv_mult, private_key)
println("x1 * x2 = ", result_sv_mult)

result_sv_shift1 = decrypt(context, private_key, sv_shift1)
result_sv_shift1 = decrypt(sv_shift1, private_key)
println("x1 shifted circularly by -1 = ", result_sv_shift1)

result_sv_shift2 = decrypt(context, private_key, sv_shift2)
result_sv_shift2 = decrypt(sv_shift2, private_key)
println("x1 shifted circularly by 2 = ", result_sv_shift2)
end

Expand Down
28 changes: 12 additions & 16 deletions src/openfhe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,53 +44,49 @@
nothing
end

function PlainVector(context::SecureContext{<:OpenFHEBackend}, data::Vector{<:Real})
function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend})
cc = get_crypto_context(context)
plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data)
plain_vector = PlainVector(plaintext, context)

plain_vector
end

function encrypt(context::SecureContext{<:OpenFHEBackend}, public_key, data::Vector{<:Real})
function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend})

Check warning on line 55 in src/openfhe.jl

View check run for this annotation

Codecov / codecov/patch

src/openfhe.jl#L55

Added line #L55 was not covered by tests
plain_vector = PlainVector(context, data)
secure_vector = encrypt(context, public_key, plain_vector)

secure_vector
end

function encrypt(context::SecureContext{<:OpenFHEBackend}, public_key,
plain_vector::PlainVector)
function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key)
context = plain_vector.context
cc = get_crypto_context(context)
ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext)
secure_vector = SecureVector(ciphertext, context)

secure_vector
end

function decrypt!(plain_vector, context::SecureContext{<:OpenFHEBackend}, private_key,
secure_vector)
cc = get_crypto_context(context)
function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend},
secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
cc = get_crypto_context(secure_vector)
OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext,
plain_vector.plaintext)

plain_vector
end

function decrypt(context::SecureContext{<:OpenFHEBackend}, private_key, secure_vector)
function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
context = secure_vector.context
plain_vector = PlainVector(OpenFHE.Plaintext(), context)

decrypt!(plain_vector, context, private_key, secure_vector)
decrypt!(plain_vector, secure_vector, private_key)
end


function bootstrap!(context::SecureContext{<:OpenFHEBackend}, secure_vector)
cc = get_crypto_context(context)
OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext)

secure_vector
end
function bootstrap!(context::SecureContext{<:OpenFHEBackend}, secure_vector)
function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend})
context = secure_vector.context
cc = get_crypto_context(context)
OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext)

Expand Down
23 changes: 10 additions & 13 deletions src/unencrypted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,30 @@
init_rotation!(context::SecureContext{<:Unencrypted}, private_key, shifts) = nothing
init_bootstrapping!(context::SecureContext{<:Unencrypted}, private_key) = nothing

function PlainVector(context::SecureContext{<:Unencrypted}, data::Vector{<:Real})
plain_vector = PlainVector(data, context)
end
# No constructor for `PlainVector` necessary since we can directly use the inner constructor

function encrypt(context::SecureContext{<:Unencrypted}, public_key, data::Vector{<:Real})
function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:Unencrypted})

Check warning on line 15 in src/unencrypted.jl

View check run for this annotation

Codecov / codecov/patch

src/unencrypted.jl#L15

Added line #L15 was not covered by tests
SecureVector(data, context)
end

function encrypt(context::SecureContext{<:Unencrypted}, public_key,
plain_vector::PlainVector)
SecureVector(plain_vector.plaintext, context)
function encrypt(plain_vector::PlainVector{<:Unencrypted}, public_key)
SecureVector(plain_vector.plaintext, plain_vector.context)
end

function decrypt!(plain_vector, context::SecureContext{<:Unencrypted}, private_key,
secure_vector)
function decrypt!(plain_vector::PlainVector{<:Unencrypted},
secure_vector::SecureVector{<:Unencrypted}, private_key)
plain_vector.plaintext .= secure_vector.ciphertext

plain_vector
end

function decrypt(context::SecureContext{<:Unencrypted}, private_key, secure_vector)
plain_vector = PlainVector(similar(secure_vector.ciphertext), context)
function decrypt(secure_vector::SecureVector{<:Unencrypted}, private_key)
plain_vector = PlainVector(similar(secure_vector.ciphertext), secure_vector.context)

decrypt!(plain_vector, context, private_key, secure_vector)
decrypt!(plain_vector, secure_vector, private_key)
end

bootstrap!(context::SecureContext{<:Unencrypted}, secure_vector) = secure_vector
bootstrap!(secure_vector::SecureVector{<:Unencrypted}) = secure_vector

function add(sv1::SecureVector{<:Unencrypted}, sv2::SecureVector{<:Unencrypted})
SecureVector(sv1.ciphertext .+ sv2.ciphertext, sv1.context)
Expand Down
Loading