Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ciphertext packing in bootstrapping #506

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 133 additions & 43 deletions circuits/ckks/bootstrapping/bootstrapping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestBootstrapping(t *testing.T) {
})
})

t.Run("BootstrappingPackedWithRingDegreeSwitch", func(t *testing.T) {
t.Run("BootstrappingPackedWithoutRingDegreeSwitch", func(t *testing.T) {

schemeParamsLit := testPrec45
btpParamsLit := ParametersLiteral{}
Expand All @@ -196,75 +196,165 @@ func TestBootstrapping(t *testing.T) {
}

btpParamsLit.LogN = utils.Pointy(schemeParamsLit.LogN)
schemeParamsLit.LogNthRoot = schemeParamsLit.LogN + 1
schemeParamsLit.LogN -= 3

params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
if !*flagLongTest {
btpParams.SlotsToCoeffsParameters.LogSlots = btpParams.BootstrappingParameters.LogN() - 1
btpParams.CoeffsToSlotsParameters.LogSlots = btpParams.BootstrappingParameters.LogN() - 1

// Corrects the message ratio to take into account the smaller number of slots and keep the same precision
btpParams.Mod1ParametersLiteral.LogMessageRatio += 16 - params.LogN()
btpParamsLit.LogMessageRatio = utils.Pointy(DefaultLogMessageRatio + (16 - params.LogN()))
}

t.Logf("Params: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.ResidualParameters.LogN(), btpParams.ResidualParameters.LogMaxSlots(), btpParams.ResidualParameters.LogQP())
t.Logf("BTPParams: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.BootstrappingParameters.LogN(), btpParams.BootstrappingParameters.LogMaxSlots(), btpParams.BootstrappingParameters.LogQP())

sk := rlwe.NewKeyGenerator(params).GenSecretKeyNew()

t.Log("Generating Bootstrapping Keys")
btpKeys, _, err := btpParams.GenEvaluationKeys(sk)
require.Nil(t, err)
ecd := ckks.NewEncoder(params)
enc := rlwe.NewEncryptor(params, sk)
dec := rlwe.NewDecryptor(params, sk)

evaluator, err := NewEvaluator(btpParams, btpKeys)
for _, sparsity := range []int{0, 1, 2} {
btpParamsLit.LogSlots = utils.Pointy(*btpParamsLit.LogN - 1 - sparsity)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

t.Logf("Params: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.ResidualParameters.LogN(), btpParams.ResidualParameters.LogMaxSlots(), btpParams.ResidualParameters.LogQP())
t.Logf("BTPParams: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.BootstrappingParameters.LogN(), btpParams.BootstrappingParameters.LogMaxSlots(), btpParams.BootstrappingParameters.LogQP())

t.Log("Generating Bootstrapping Keys for LogSlots")
btpKeys, _, err := btpParams.GenEvaluationKeys(sk)
require.Nil(t, err)

evaluator, err := NewEvaluator(btpParams, btpKeys)
require.Nil(t, err)

for _, slotOffset := range []int{0, 1, 2, 3} {
logMaxSlots := params.LogMaxSlots() - slotOffset
logMaxSlots = utils.Min(logMaxSlots, btpParams.LogMaxSlots())
values := make([]complex128, 1<<logMaxSlots)
for i := range values {
values[i] = sampling.RandComplex128(-1, 1)
}

values[0] = complex(0.9238795325112867, 0.3826834323650898)
values[1] = complex(0.9238795325112867, 0.3826834323650898)
if len(values) > 2 {
values[2] = complex(0.9238795325112867, 0.3826834323650898)
values[3] = complex(0.9238795325112867, 0.3826834323650898)
}

pt := ckks.NewPlaintext(params, 0)
pt.LogDimensions = ring.Dimensions{Rows: 0, Cols: logMaxSlots}

cts := make([]rlwe.Ciphertext, 11)
for i := range cts {

require.NoError(t, ecd.Encode(utils.RotateSlice(values, i), pt))

ct, err := enc.EncryptNew(pt)
require.NoError(t, err)

cts[i] = *ct
}

if cts, err = evaluator.BootstrapMany(cts); err != nil {
t.Fatal(err)
}

for i := range cts {
// Checks that the output ciphertext is at the max level of paramsN1
require.True(t, cts[i].Level() == params.MaxLevel())
require.True(t, cts[i].Scale.Equal(params.DefaultScale()))

verifyTestVectorsBootstrapping(params, ecd, dec, utils.RotateSlice(values, i), &cts[i], t)
}
}
}
})

t.Run("BootstrappingPackedWithRingDegreeSwitch", func(t *testing.T) {

schemeParamsLit := testPrec45
btpParamsLit := ParametersLiteral{}

if *flagLongTest {
schemeParamsLit.LogN = 16
}

btpParamsLit.LogN = utils.Pointy(schemeParamsLit.LogN)
schemeParamsLit.LogNthRoot = schemeParamsLit.LogN + 1
schemeParamsLit.LogN -= 3

params, err := ckks.NewParametersFromLiteral(schemeParamsLit)
require.Nil(t, err)

// Insecure params for fast testing only
if !*flagLongTest {
// Corrects the message ratio to take into account the smaller number of slots and keep the same precision
btpParamsLit.LogMessageRatio = utils.Pointy(DefaultLogMessageRatio + (16 - params.LogN()))
}

sk := rlwe.NewKeyGenerator(params).GenSecretKeyNew()

ecd := ckks.NewEncoder(params)
enc := rlwe.NewEncryptor(params, sk)
dec := rlwe.NewDecryptor(params, sk)

values := make([]complex128, params.MaxSlots())
for i := range values {
values[i] = sampling.RandComplex128(-1, 1)
}
for _, sparsity := range []int{0, 1, 2} {
btpParamsLit.LogSlots = utils.Pointy(*btpParamsLit.LogN - 1 - sparsity)
btpParams, err := NewParametersFromLiteral(params, btpParamsLit)
require.Nil(t, err)

values[0] = complex(0.9238795325112867, 0.3826834323650898)
values[1] = complex(0.9238795325112867, 0.3826834323650898)
if len(values) > 2 {
values[2] = complex(0.9238795325112867, 0.3826834323650898)
values[3] = complex(0.9238795325112867, 0.3826834323650898)
}
t.Logf("Params: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.ResidualParameters.LogN(), btpParams.ResidualParameters.LogMaxSlots(), btpParams.ResidualParameters.LogQP())
t.Logf("BTPParams: LogN=%d/LogSlots=%d/LogQP=%f", btpParams.BootstrappingParameters.LogN(), btpParams.BootstrappingParameters.LogMaxSlots(), btpParams.BootstrappingParameters.LogQP())

pt := ckks.NewPlaintext(params, 0)
t.Log("Generating Bootstrapping Keys for LogSlots")
btpKeys, _, err := btpParams.GenEvaluationKeys(sk)
require.Nil(t, err)

cts := make([]rlwe.Ciphertext, 7)
for i := range cts {
evaluator, err := NewEvaluator(btpParams, btpKeys)
require.Nil(t, err)

require.NoError(t, ecd.Encode(utils.RotateSlice(values, i), pt))
for _, slotOffset := range []int{0, 1, 2, 3} {
logMaxSlots := params.LogMaxSlots() - slotOffset
logMaxSlots = utils.Min(logMaxSlots, btpParams.LogMaxSlots())
values := make([]complex128, 1<<logMaxSlots)
for i := range values {
values[i] = sampling.RandComplex128(-1, 1)
}

ct, err := enc.EncryptNew(pt)
require.NoError(t, err)
values[0] = complex(0.9238795325112867, 0.3826834323650898)
values[1] = complex(0.9238795325112867, 0.3826834323650898)
if len(values) > 2 {
values[2] = complex(0.9238795325112867, 0.3826834323650898)
values[3] = complex(0.9238795325112867, 0.3826834323650898)
}

cts[i] = *ct
}
pt := ckks.NewPlaintext(params, 0)
pt.LogDimensions = ring.Dimensions{Rows: 0, Cols: logMaxSlots}

if cts, err = evaluator.BootstrapMany(cts); err != nil {
t.Fatal(err)
}
cts := make([]rlwe.Ciphertext, 11)
for i := range cts {

for i := range cts {
// Checks that the output ciphertext is at the max level of paramsN1
require.True(t, cts[i].Level() == params.MaxLevel())
require.True(t, cts[i].Scale.Equal(params.DefaultScale()))
require.NoError(t, ecd.Encode(utils.RotateSlice(values, i), pt))

ct, err := enc.EncryptNew(pt)
require.NoError(t, err)

verifyTestVectorsBootstrapping(params, ecd, dec, utils.RotateSlice(values, i), &cts[i], t)
cts[i] = *ct
}

if cts, err = evaluator.BootstrapMany(cts); err != nil {
t.Fatal(err)
}

for i := range cts {
// Checks that the output ciphertext is at the max level of paramsN1
require.True(t, cts[i].Level() == params.MaxLevel())
require.True(t, cts[i].Scale.Equal(params.DefaultScale()))

verifyTestVectorsBootstrapping(params, ecd, dec, utils.RotateSlice(values, i), &cts[i], t)
}
}
}
})

Expand Down
Loading
Loading