From 3f9d6f71bab358a6674d6cbe5595c7fcab43b3e0 Mon Sep 17 00:00:00 2001 From: Masajiro Iwasaki Date: Tue, 11 Jul 2023 09:04:50 +0900 Subject: [PATCH] additional reimplementation of a couple of data types for QBG --- VERSION | 2 +- lib/NGT/NGTQ/ObjectFile.h | 22 ++++++++++---- lib/NGT/NGTQ/QuantizedBlobGraph.h | 50 ++++++++++++++++++++++++------- 3 files changed, 57 insertions(+), 17 deletions(-) diff --git a/VERSION b/VERSION index b8061b5..a14da29 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.0.15 +2.0.16 diff --git a/lib/NGT/NGTQ/ObjectFile.h b/lib/NGT/NGTQ/ObjectFile.h index 3beeb19..da37d1b 100644 --- a/lib/NGT/NGTQ/ObjectFile.h +++ b/lib/NGT/NGTQ/ObjectFile.h @@ -114,18 +114,20 @@ class ObjectFile : public ArrayFile { objectFiles.clear(); } - bool get(const size_t streamID, size_t id, std::vector &data, NGT::ObjectSpace *objectSpace) { + template + bool get(const size_t streamID, size_t id, std::vector &data, NGT::ObjectSpace *objectSpace = 0) { if (streamID >= objectFiles.size()) { std::cerr << "ObjectFile::streamID is invalid. " << streamID << ":" << objectFiles.size() << std::endl; return false; } - if (!objectFiles[streamID]->get(id, data, objectSpace)) { + if (!objectFiles[streamID]->get(id, data)) { return false; } return true; } - bool get(const size_t id, std::vector &data, NGT::ObjectSpace *os = 0) { + template + bool get(const size_t id, std::vector &data, NGT::ObjectSpace *os = 0) { if (objectSpace == 0) { stringstream msg; msg << "ObjectFile::Fatal Error. objectSpace is not set." << std::endl; @@ -138,8 +140,11 @@ class ObjectFile : public ArrayFile { } const std::type_info &otype = objectSpace->getObjectType(); size_t dim = objectSpace->getDimension(); - data.resize(pseudoDimension, 0); - if (otype == typeid(uint8_t)) { + data.resize(pseudoDimension); + if (typeid(T) == otype) { + auto *v = object->getPointer(); + memcpy(data.data(), v, sizeof(T) * dim); + } else if (otype == typeid(uint8_t)) { auto *v = static_cast(object->getPointer()); for (size_t i = 0; i < dim; i++) { data[i] = v[i]; @@ -151,7 +156,12 @@ class ObjectFile : public ArrayFile { } } else if (otype == typeid(float)) { auto *v = static_cast(object->getPointer()); - memcpy(data.data(), v, sizeof(float) * dim); + for (size_t i = 0; i < dim; i++) { + data[i] = v[i]; + } + } + for (size_t i = dim; i < pseudoDimension; i++) { + data[i] = 0; } objectSpace->deleteObject(object); return true; diff --git a/lib/NGT/NGTQ/QuantizedBlobGraph.h b/lib/NGT/NGTQ/QuantizedBlobGraph.h index 0d7e16a..07c7005 100644 --- a/lib/NGT/NGTQ/QuantizedBlobGraph.h +++ b/lib/NGT/NGTQ/QuantizedBlobGraph.h @@ -928,24 +928,39 @@ namespace QBG { auto threadid = omp_get_thread_num(); auto paddedDimension = getQuantizer().globalCodebookIndex.getObjectSpace().getPaddedDimension(); NGT::ResultPriorityQueue rs; - std::vector object; qresults.resize(results.size()); size_t idx = results.size(); while (!results.empty()) { auto r = results.top(); results.pop(); + if (objectSpace.getObjectType() == typeid(float)) { + std::vector object; #ifdef MULTIPLE_OBJECT_LISTS - quantizer.objectList.get(threadid, r.id, object, &quantizer.globalCodebookIndex.getObjectSpace()); + quantizer.objectList.get(threadid, r.id, object); #else - quantizer.objectList.get(r.id, object, &quantizer.globalCodebookIndex.getObjectSpace()); + quantizer.objectList.get(r.id, object); #endif - if (objectSpace.getObjectType() == typeid(float)) { r.distance = NGT::PrimitiveComparator::compareL2(static_cast(searchContainer.object.getPointer()), static_cast(object.data()), paddedDimension); + } else if (objectSpace.getObjectType() == typeid(uint8_t)) { + std::vector object; +#ifdef MULTIPLE_OBJECT_LISTS + quantizer.objectList.get(threadid, r.id, object); +#else + quantizer.objectList.get(r.id, object); +#endif + r.distance = NGT::PrimitiveComparator::compareL2(static_cast(searchContainer.object.getPointer()), + static_cast(object.data()), paddedDimension); #ifdef NGT_HALF_FLOAT } else if (objectSpace.getObjectType() == typeid(NGT::float16)) { - r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast(searchContainer.object.getPointer()), - reinterpret_cast(object.data()), paddedDimension); + std::vector object; +#ifdef MULTIPLE_OBJECT_LISTS + quantizer.objectList.get(threadid, r.id, object); +#else + quantizer.objectList.get(r.id, object); +#endif + r.distance = NGT::PrimitiveComparator::compareL2(static_cast(searchContainer.object.getPointer()), + static_cast(object.data()), paddedDimension); #endif } qresults[--idx] = r; @@ -961,20 +976,35 @@ namespace QBG { auto threadid = omp_get_thread_num(); auto paddedDimension = getQuantizer().globalCodebookIndex.getObjectSpace().getPaddedDimension(); NGT::ResultPriorityQueue rs; - std::vector object; while (!results.empty()) { auto r = results.top(); results.pop(); + if (objectSpace.getObjectType() == typeid(float)) { + std::vector object; #ifdef MULTIPLE_OBJECT_LISTS - quantizer.objectList.get(threadid, r.id, object, &quantizer.globalCodebookIndex.getObjectSpace()); + quantizer.objectList.get(threadid, r.id, object); #else - quantizer.objectList.get(r.id, object, &quantizer.globalCodebookIndex.getObjectSpace()); + quantizer.objectList.get(r.id, object); #endif - if (objectSpace.getObjectType() == typeid(float)) { r.distance = NGT::PrimitiveComparator::compareL2(static_cast(searchContainer.object.getPointer()), static_cast(object.data()), paddedDimension); + } else if (objectSpace.getObjectType() == typeid(uint8_t)) { + std::vector object; +#ifdef MULTIPLE_OBJECT_LISTS + quantizer.objectList.get(threadid, r.id, object); +#else + quantizer.objectList.get(r.id, object); +#endif + r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast(searchContainer.object.getPointer()), + reinterpret_cast(object.data()), paddedDimension); #ifdef NGT_HALF_FLOAT } else if (objectSpace.getObjectType() == typeid(NGT::float16)) { + std::vector object; +#ifdef MULTIPLE_OBJECT_LISTS + quantizer.objectList.get(threadid, r.id, object); +#else + quantizer.objectList.get(r.id, object); +#endif r.distance = NGT::PrimitiveComparator::compareL2(reinterpret_cast(searchContainer.object.getPointer()), reinterpret_cast(object.data()), paddedDimension); #endif