forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathIndexFlat_T.h
87 lines (71 loc) · 2.35 KB
/
IndexFlat_T.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once
#ifdef OPT_FLAT_DTYPE
#include <vector>
#include <faiss/Index.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/dtype.h>
namespace faiss {
template <typename T>
struct IndexFlat_T: Index {
std::vector<T> base;
std::vector<float> norms;
IndexFlat_T() {
}
IndexFlat_T (idx_t d, MetricType metric = METRIC_L2):
Index (d, metric) {
}
void add (idx_t n, const float* y) override {
const float* yi = y;
if (metric_type == METRIC_PROJECTION) {
for (size_t i = 0; i < n; i++) {
float rnorm = 1.0f / std::sqrt (vec_IP_T (yi, yi, d));
for (size_t j = 0; j < d; j++) {
base.push_back (static_cast<T> (yi[j] * rnorm));
}
yi += d;
}
}
else {
for (size_t i = 0; i < n; i++) {
for (size_t j = 0; j < d; j++) {
base.push_back (static_cast<T> (yi[j]));
}
if (metric_type == METRIC_L2_EXPAND) {
norms.push_back (vec_IP_T (yi, yi, d));
}
yi += d;
}
}
ntotal += n;
}
void reset () override {
base.clear ();
norms.clear ();
ntotal = 0;
}
void search (idx_t n, const float* x,
idx_t k, float* distances, idx_t* labels) const override {
Converter_T<T> converter (n * d, x);
if (metric_type == METRIC_INNER_PRODUCT ||
metric_type == METRIC_PROJECTION) {
float_minheap_array_t res = {
size_t(n), size_t(k), labels, distances};
knn_inner_product_T (converter.x, base.data (), d, n,
ntotal, &res);
} else if (metric_type == METRIC_L2) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
knn_L2Sqr_T (converter.x, base.data(), d, n, ntotal, &res);
} else if (metric_type == METRIC_L2_EXPAND) {
float_maxheap_array_t res = {
size_t(n), size_t(k), labels, distances};
knn_L2Sqr_expand_T (converter.x, base.data(), d, n, ntotal,
&res, norms.data());
}
else {
FAISS_THROW_FMT("unsupported metric type: %d", (int)metric_type);
}
}
};
}
#endif