-
Notifications
You must be signed in to change notification settings - Fork 6
/
merkle_test.go
83 lines (69 loc) · 2.33 KB
/
merkle_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package main
// Test suite for Merkle tree-based VPIR schemes. Only multi-bit schemes are
// implemented using this approach.
import (
"encoding/binary"
"fmt"
"io"
"math"
"testing"
"github.com/si-co/vpir-code/lib/client"
"github.com/si-co/vpir-code/lib/database"
"github.com/si-co/vpir-code/lib/field"
"github.com/si-co/vpir-code/lib/monitor"
"github.com/si-co/vpir-code/lib/server"
"github.com/si-co/vpir-code/lib/utils"
"github.com/stretchr/testify/require"
)
func TestMerkle(t *testing.T) {
numServers := 2
dbLen := oneMB
blockLen := testBlockLength * field.Bytes
// since this scheme works on bytes, the bit size of one element is 8
elemBitSize := 8
numBlocks := dbLen / (elemBitSize * blockLen)
nCols := int(math.Sqrt(float64(numBlocks)))
nRows := nCols
db := database.CreateRandomMerkle(utils.RandomPRG(), dbLen, nRows, blockLen)
retrieveBlocksMerkle(t, utils.RandomPRG(), db, numServers, numBlocks, "Merkle")
}
func TestMerkleFourServers(t *testing.T) {
numServers := 4
dbLen := oneMB
blockLen := testBlockLength * field.Bytes
// since this scheme works on bytes, the bit size of one element is 8
elemBitSize := 8
numBlocks := dbLen / (elemBitSize * blockLen)
nCols := int(math.Sqrt(float64(numBlocks)))
nRows := nCols
db := database.CreateRandomMerkle(utils.RandomPRG(), dbLen, nRows, blockLen)
retrieveBlocksMerkle(t, utils.RandomPRG(), db, numServers, numBlocks, "MerkleFourServers")
}
func retrieveBlocksMerkle(t *testing.T, rnd io.Reader, db *database.Bytes, numServers, numBlocks int, testName string) {
c := client.NewPIR(rnd, &db.Info)
servers := make([]*server.PIR, numServers)
for i := range servers {
if numServers == 2 {
servers[i] = server.NewPIRTwo(db)
} else {
servers[i] = server.NewPIR(db)
}
}
totalTimer := monitor.NewMonitor()
for i := 0; i < numBlocks; i++ {
in := make([]byte, 4)
binary.BigEndian.PutUint32(in, uint32(i))
queries, err := c.QueryBytes(in, numServers)
require.NoError(t, err)
answers := make([][]byte, numServers)
for i, s := range servers {
a, err := s.AnswerBytes(queries[i])
require.NoError(t, err)
answers[i] = a
}
res, err := c.ReconstructBytes(answers)
require.NoError(t, err)
require.Equal(t, db.Entries[i*db.BlockSize:(i+1)*db.BlockSize-db.ProofLen-1], res)
}
fmt.Printf("Total CPU time %s: %.1fms\n", testName, totalTimer.Record())
}