diff --git a/CMakeLists.txt b/CMakeLists.txt index c0f165c..d9c2cde 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,31 +1,28 @@ cmake_minimum_required(VERSION 2.8) -project( ngt ) +project(ngt) -set(ngt_VERSION_MAJOR 1 ) -set(ngt_VERSION_MINOR 3 ) -set(ngt_VERSION_PATCH 2 ) +set(ngt_VERSION_MAJOR 1) +set(ngt_VERSION_MINOR 3) +set(ngt_VERSION_PATCH 3) -set( ngt_VERSION ${ngt_VERSION_MAJOR}.${ngt_VERSION_MINOR}.${ngt_VERSION_PATCH} ) -set( ngt_SOVERSION ${ngt_VERSION_MAJOR} ) +set(ngt_VERSION ${ngt_VERSION_MAJOR}.${ngt_VERSION_MINOR}.${ngt_VERSION_PATCH}) +set(ngt_SOVERSION ${ngt_VERSION_MAJOR}) if (NOT CMAKE_BUILD_TYPE) set (CMAKE_BUILD_TYPE "Release") endif (NOT CMAKE_BUILD_TYPE) +string(TOLOWER ${CMAKE_BUILD_TYPE} CMAKE_BUILD_TYPE_LOWER) message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") - -if( ${UNIX} ) - if( CMAKE_VERSION VERSION_LESS 3.1 ) - link_directories("/usr/lib64") - - set(CMAKE_SKIP_BUILD_RPATH TRUE) - set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) - - set(BUILD_DATE_OPTION "-DBUILD_DATE=\"\\\"`date +'%Y/%m/%d %H:%M:%S'`\\\"\"") - set(GIT_HASH_OPTION "-DGIT_HASH=\"\\\"`git log -1 --format='%H'`\\\"\"") - set(GIT_DATE_OPTION "-DGIT_DATE=\"\\\"`git log -1 --format='%cd'`\\\"\"") - set(GIT_TAG_OPTION "-DGIT_TAG=\"\\\"`git describe --abbrev=0`\\\"\"") - +message(STATUS "CMAKE_BUILD_TYPE_LOWER: ${CMAKE_BUILD_TYPE_LOWER}") + +if(${UNIX}) + set(BUILD_DATE_OPTION "-DBUILD_DATE=\"\\\"`date +'%Y/%m/%d %H:%M:%S'`\\\"\"") + set(GIT_HASH_OPTION "-DGIT_HASH=\"\\\"`git log -1 --format='%H'`\\\"\"") + set(GIT_DATE_OPTION "-DGIT_DATE=\"\\\"`git log -1 --format='%cd'`\\\"\"") + set(GIT_TAG_OPTION "-DGIT_TAG=\"\\\"`git describe --abbrev=0`\\\"\"") + set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) + if(CMAKE_VERSION VERSION_LESS 3.1) set(BASE_OPTIONS "-Wall -std=gnu++0x -lrt ${BUILD_DATE_OPTION} ${GIT_HASH_OPTION} ${GIT_DATE_OPTION} ${GIT_TAG_OPTION}") if( ${NGT_AVX_DISABLED} ) message(STATUS "AVX will not be used to compute distances.") @@ -33,14 +30,19 @@ if( ${UNIX} ) set(BASE_OPTIONS "${BASE_OPTIONS} -mavx") endif() set(CMAKE_CXX_FLAGS_DEBUG "-g ${BASE_OPTIONS}") - set(CMAKE_CXX_FLAGS_RELEASE "-O3 ${BASE_OPTIONS}") + set(CMAKE_CXX_FLAGS_RELEASE "-O3 -march=native ${BASE_OPTIONS}") else() - option(WALL "enable all warnings" ON) - if( ${WALL} ) - add_compile_options(-Wall) - endif() - - # CMAKE_CXX_STANDARD is supported from CMake 3.1 + add_definitions(${BUILD_DATE_OPTION} ${GIT_HASH_OPTION} ${GIT_DATE_OPTION} ${GIT_TAG_OPTION}) + if (CMAKE_BUILD_TYPE_LOWER STREQUAL "release") + set(CMAKE_CXX_FLAGS_RELEASE "") + add_compile_options(-Ofast -march=native -DNDEBUG) + endif() + add_compile_options(-Wall -lrt) + if(${NGT_AVX_DISABLED}) + message(STATUS "AVX will not be used to compute distances.") + else() + add_compile_options(-mavx) + endif() set(CMAKE_CXX_STANDARD 11) # for std::unordered_set, std::unique_ptr set(CMAKE_CXX_STANDARD_REQUIRED ON) endif() diff --git a/bin/ngt/Command.h b/bin/ngt/Command.h index 092ce53..df241d0 100644 --- a/bin/ngt/Command.h +++ b/bin/ngt/Command.h @@ -658,7 +658,7 @@ class Command { void prune(Args &args) { - const string usage = "Usage: ngt prune -e #-of-forcedly-pruned-edges -s #-of-selecively-pruned-edge"; + const string usage = "Usage: ngt prune -e #-of-forcedly-pruned-edges -s #-of-selecively-pruned-edge index(in/out)"; string indexName; try { indexName = args.get("#1"); diff --git a/lib/NGT/Graph.cpp b/lib/NGT/Graph.cpp index e6e20f5..378416e 100644 --- a/lib/NGT/Graph.cpp +++ b/lib/NGT/Graph.cpp @@ -73,9 +73,11 @@ NeighborhoodGraph::setupSeeds(NGT::SearchContainer &sc, ObjectDistances &seeds, std::sort(tmp.begin(), tmp.end()); +#if 0 if (tmp.size() > (size_t)property.seedSize) { tmp.resize(property.seedSize); } +#endif #ifdef NGT_GRAPH_UNCHECK_STACK for (ObjectDistances::reverse_iterator ri = tmp.rbegin(); ri != tmp.rend(); ri++) { diff --git a/lib/NGT/Graph.h b/lib/NGT/Graph.h index 82574d8..ed9a00b 100644 --- a/lib/NGT/Graph.h +++ b/lib/NGT/Graph.h @@ -45,7 +45,8 @@ #endif #ifndef NGT_SEED_SIZE -#define NGT_SEED_SIZE 10 +//#define NGT_SEED_SIZE 10 +#define NGT_SEED_SIZE 50 #endif #ifndef NGT_CREATION_EDGE_SIZE diff --git a/lib/NGT/Index.h b/lib/NGT/Index.h index 85532db..a8000d5 100644 --- a/lib/NGT/Index.h +++ b/lib/NGT/Index.h @@ -245,12 +245,14 @@ namespace NGT { virtual void load(const string &ifile, size_t dataSize) { getIndex().load(ifile, dataSize); } virtual void append(const string &ifile, size_t dataSize) { getIndex().append(ifile, dataSize); } virtual void append(const float *data, size_t dataSize) { getIndex().append(data, dataSize); } + virtual void append(const double *data, size_t dataSize) { getIndex().append(data, dataSize); } virtual size_t getObjectRepositorySize() { return getIndex().getObjectRepositorySize(); } virtual void createIndex(size_t threadNumber) { getIndex().createIndex(threadNumber); } virtual void saveIndex(const string &ofile) { getIndex().saveIndex(ofile); } virtual void loadIndex(const string &ofile) { getIndex().loadIndex(ofile); } virtual Object *allocateObject(const string &textLine, const string &sep) { return getIndex().allocateObject(textLine, sep); } virtual Object *allocateObject(vector &obj) { return getIndex().allocateObject(obj); } + virtual Object *allocateObject(vector &obj) { return getIndex().allocateObject(obj); } virtual size_t getSizeOfElement() { return getIndex().getSizeOfElement(); } virtual void setProperty(NGT::Property &prop) { getIndex().setProperty(prop); } virtual void getProperty(NGT::Property &prop) { getIndex().getProperty(prop); } @@ -384,9 +386,8 @@ namespace NGT { objectSpace->appendText(is, dataSize); } - virtual void append(const float *data, size_t dataSize) { - objectSpace->append(data, dataSize); - } + virtual void append(const float *data, size_t dataSize) { objectSpace->append(data, dataSize); } + virtual void append(const double *data, size_t dataSize) { objectSpace->append(data, dataSize); } virtual void saveIndex(const string &ofile) { #ifndef NGT_SHARED_MEMORY_ALLOCATOR @@ -982,9 +983,8 @@ namespace NGT { return objectSpace->allocateObject(textLine, sep); } - Object *allocateObject(vector &obj) { - return objectSpace->allocateObject(obj); - } + Object *allocateObject(vector &obj) { return objectSpace->allocateObject(obj); } + Object *allocateObject(vector &obj) { return objectSpace->allocateObject(obj); } void deleteObject(Object *po) { return objectSpace->deleteObject(po); @@ -1216,6 +1216,7 @@ namespace NGT { } // if seedSize is zero, the result size of the query is used as seedSize. size_t seedSize = NeighborhoodGraph::property.seedSize == 0 ? sc.size : NeighborhoodGraph::property.seedSize; + seedSize = seedSize > sc.size ? sc.size : seedSize; if (seeds.size() > seedSize) { srand(tso.nodeID.getID()); // to accelerate thinning data. @@ -1402,7 +1403,7 @@ NGT::Index::append(const string &database, const string &dataFile, size_t thread if (dataFile.size() != 0) { index.append(dataFile, dataSize); } else { - NGTThrowException("Index::create: No data file."); + NGTThrowException("Index::append: No data file."); } timer.stop(); cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; @@ -1424,7 +1425,7 @@ NGT::Index::append(const string &database, const float *data, size_t dataSize, s if (data != 0 && dataSize != 0) { index.append(data, dataSize); } else { - NGTThrowException("Index::create: No data."); + NGTThrowException("Index::append: No data."); } timer.stop(); cerr << "Data loading time=" << timer.time << " (sec) " << timer.time * 1000.0 << " (msec)" << endl; diff --git a/lib/NGT/MmapManager.cpp b/lib/NGT/MmapManager.cpp index 7a3b72c..6be2e63 100644 --- a/lib/NGT/MmapManager.cpp +++ b/lib/NGT/MmapManager.cpp @@ -249,13 +249,13 @@ namespace MemoryManager{ return -1; } - if(size > _impl->mmapCntlHead->base_size + sizeof(chunk_head_st)){ + size_t alloc_size = getAlignSize(size); + + if( (alloc_size + sizeof(chunk_head_st)) >= _impl->mmapCntlHead->base_size ){ std::cerr << "alloc size over. size=" << size << "." << std::endl; return -1; } - size_t alloc_size = getAlignSize(size); - if(!not_reuse_flag){ if( _impl->mmapCntlHead->reuse_type == REUSE_DATA_CLASSIFY || _impl->mmapCntlHead->reuse_type == REUSE_DATA_QUEUE diff --git a/lib/NGT/ObjectSpace.h b/lib/NGT/ObjectSpace.h index e091db3..d14a022 100644 --- a/lib/NGT/ObjectSpace.h +++ b/lib/NGT/ObjectSpace.h @@ -18,7 +18,10 @@ #if !defined(NGT_AVX_DISABLED) && defined(__AVX__) +#warning "***** AVX is available! ************************************************" #include +#else +#warning "***** AVX is *NOT* available! ************************************************" #endif #include "Common.h" @@ -243,6 +246,7 @@ namespace NGT { virtual void readText(istream &is, size_t dataSize) = 0; virtual void appendText(ifstream &is, size_t dataSize) = 0; virtual void append(const float *data, size_t dataSize) = 0; + virtual void append(const double *data, size_t dataSize) = 0; virtual void copy(Object &objecta, Object &objectb) = 0; virtual void linearSearch(Object &query, double radius, size_t size, @@ -255,6 +259,7 @@ namespace NGT { virtual size_t getByteSizeOfObject() = 0; virtual Object *allocateObject(const string &textLine, const string &sep) = 0; virtual Object *allocateObject(vector &obj) = 0; + virtual Object *allocateObject(vector &obj) = 0; virtual void deleteObject(Object *po) = 0; virtual Object *allocateObject() = 0; virtual void remove(size_t id) = 0; @@ -851,7 +856,7 @@ namespace NGT { #else ComparatorCosineSimilarity(size_t d) : Comparator(d) {} double operator()(Object &objecta, Object &objectb) { - return ObjectSpaceT::compareAngleDistance((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); + return ObjectSpaceT::compareCosineSimilarity((OBJECT_TYPE*)&objecta[0], (OBJECT_TYPE*)&objectb[0], dimension); } #endif }; @@ -1127,7 +1132,8 @@ namespace NGT { return (double)count; } - inline static double compareAngleDistance(OBJECT_TYPE *a, OBJECT_TYPE *b, size_t size) { +#if defined(NGT_AVX_DISABLED) || !defined(__AVX__) + inline static double compareCosine(OBJECT_TYPE *a, OBJECT_TYPE *b, size_t size) { // Calculate the norm of A and B (the supplied vector). double normA = 0.0F; double normB = 0.0F; @@ -1137,26 +1143,62 @@ namespace NGT { normB += (double)b[loc] * (double)b[loc]; sum += (double)a[loc] * (double)b[loc]; } - assert(normA > 0.0F); assert(normB > 0.0F); // Compute the dot product of the two vectors. - double cosine = sum / (sqrt(normA) * sqrt(normB)); - // Compute the vector angle from the cosine value, and return. - // Roundoff error could have put the cosine value out of range. - // Handle these cases explicitly. - if (cosine >= 1.0F) { - return 0.0F; - } else if (cosine <= -1.0F) { - return acos (-1.0F); - } else { - return acos (cosine); + double cosine = sum / sqrt(normA * normB); + + return cosine; + } +#else + inline static double compareCosine(float *a, float *b, size_t size) { + // Calculate the norm of A and B (the supplied vector). + + __m256 normA = _mm256_setzero_ps(); + __m256 normB = _mm256_setzero_ps(); + __m256 sum = _mm256_setzero_ps(); + float *last = a + size; + float *lastgroup = last - 7; + while (a < lastgroup) { + __m256 am = _mm256_loadu_ps(a); + __m256 bm = _mm256_loadu_ps(b); + normA = _mm256_add_ps(normA, _mm256_mul_ps(am, am)); + normB = _mm256_add_ps(normB, _mm256_mul_ps(bm, bm)); + sum = _mm256_add_ps(sum, _mm256_mul_ps(am, bm)); + a += 8; + b += 8; } + __attribute__((aligned(32))) float f[8]; + + _mm256_store_ps(f, normA); + double na = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7]; + _mm256_store_ps(f, normB); + double nb = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7]; + _mm256_store_ps(f, sum); + double s = f[0] + f[1] + f[2] + f[3] + f[4] + f[5] + f[6] + f[7]; + + while (a < last) { + double av = *a; + double bv = *b; + na += av * av; + nb += bv * bv; + s += av * bv; + a++; + b++; + } + + + assert(na > 0.0F); + assert(nb > 0.0F); + + double cosine = s / sqrt(na * nb); + + return cosine; } - inline static double compareCosineSimilarity(OBJECT_TYPE *a, OBJECT_TYPE *b, size_t size) { + inline static double compareCosine(unsigned char *a, unsigned char *b, size_t size) { // Calculate the norm of A and B (the supplied vector). double normA = 0.0F; double normB = 0.0F; @@ -1166,14 +1208,32 @@ namespace NGT { normB += (double)b[loc] * (double)b[loc]; sum += (double)a[loc] * (double)b[loc]; } - assert(normA > 0.0F); assert(normB > 0.0F); // Compute the dot product of the two vectors. - double cosine = sum / (sqrt(normA) * sqrt(normB)); + double cosine = sum / sqrt(normA * normB); + + return cosine; + } +#endif // #if defined(NGT_AVX_DISABLED) || !defined(__AVX__) + + inline static double compareAngleDistance(OBJECT_TYPE *a, OBJECT_TYPE *b, size_t size) { + double cosine = compareAngleDistance(a, b, size); + // Compute the vector angle from the cosine value, and return. + // Roundoff error could have put the cosine value out of range. + // Handle these cases explicitly. + if (cosine >= 1.0F) { + return 0.0F; + } else if (cosine <= -1.0F) { + return acos(-1.0F); + } else { + return acos(cosine); + } + } - return 1.0 - cosine; + inline static double compareCosineSimilarity(OBJECT_TYPE *a, OBJECT_TYPE *b, size_t size) { + return 1.0 - compareCosine(a, b, size); } void serialize(const string &ofile) { ObjectRepository::serialize(ofile, this); } @@ -1183,6 +1243,7 @@ namespace NGT { void readText(istream &is, size_t dataSize) { ObjectRepository::readText(is, dataSize); } void appendText(ifstream &is, size_t dataSize) { ObjectRepository::appendText(is, dataSize); } void append(const float *data, size_t dataSize) { ObjectRepository::append(data, dataSize); } + void append(const double *data, size_t dataSize) { ObjectRepository::append(data, dataSize); } @@ -1245,6 +1306,9 @@ namespace NGT { Object *allocateObject(vector &obj) { return ObjectRepository::allocateObject(obj); } + Object *allocateObject(vector &obj) { + return ObjectRepository::allocateObject(obj); + } size_t getSize() { return ObjectRepository::size(); } size_t getSizeOfElement() { return sizeof(OBJECT_TYPE); } diff --git a/lib/NGT/defines.h.in b/lib/NGT/defines.h.in index 53cfd9b..e388374 100644 --- a/lib/NGT/defines.h.in +++ b/lib/NGT/defines.h.in @@ -33,7 +33,11 @@ #define NGT_COMPACT_VECTOR + +#ifndef NGT_GRAPH_CHECK_VECTOR #define NGT_GRAPH_CHECK_BOOLEANSET // use original booleanset to check whether nodes were accessed. +#endif + #if defined(NGT_GRAPH_CHECK_BOOLEANSET) || defined(NGT_GRAPH_CHECK_BITSET) #define NGT_GRAPH_CHECK_VECTOR // use vector to check whether nodes were accessed. #endif diff --git a/python/README.md b/python/README.md index ba19fc2..44a7f19 100644 --- a/python/README.md +++ b/python/README.md @@ -8,7 +8,7 @@ You **MUST** install the NGT library according to the [README](../README.md#buil ``` cd NGT_ROOT/python python setup.py sdist -pip install dist/ngt-1.0.0.tar.gz +pip install dist/ngt-1.1.0.tar.gz ``` ## Usage diff --git a/python/ngt/base.py b/python/ngt/base.py index 03f3e43..f2111f7 100644 --- a/python/ngt/base.py +++ b/python/ngt/base.py @@ -56,7 +56,7 @@ class Index(object): objects.append(vector) query = objects[0] - index = ngt.Index.create("tmp", dim) + index = ngt.Index.create(b"tmp", dim) index.insert(objects) # You can also insert objects from a file like this. # index.insert_from_tsv('list.dat') diff --git a/python/setup.py b/python/setup.py index 6e5f825..8d2d881 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,15 +1,18 @@ #!/usr/bin/env python import os +import sys import json import glob from setuptools import setup +if sys.version_info.major >= 3: + from setuptools import Extension + import pip + +version = '1.1.0' basedir = os.path.abspath(os.path.dirname(__file__)) -version = '1.0.0' -# Create a dictionary of our arguments, this way this script can be imported -# without running setup() to allow external scripts to see the setup settings. args = { 'name': 'ngt', 'version': version, @@ -17,18 +20,26 @@ 'author_email': 'https://www.yahoo-help.jp/', 'url': 'https://github.com/yahoojapan/NGT', 'license': 'Apache License Version 2.0', - 'packages': ['ngt'], - #'namespace_packages': ['ngt'], + 'packages': ['ngt'] } + +if sys.version_info.major >= 3: + module1 = Extension('ngtpy', + include_dirs=['/usr/local/include', + os.path.dirname(pip.locations.distutils_scheme('pybind11')['headers']), + os.path.dirname(pip.locations.distutils_scheme('pybind11', True)['headers'])], + library_dirs=['/usr/local/lib', '/usr/local/lib64'], + libraries=['ngt'], + extra_compile_args=['-std=c++11', '-mavx', '-Ofast', '-march=native', '-lrt', '-DNDEBUG'], + sources=['src/ngtpy.cpp']) + args['ext_modules'] = [module1] + setup_arguments = args -# Add any scripts we want to package if os.path.isdir('scripts'): setup_arguments['scripts'] = [ os.path.join('scripts', f) for f in os.listdir('scripts') ] - if __name__ == '__main__': - # We're being run from the command line so call setup with our arguments setup(**setup_arguments) diff --git a/python/src/ngtpy.cpp b/python/src/ngtpy.cpp new file mode 100644 index 0000000..ac34d34 --- /dev/null +++ b/python/src/ngtpy.cpp @@ -0,0 +1,176 @@ +// +// Copyright (C) 2018 Yahoo Japan Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +#include "NGT/Index.h" + +#include +#include +#include + +namespace py = pybind11; + +class Index : public NGT::Index { +public: + Index( + const string path, // ngt index path. + bool zeroBasedNumbering = true // object ID numbering. + ):NGT::Index(path) { + indexDecrement = zeroBasedNumbering ? 1 : 0; + } + + static void create( + const string path, + size_t dimension, + int edgeSizeForCreation = 10, + int edgeSizeForSearch = 40, + const string distanceType = "L2", + const string objectType = "Float" + ) { + NGT::Property prop; + prop.dimension = dimension; + prop.edgeSizeForCreation = edgeSizeForCreation; + prop.edgeSizeForSearch = edgeSizeForSearch; + + if (objectType == "Float" || objectType == "float") { + prop.objectType = NGT::Index::Property::ObjectType::Float; + } else if (objectType == "Byte" || objectType == "byte") { + prop.objectType = NGT::Index::Property::ObjectType::Uint8; + } else { + std::cerr << "ngtpy::create: invalid object type. " << objectType << std::endl; + return; + } + + if (distanceType == "L1") { + prop.distanceType = NGT::Index::Property::DistanceType::DistanceTypeL1; + } else if (distanceType == "L2") { + prop.distanceType = NGT::Property::DistanceType::DistanceTypeL2; + } else if (distanceType == "Hamming") { + prop.distanceType = NGT::Property::DistanceType::DistanceTypeHamming; + } else if (distanceType == "Angle") { + prop.distanceType = NGT::Property::DistanceType::DistanceTypeAngle; + } else if (distanceType == "Cosine") { + prop.distanceType = NGT::Property::DistanceType::DistanceTypeCosine; + } else { + std::cerr << "ngtpy::create: invalid distance type. " << distanceType << std::endl; + return; + } + NGT::Index::createGraphAndTree(path, prop); + } + + void batchInsert( + py::array_t objects, + size_t numThreads = 8, + bool debug = false + ) { + py::buffer_info info = objects.request(); + if (debug) { + std::cerr << info.shape.size() << ":" << info.shape[0] << ":" << info.shape[1] << std::endl; + } + auto ptr = static_cast(info.ptr); + assert(info.shape.size() == 2); + NGT::Property prop; + getProperty(prop); + if (prop.dimension != info.shape[1]) { + std::cerr << "ngtpy::insert: Error! dimensions are inconsitency. " << prop.dimension << ":" << info.shape[1] << std::endl; + return; + } +///- v(ptr, ptr + info.shape[1]); + ptr += info.shape[1]; + NGT::Index::insert(v); + } +#else +///->/ + NGT::Index::append(ptr, info.shape[0]); +///-/ + NGT::Index::createIndex(numThreads); + } + + std::vector search( + std::vector query, // query + size_t size = 10, // the number of resultant objects + float epsilon = 0.1, // search parameter epsilon. the adequate range is from 0.0 to 0.15. minus value is acceptable. + int edgeSize = -1 // the number of used edges for each node during the exploration of the graph. + ) { + NGT::Object *ngtquery = NGT::Index::allocateObject(query); + NGT::SearchContainer sc(*ngtquery); + NGT::ObjectDistances objects; + sc.setResults(&objects); // set the result set. + sc.setSize(size); // the number of resultant objects. + sc.setEpsilon(epsilon); // set exploration coefficient. + sc.setEdgeSize(edgeSize); // if maxEdge is minus, the specified value in advance is used. + + NGT::Index::search(sc); + + std::vector ids; + for (size_t i = 0; i < objects.size(); i++) { + ids.push_back(objects[i].id - indexDecrement); + } +///-/ + NGT::Index::deleteObject(ngtquery); + + return ids; + } + + size_t indexDecrement; // for object ID numbering. zero-based or one-based numbering. +}; + +PYBIND11_MODULE(ngtpy, m) { + m.doc() = "ngt python"; + + m.def("create", &::Index::create, + py::arg("path"), + py::arg("dimension"), + py::arg("edge_size_for_creation") = 10, + py::arg("edge_size_for_search") = 40, + py::arg("distance_type") = "L2", + py::arg("object_type") = "Float"); + + py::class_(m, "Index") + .def(py::init(), + py::arg("path"), + py::arg("zero_based_numbering") = true) + .def("search", &::Index::search, + py::arg("query"), + py::arg("size") = 10, + py::arg("epsilon") = 0.1, + py::arg("edge_size") = -1) + .def("save", &NGT::Index::saveIndex, + py::arg("path")) + .def("close", &NGT::Index::close) + .def("batch_insert", &::Index::batchInsert, + py::arg("objects"), + py::arg("num_threads") = 8, + py::arg("debug") = false); +} +