From 9b59603ebc64a3006c81abfb3c80096984bb34ca Mon Sep 17 00:00:00 2001 From: Masajiro Iwasaki Date: Mon, 10 Jul 2023 14:56:32 +0900 Subject: [PATCH] fix the issue #138 --- VERSION | 2 +- lib/NGT/NGTQ/QuantizedBlobGraph.h | 11 +++++++++-- lib/NGT/NGTQ/Quantizer.h | 16 ++++++++++++++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/VERSION b/VERSION index 3d45b5c..b8061b5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.0.14 +2.0.15 diff --git a/lib/NGT/NGTQ/QuantizedBlobGraph.h b/lib/NGT/NGTQ/QuantizedBlobGraph.h index 06a9962..0d7e16a 100644 --- a/lib/NGT/NGTQ/QuantizedBlobGraph.h +++ b/lib/NGT/NGTQ/QuantizedBlobGraph.h @@ -785,6 +785,11 @@ namespace QBG { std::vector &rotatedQuery = searchContainer.objectVector; if (objectSpace.getObjectType() == typeid(float)) { memcpy(rotatedQuery.data(), searchContainer.object.getPointer(), rotatedQuery.size() * sizeof(float)); + } else if (objectSpace.getObjectType() == typeid(uint8_t)) { + auto *ptr = static_cast(searchContainer.object.getPointer()); + for (size_t i = 0; i < rotatedQuery.size(); i++) { + rotatedQuery[i] = ptr[i]; + } #ifdef NGT_HALF_FLOAT } else if (objectSpace.getObjectType() == typeid(NGT::float16)) { auto *ptr = static_cast(searchContainer.object.getPointer()); @@ -893,10 +898,13 @@ namespace QBG { if (objectSpace.getObjectType() == typeid(float)) { distance = NGT::PrimitiveComparator::L2Float::compare(searchContainer.object.getPointer(), neighborptr->second->getPointer(), dimension); + } else if (objectSpace.getObjectType() == typeid(uint8_t)) { + distance = NGT::PrimitiveComparator::L2Uint8::compare(searchContainer.object.getPointer(), + neighborptr->second->getPointer(), dimension); #ifdef NGT_HALF_FLOAT } else if (objectSpace.getObjectType() == typeid(NGT::float16)) { distance = NGT::PrimitiveComparator::L2Float16::compare(searchContainer.object.getPointer(), - neighborptr->second->getPointer(), dimension); + neighborptr->second->getPointer(), dimension); #endif } else { assert(false); @@ -1158,7 +1166,6 @@ namespace QBG { } const string com = "rm -rf " + indexPath + "/" + getWorkspaceName(); - std::cerr << "pass com=" << com << std::endl; if (system(com.c_str()) == -1) { std::cerr << "Warning. cannot remove the workspace directory. " << std::endl; } diff --git a/lib/NGT/NGTQ/Quantizer.h b/lib/NGT/NGTQ/Quantizer.h index 7c9c8fa..ad09a4e 100644 --- a/lib/NGT/NGTQ/Quantizer.h +++ b/lib/NGT/NGTQ/Quantizer.h @@ -658,13 +658,15 @@ class SerializableObject : public NGT::Object { dataSize = sizeof(float) * dimension; #endif break; +#ifdef NGT_HALF_FLOAT case DataTypeFloat16: #ifdef NGTQ_QBG - dataSize = sizeof(float) * genuineDimension; + dataSize = sizeof(NGT::float16) * genuineDimension; #else - dataSize = sizeof(float) * dimension; + dataSize = sizeof(NGT::float16) * dimension; #endif break; +#endif default: NGTThrowException("Quantizer constructor: Inner error. Invalid data type."); break; @@ -2590,11 +2592,21 @@ class QuantizerInstance : public Quantizer { quantizedObjectDistance = new QuantizedObjectDistanceUint8; } else if (property.localIDByteSize == 2) { quantizedObjectDistance = new QuantizedObjectDistanceUint8; +#ifdef NGTQ_QBG + } else if (property.localIDByteSize == 1) { + quantizedObjectDistance = new QuantizedObjectDistanceFloat; +#endif } else { + std::cerr << "Inconsistent localIDByteSize and ObjectType. " << property.localIDByteSize << ":" << globalProperty.objectType << std::endl; abort(); } +#ifdef NGTQ_VECTOR_OBJECT + generateResidualObject = new GenerateResidualObjectFloat; + sizeoftype = sizeof(float); +#else generateResidualObject = new GenerateResidualObjectUint8; sizeoftype = sizeof(uint8_t); +#endif } else { cerr << "NGTQ::open: Fatal Inner Error: invalid object type. " << globalProperty.objectType << endl; cerr << " check NGT version consistency between the caller and the library." << endl;