Skip to content

Commit

Permalink
additional reimplementation of a couple of data types for QBG
Browse files Browse the repository at this point in the history
  • Loading branch information
masajiro committed Jul 11, 2023
1 parent 9b59603 commit 3f9d6f7
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 17 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.15
2.0.16
22 changes: 16 additions & 6 deletions lib/NGT/NGTQ/ObjectFile.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,20 @@ class ObjectFile : public ArrayFile<NGT::Object> {
objectFiles.clear();
}

bool get(const size_t streamID, size_t id, std::vector<float> &data, NGT::ObjectSpace *objectSpace) {
template<typename T>
bool get(const size_t streamID, size_t id, std::vector<T> &data, NGT::ObjectSpace *objectSpace = 0) {
if (streamID >= objectFiles.size()) {
std::cerr << "ObjectFile::streamID is invalid. " << streamID << ":" << objectFiles.size() << std::endl;
return false;
}
if (!objectFiles[streamID]->get(id, data, objectSpace)) {
if (!objectFiles[streamID]->get(id, data)) {
return false;
}
return true;
}

bool get(const size_t id, std::vector<float> &data, NGT::ObjectSpace *os = 0) {
template<typename T>
bool get(const size_t id, std::vector<T> &data, NGT::ObjectSpace *os = 0) {
if (objectSpace == 0) {
stringstream msg;
msg << "ObjectFile::Fatal Error. objectSpace is not set." << std::endl;
Expand All @@ -138,8 +140,11 @@ class ObjectFile : public ArrayFile<NGT::Object> {
}
const std::type_info &otype = objectSpace->getObjectType();
size_t dim = objectSpace->getDimension();
data.resize(pseudoDimension, 0);
if (otype == typeid(uint8_t)) {
data.resize(pseudoDimension);
if (typeid(T) == otype) {
auto *v = object->getPointer();
memcpy(data.data(), v, sizeof(T) * dim);
} else if (otype == typeid(uint8_t)) {
auto *v = static_cast<uint8_t*>(object->getPointer());
for (size_t i = 0; i < dim; i++) {
data[i] = v[i];
Expand All @@ -151,7 +156,12 @@ class ObjectFile : public ArrayFile<NGT::Object> {
}
} else if (otype == typeid(float)) {
auto *v = static_cast<float*>(object->getPointer());
memcpy(data.data(), v, sizeof(float) * dim);
for (size_t i = 0; i < dim; i++) {
data[i] = v[i];
}
}
for (size_t i = dim; i < pseudoDimension; i++) {
data[i] = 0;
}
objectSpace->deleteObject(object);
return true;
Expand Down
50 changes: 40 additions & 10 deletions lib/NGT/NGTQ/QuantizedBlobGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -928,24 +928,39 @@ namespace QBG {
auto threadid = omp_get_thread_num();
auto paddedDimension = getQuantizer().globalCodebookIndex.getObjectSpace().getPaddedDimension();
NGT::ResultPriorityQueue rs;
std::vector<float> object;
qresults.resize(results.size());
size_t idx = results.size();
while (!results.empty()) {
auto r = results.top();
results.pop();
if (objectSpace.getObjectType() == typeid(float)) {
std::vector<float> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object, &quantizer.globalCodebookIndex.getObjectSpace());
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object, &quantizer.globalCodebookIndex.getObjectSpace());
quantizer.objectList.get(r.id, object);
#endif
if (objectSpace.getObjectType() == typeid(float)) {
r.distance = NGT::PrimitiveComparator::compareL2(static_cast<float*>(searchContainer.object.getPointer()),
static_cast<float*>(object.data()), paddedDimension);
} else if (objectSpace.getObjectType() == typeid(uint8_t)) {
std::vector<uint8_t> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object);
#endif
r.distance = NGT::PrimitiveComparator::compareL2(static_cast<uint8_t*>(searchContainer.object.getPointer()),
static_cast<uint8_t*>(object.data()), paddedDimension);
#ifdef NGT_HALF_FLOAT
} else if (objectSpace.getObjectType() == typeid(NGT::float16)) {
r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast<NGT::float16*>(searchContainer.object.getPointer()),
reinterpret_cast<NGT::float16*>(object.data()), paddedDimension);
std::vector<NGT::float16> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object);
#endif
r.distance = NGT::PrimitiveComparator::compareL2(static_cast<NGT::float16*>(searchContainer.object.getPointer()),
static_cast<NGT::float16*>(object.data()), paddedDimension);
#endif
}
qresults[--idx] = r;
Expand All @@ -961,20 +976,35 @@ namespace QBG {
auto threadid = omp_get_thread_num();
auto paddedDimension = getQuantizer().globalCodebookIndex.getObjectSpace().getPaddedDimension();
NGT::ResultPriorityQueue rs;
std::vector<float> object;
while (!results.empty()) {
auto r = results.top();
results.pop();
if (objectSpace.getObjectType() == typeid(float)) {
std::vector<float> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object, &quantizer.globalCodebookIndex.getObjectSpace());
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object, &quantizer.globalCodebookIndex.getObjectSpace());
quantizer.objectList.get(r.id, object);
#endif
if (objectSpace.getObjectType() == typeid(float)) {
r.distance = NGT::PrimitiveComparator::compareL2(static_cast<float*>(searchContainer.object.getPointer()),
static_cast<float*>(object.data()), paddedDimension);
} else if (objectSpace.getObjectType() == typeid(uint8_t)) {
std::vector<uint8_t> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object);
#endif
r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast<uint8_t*>(searchContainer.object.getPointer()),
reinterpret_cast<uint8_t*>(object.data()), paddedDimension);
#ifdef NGT_HALF_FLOAT
} else if (objectSpace.getObjectType() == typeid(NGT::float16)) {
std::vector<NGT::float16> object;
#ifdef MULTIPLE_OBJECT_LISTS
quantizer.objectList.get(threadid, r.id, object);
#else
quantizer.objectList.get(r.id, object);
#endif
r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast<NGT::float16*>(searchContainer.object.getPointer()),
reinterpret_cast<NGT::float16*>(object.data()), paddedDimension);
#endif
Expand Down

0 comments on commit 3f9d6f7

Please sign in to comment.