Skip to content

Commit

Permalink
Improve usability by storing & using length of PlainVector/`Secur…
Browse files Browse the repository at this point in the history
…eVector` (#12)

* Change data field name to data for plain and secure vector

* Run unit tests first

* Introduce `length` for Plain/SecureVector

* Enable pretty-printing of types

* Improve code coverage
  • Loading branch information
sloede authored Jan 21, 2024
1 parent 5109b09 commit eb47482
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 81 deletions.
3 changes: 3 additions & 0 deletions src/SecureArithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using OpenFHE: OpenFHE
# Basic types
export SecureContext, SecureVector, PlainVector

# Keys
export PrivateKey, PublicKey

# Backends
export Unencrypted, OpenFHEBackend

Expand Down
93 changes: 47 additions & 46 deletions src/openfhe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ end
function get_crypto_context(context::SecureContext{<:OpenFHEBackend})
context.backend.crypto_context
end
function get_crypto_context(secure_vector::SecureVector{<:OpenFHEBackend})
get_crypto_context(secure_vector.context)
end
function get_crypto_context(plain_vector::PlainVector{<:OpenFHEBackend})
get_crypto_context(plain_vector.context)
function get_crypto_context(v::Union{SecureVector{<:OpenFHEBackend},
PlainVector{<:OpenFHEBackend}})
get_crypto_context(v.context)
end

function generate_keys(context::SecureContext{<:OpenFHEBackend})
Expand All @@ -21,21 +19,24 @@ function generate_keys(context::SecureContext{<:OpenFHEBackend})
public_key, private_key
end

function init_multiplication!(context::SecureContext{<:OpenFHEBackend}, private_key)
function init_multiplication!(context::SecureContext{<:OpenFHEBackend},
private_key::PrivateKey)
cc = get_crypto_context(context)
OpenFHE.EvalMultKeyGen(cc, private_key.private_key)

nothing
end

function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key, shifts)
function init_rotation!(context::SecureContext{<:OpenFHEBackend}, private_key::PrivateKey,
shifts)
cc = get_crypto_context(context)
OpenFHE.EvalRotateKeyGen(cc, private_key.private_key, shifts)

nothing
end

function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend}, private_key)
function init_bootstrapping!(context::SecureContext{<:OpenFHEBackend},
private_key::PrivateKey)
cc = get_crypto_context(context)
ring_dimension = OpenFHE.GetRingDimension(cc)
num_slots = div(ring_dimension, 2)
Expand All @@ -47,39 +48,39 @@ end
function PlainVector(data::Vector{<:Real}, context::SecureContext{<:OpenFHEBackend})
cc = get_crypto_context(context)
plaintext = OpenFHE.MakeCKKSPackedPlaintext(cc, data)
plain_vector = PlainVector(plaintext, context)
plain_vector = PlainVector(plaintext, length(data), context)

plain_vector
end

function encrypt(data::Vector{<:Real}, public_key, context::SecureContext{<:OpenFHEBackend})
plain_vector = PlainVector(context, data)
secure_vector = encrypt(context, public_key, plain_vector)
plain_vector = PlainVector(data, context)
secure_vector = encrypt(plain_vector, public_key)

secure_vector
end

function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key)
function encrypt(plain_vector::PlainVector{<:OpenFHEBackend}, public_key::PublicKey)
context = plain_vector.context
cc = get_crypto_context(context)
ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.plaintext)
secure_vector = SecureVector(ciphertext, context)
ciphertext = OpenFHE.Encrypt(cc, public_key.public_key, plain_vector.data)
secure_vector = SecureVector(ciphertext, length(plain_vector), context)

secure_vector
end

function decrypt!(plain_vector::PlainVector{<:OpenFHEBackend},
secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey)
cc = get_crypto_context(secure_vector)
OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.ciphertext,
plain_vector.plaintext)
OpenFHE.Decrypt(cc, private_key.private_key, secure_vector.data,
plain_vector.data)

plain_vector
end

function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key)
function decrypt(secure_vector::SecureVector{<:OpenFHEBackend}, private_key::PrivateKey)
context = secure_vector.context
plain_vector = PlainVector(OpenFHE.Plaintext(), context)
plain_vector = PlainVector(OpenFHE.Plaintext(), length(secure_vector), context)

decrypt!(plain_vector, secure_vector, private_key)
end
Expand All @@ -88,7 +89,7 @@ end
function bootstrap!(secure_vector::SecureVector{<:OpenFHEBackend})
context = secure_vector.context
cc = get_crypto_context(context)
OpenFHE.EvalBootstrap(cc, secure_vector.ciphertext)
OpenFHE.EvalBootstrap(cc, secure_vector.data)

secure_vector
end
Expand All @@ -100,105 +101,105 @@ end

function add(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalAdd(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalAdd(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function add(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalAdd(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function add(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalAdd(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalAdd(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalSub(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalSub(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function subtract(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(pv::PlainVector{<:OpenFHEBackend}, sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, pv.plaintext, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, pv.data, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function subtract(scalar::Real, sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalSub(cc, scalar, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalSub(cc, scalar, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function negate(sv::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalNegate(cc, sv.ciphertext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalNegate(cc, sv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function multiply(sv1::SecureVector{<:OpenFHEBackend}, sv2::SecureVector{<:OpenFHEBackend})
cc = get_crypto_context(sv1)
ciphertext = OpenFHE.EvalMult(cc, sv1.ciphertext, sv2.ciphertext)
secure_vector = SecureVector(ciphertext, sv1.context)
ciphertext = OpenFHE.EvalMult(cc, sv1.data, sv2.data)
secure_vector = SecureVector(ciphertext, length(sv1), sv1.context)

secure_vector
end

function multiply(sv::SecureVector{<:OpenFHEBackend}, pv::PlainVector{<:OpenFHEBackend})
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, pv.plaintext)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalMult(cc, sv.data, pv.data)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function multiply(sv::SecureVector{<:OpenFHEBackend}, scalar::Real)
cc = get_crypto_context(sv)
ciphertext = OpenFHE.EvalMult(cc, sv.ciphertext, scalar)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalMult(cc, sv.data, scalar)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end

function rotate(sv::SecureVector{<:OpenFHEBackend}, shift)
cc = get_crypto_context(sv)
# We use `-shift` to match Julia's usual `circshift` direction
ciphertext = OpenFHE.EvalRotate(cc, sv.ciphertext, -shift)
secure_vector = SecureVector(ciphertext, sv.context)
ciphertext = OpenFHE.EvalRotate(cc, sv.data, -shift)
secure_vector = SecureVector(ciphertext, length(sv), sv.context)

secure_vector
end
52 changes: 43 additions & 9 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ struct SecureContext{CryptoBackendT <: AbstractCryptoBackend}
backend::CryptoBackendT
end

struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, CiphertextT}
ciphertext::CiphertextT
function Base.show(io::IO, v::SecureContext)
print("SecureContext{", backend_name(v), "}()")
end

struct SecureVector{CryptoBackendT <: AbstractCryptoBackend, DataT}
data::DataT
length::Int
context::SecureContext{CryptoBackendT}

function SecureVector(ciphertext, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(ciphertext)}(ciphertext, context)
function SecureVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(data)}(data, length, context)
end
end

struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, PlaintextT}
plaintext::PlaintextT
Base.length(v::SecureVector) = v.length
function Base.show(io::IO, v::SecureVector)
print("SecureVector{", backend_name(v), "}(data=<encrypted>, length=$(v.length))")
end

struct PlainVector{CryptoBackendT <: AbstractCryptoBackend, DataT}
data::DataT
length::Int
context::SecureContext{CryptoBackendT}

function PlainVector(plaintext, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(plaintext)}(plaintext, context)
function PlainVector(data, length, context::SecureContext{CryptoBackendT}) where CryptoBackendT
new{CryptoBackendT, typeof(data)}(data, length, context)
end
end

Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.plaintext)
Base.length(v::PlainVector) = v.length
function Base.show(io::IO, v::PlainVector{CryptoBackendT}) where CryptoBackendT
print("PlainVector{", backend_name(v), "}(data=<plain>, length=$(v.length))")
end

Base.print(io::IO, plain_vector::PlainVector) = print(io, plain_vector.data)

struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
private_key::KeyT
Expand All @@ -33,6 +49,10 @@ struct PrivateKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
end
end

function Base.show(io::IO, key::PrivateKey{CryptoBackendT}) where CryptoBackendT
print("PrivateKey{", backend_name(key), "}()")
end

struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
public_key::KeyT
context::SecureContext{CryptoBackendT}
Expand All @@ -41,3 +61,17 @@ struct PublicKey{CryptoBackendT <: AbstractCryptoBackend, KeyT}
new{CryptoBackendT, typeof(key)}(key, context)
end
end

function Base.show(io::IO, key::PublicKey{CryptoBackendT}) where CryptoBackendT
print("PublicKey{", backend_name(key), "}()")
end

# Get wrapper name of a potentially parametric type
# Copied from: https://github.com/ClapeyronThermo/Clapeyron.jl/blob/f40c282e2236ff68d91f37c39b5c1e4230ae9ef0/src/utils/core_utils.jl#L17
# Original source: https://github.com/JuliaArrays/ArrayInterface.jl/blob/40d9a87be07ba323cca00f9e59e5285c13f7ee72/src/ArrayInterface.jl#L20
# Note: prefixed by `__` since it is really, really dirty black magic internals we use here!
__parameterless_type(T) = Base.typename(T).wrapper

# Convenience method for getting the human-readable backend name
backend_name(x::Union{SecureContext{T}, SecureVector{T}, PlainVector{T}, PrivateKey{T},
PublicKey{T}}) where T = string(__parameterless_type(T))
Loading

0 comments on commit eb47482

Please sign in to comment.