forked from homenc/HElib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpartialMatch.h
314 lines (278 loc) · 10.9 KB
/
partialMatch.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
/* Copyright (C) 2020 IBM Corp.
* This program is Licensed under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. See accompanying LICENSE file.
*/
#ifndef HELIB_PARTIALMATCH_H
#define HELIB_PARTIALMATCH_H
#include <sstream>
#include <helib/Matrix.h>
#include <helib/PolyMod.h>
#include <helib/query.h>
// This code is in flux and should be considered very alpha.
// Not recommended for public use.
namespace helib {
/**
* @brief Given a query set and a database, calculates a mask of {0,1} where 1
* signifies a matching element and 0 otherwise.
* @tparam TXT type of the query set. Must be a `Ptxt` or `Ctxt`.
* @param ea The encrypted array object holding information about the scheme.
* @param query The query set to mask against the database. Must be a row
* vector of the same dimension as the second dimension of the database matrix.
* @param database The matrix holding the plaintext database.
* @return The calculated mask. Is the same size as the database.
* @note This is an overloaded function for when the database is not encrypted.
**/
template <typename TXT>
inline Matrix<TXT> calculateMasks(const EncryptedArray& ea,
Matrix<TXT> query,
const Matrix<Ptxt<BGV>>& database)
{
if (query.dims(0) != 1)
throw InvalidArgument("Query must be a row vector");
if (query.dims(1) != database.dims(1))
throw InvalidArgument(
"Database and query must have same number of columns");
// TODO: case where query.dims(0) != database.dims(0)
// Replicate the query once per row of the database
// TODO: Some such replication will be needed once blocks/bands exist
std::vector<long> columns(database.dims(0), 0l);
Matrix<TXT>& mask = query;
mask.inPlaceTranspose();
mask = mask.columns(columns);
mask.inPlaceTranspose();
(mask -= database)
.apply([&](auto& entry) { mapTo01(ea, entry); })
.apply([](auto& entry) { entry.negate(); })
.apply([](auto& entry) { entry.addConstant(NTL::ZZX(1l)); });
return mask;
}
/**
* @brief Given a query set and a database, calculates a mask of {0,1} where 1
* signifies a matching element and 0 otherwise.
* @tparam TXT type of the query set. Must be a `Ptxt` or `Ctxt`.
* @param ea The encrypted array object holding information about the scheme.
* @param query The query set to mask against the database. Must be a row
* vector of the same dimension as the second dimension of the database matrix.
* @param database The matrix holding the encrypted database.
* @return The calculated mask. Is the same size as the database.
* @note This is an overloaded function for when the database is encrypted.
**/
template <typename TXT>
Matrix<Ctxt> calculateMasks(const EncryptedArray& ea,
Matrix<TXT> query,
const Matrix<Ctxt>& database)
{
if (query.dims(0) != 1)
throw InvalidArgument("Query must be a row vector");
if (query.dims(1) != database.dims(1))
throw InvalidArgument(
"Database and query must have same number of columns");
// TODO: case where query.dims(0) != database.dims(0)
// Replicate the query once per row of the database
// TODO: Some such replication will be needed once blocks/bands exist
std::vector<long> columns(database.dims(0), 0l);
Matrix<TXT>& mask = query;
mask.inPlaceTranspose();
mask = mask.columns(columns);
mask.inPlaceTranspose();
// FIXME: Avoid deep copy
// Ptxt Query
if constexpr (std::is_same_v<TXT, Ptxt<BGV>>) {
auto tmp = database.deepCopy();
(tmp -= mask)
.apply([&](auto& entry) { mapTo01(ea, entry); })
.apply([](auto& entry) { entry.negate(); })
.apply([](auto& entry) { entry.addConstant(NTL::ZZX(1l)); });
return tmp;
} else { // Ctxt Query
(mask -= database)
.apply([&](auto& entry) { mapTo01(ea, entry); })
.apply([](auto& entry) { entry.negate(); })
.apply([](auto& entry) { entry.addConstant(NTL::ZZX(1l)); });
return mask;
}
}
/**
* @brief Given a mask and information about the query to be performed,
* calculates a score for each matching element signified by the mask.
* @tparam TXT type of the mask matrix. Must be a `Ptxt` or `Ctxt`.
* @param index_sets The set of indices signifying which columns of the mask
* to query.
* @param offsets The constant term to be added to the final score of each
* queried column.
* @param weights The weighted importance assigned to each queried column.
* @param mask The mask with which to calculate the score from.
* @return A single `Ctxt` or `Ptxt` containing the total score for each
* queried column.
**/
template <typename TXT>
inline Matrix<TXT> calculateScores(
const std::vector<std::vector<long>> index_sets,
const std::vector<long>& offsets,
const std::vector<Matrix<long>>& weights,
const Matrix<TXT>& mask)
{
assertEq<InvalidArgument>(index_sets.size(),
offsets.size(),
"index_sets and offsets must have matching size");
assertEq<InvalidArgument>(index_sets.size(),
weights.size(),
"index_sets and weights must have matching size");
auto ones(mask(0, 0));
ones.clear();
ones.addConstant(NTL::ZZX(1L));
Matrix<TXT> result(ones, mask.dims(0), 1l);
for (std::size_t i = 0; i < index_sets.size(); ++i) {
const auto& index_set = index_sets.at(i);
const auto& weight_set = weights.at(i);
long offset = offsets.at(i);
assertEq<InvalidArgument>(
weight_set.dims(0),
index_set.size(),
"found mismatch between index set size and weight set size");
assertEq<InvalidArgument>(weight_set.dims(1),
1lu,
"all weight sets must be column vectors");
Matrix<TXT> submatrix = mask.columns(index_set);
Matrix<TXT> factor(submatrix * weight_set);
// factor should in fact be a 1*1 matrix
factor.apply([&](auto& entry) { entry.addConstant(NTL::ZZX(offset)); });
result.template entrywiseOperation<TXT>(
factor,
[](auto& lhs, const auto& rhs) -> decltype(auto) {
lhs.multiplyBy(rhs);
return lhs;
});
}
return result;
}
/**
* @brief Given a value, encode the value across the coefficients of a
* polynomial.
* @param input The value of which to encode.
* @param context The context object holding information on how to encode the
* value.
* @return A polynomial representing the encoded value.
**/
inline PolyMod partialMatchEncode(uint32_t input, const Context& context)
{
const long p = context.getP();
std::vector<long> coeffs(context.getOrdP());
// TODO - shouldn't keep checking input.
for (long i = 0; i < long(coeffs.size()) && input != 0; ++i) {
coeffs[i] = input % p;
input /= p;
}
return PolyMod(coeffs, context.getSlotRing());
}
/**
* @class Database
* @tparam TXT The database is templated on `TXT` which can either be a `Ctxt`
* or a `Ptxt<BGV>`
* @brief An object representing a database which is a `HElib::Matrix<TXT>`.
**/
template <typename TXT>
class Database
{
public:
// FIXME: Generally, should Database own the Matrix uniquely?
// Should we force good practice and ask that Context always be shared_ptr?
// FIXME: Should probably move Matrix or make it unique_ptr or both?
/**
* @brief Constructor.
* @param M The `Matrix<TXT>` containing the data of the database.
* @param c A shared pointer to the context used to create the data.
**/
Database(const Matrix<TXT>& M, std::shared_ptr<const Context> c) :
data(M), context(c)
{}
// FIXME: Should this option really exist?
/**
* @brief Constructor.
* @param M The `Matrix<TXT>` containing the data of the database.
* @param c The context object used to create the data.
* @note This version accepts a `Context` that this object is not responsible
* for i.e. if it is on the stack. The programmer is responsible in this case
* for scope.
**/
Database(const Matrix<TXT>& M, const Context& c) :
data(M),
context(std::shared_ptr<const helib::Context>(&c, [](auto UNUSED p) {}))
{}
/**
* @brief Function for performing a database lookup given a query expression
* and query data.
* @tparam TXT2 The type of the query data, can be either a `Ctxt` or
* `Ptxt<BGV>`.
* @param lookup_query The lookup query expression to perform.
* @param query_data The lookup query data to compare with the database.
* @return A `Matrix<TXT2>` containing 1s and 0s in slots where there was a
* match or no match respectively.
**/
template <typename TXT2>
auto contains(const QueryType& lookup_query,
const Matrix<TXT2>& query_data) const;
/**
* @brief Function for performing a weighted partial match given a query
* expression and query data.
* @tparam TXT2 The type of the query data, can be either a `Ctxt` or
* `Ptxt<BGV>`.
* @param weighted_query The weighted lookup query expression to perform.
* @param query_data The query data to compare with the database.
* @return A `Matrix<TXT2>` containing a score on weighted matches.
**/
template <typename TXT2>
auto getScore(const QueryType& weighted_query,
const Matrix<TXT2>& query_data) const;
// TODO - correct name?
/**
* @brief Returns number of columns in the database.
* @return The number of columns in the database.
**/
long columns() { return data.dims(1); }
Matrix<TXT>& getData();
private:
Matrix<TXT> data;
std::shared_ptr<const Context> context;
};
template <typename TXT>
template <typename TXT2>
inline auto Database<TXT>::contains(const QueryType& lookup_query,
const Matrix<TXT2>& query_data) const
{
auto result = getScore<TXT2>(lookup_query, query_data);
if (lookup_query.containsOR) {
// FLT on the scores
result.apply([&](auto& txt) {
txt.power(context->getAlMod().getPPowR() - 1);
return txt;
});
}
return result;
}
template <typename TXT>
template <typename TXT2>
inline auto Database<TXT>::getScore(const QueryType& weighted_query,
const Matrix<TXT2>& query_data) const
{
auto mask = calculateMasks(context->getEA(), query_data, this->data);
auto result = calculateScores(weighted_query.Fs,
weighted_query.mus,
weighted_query.taus,
mask);
return result;
}
template <typename TXT>
inline Matrix<TXT>& Database<TXT>::getData()
{
return data;
}
} // namespace helib
#endif