From c1c0fd4312b10e874ba2026802097a9fb170d51d Mon Sep 17 00:00:00 2001 From: Muki Kiboigo Date: Fri, 24 Nov 2023 00:42:57 -0800 Subject: [PATCH] feat: add db flagging & checking --- src/databases.cpp | 56 ++++++++++++++++++++++++++++------------------- src/databases.hpp | 35 ++++++++++++++++------------- src/main.cpp | 7 +++++- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/src/databases.cpp b/src/databases.cpp index 93302d61..9765fd76 100644 --- a/src/databases.cpp +++ b/src/databases.cpp @@ -19,7 +19,7 @@ const int32_t PairDistanceKVectorDatabase::kMagicValue = 0x2536f009; struct KVectorPair { int16_t index1; int16_t index2; - float distance; + decimal distance; }; bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) { @@ -33,8 +33,8 @@ bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) { | size | name | description | |---------------+------------+-------------------------------------------------------------| | 4 | numEntries | | - | sizeof float | min | minimum value contained in the database | - | sizeof float | max | max value contained in index | + | sizeof decimal | min | minimum value contained in the database | + | sizeof decimal | max | max value contained in index | | 4 | numBins | | | 4*(numBins+1) | bins | The `i'th bin (starting from zero) stores how many pairs of | | | | stars have a distance lesst han or equal to: | @@ -57,9 +57,9 @@ bool CompareKVectorPairs(const KVectorPair &p1, const KVectorPair &p2) { * @param numBins the number of "bins" the KVector should use. A higher number makes query results "tighter" but takes up more disk space. Usually should be set somewhat smaller than (max-min) divided by the "width" of the typical query. * @param buffer[out] index is written here. */ -void SerializeKVectorIndex(SerializeContext *ser, const std::vector &values, float min, float max, long numBins) { +void SerializeKVectorIndex(SerializeContext *ser, const std::vector &values, decimal min, decimal max, long numBins) { std::vector kVector(numBins+1); // We store sums before and after each bin - float binWidth = (max - min) / numBins; + decimal binWidth = (max - min) / numBins; // generate the k-vector part // Idea: When we find the first star that's across any bin boundary, we want to update all the newly sealed bins @@ -92,8 +92,8 @@ void SerializeKVectorIndex(SerializeContext *ser, const std::vector &valu // metadata fields SerializePrimitive(ser, values.size()); - SerializePrimitive(ser, min); - SerializePrimitive(ser, max); + SerializePrimitive(ser, min); + SerializePrimitive(ser, max); SerializePrimitive(ser, numBins); // kvector index field @@ -106,8 +106,8 @@ void SerializeKVectorIndex(SerializeContext *ser, const std::vector &valu KVectorIndex::KVectorIndex(DeserializeContext *des) { numValues = DeserializePrimitive(des); - min = DeserializePrimitive(des); - max = DeserializePrimitive(des); + min = DeserializePrimitive(des); + max = DeserializePrimitive(des); numBins = DeserializePrimitive(des); assert(min >= 0.0f); @@ -122,7 +122,7 @@ KVectorIndex::KVectorIndex(DeserializeContext *des) { * @param upperIndex[out] Is set to the index of the last returned value +1. * @return the index (starting from zero) of the first value matching the query */ -long KVectorIndex::QueryLiberal(float minQueryDistance, float maxQueryDistance, long *upperIndex) const { +long KVectorIndex::QueryLiberal(decimal minQueryDistance, decimal maxQueryDistance, long *upperIndex) const { assert(maxQueryDistance > minQueryDistance); if (maxQueryDistance >= max) { maxQueryDistance = max - 0.00001; // TODO: better way to avoid hitting the bottom bin @@ -152,7 +152,7 @@ long KVectorIndex::QueryLiberal(float minQueryDistance, float maxQueryDistance, } /// return the lowest-indexed bin that contains the number of pairs with distance <= dist -long KVectorIndex::BinFor(float query) const { +long KVectorIndex::BinFor(decimal query) const { long result = (long)ceil((query - min) / binWidth); assert(result >= 0); assert(result <= numBins); @@ -168,7 +168,7 @@ long KVectorIndex::BinFor(float query) const { | sizeof kvectorIndex | kVectorIndex | Serialized KVector index | | 2*sizeof(int16)*numPairs | pairs | Bulk pair data | */ -std::vector CatalogToPairDistances(const Catalog &catalog, float minDistance, float maxDistance) { +std::vector CatalogToPairDistances(const Catalog &catalog, decimal minDistance, decimal maxDistance) { std::vector result; for (int16_t i = 0; i < (int16_t)catalog.size(); i++) { for (int16_t k = i+1; k < (int16_t)catalog.size(); k++) { @@ -191,13 +191,13 @@ std::vector CatalogToPairDistances(const Catalog &catalog, float mi * Serialize a pair-distance KVector into buffer. * Use SerializeLengthPairDistanceKVector to determine how large the buffer needs to be. See command line documentation for other options. */ -void SerializePairDistanceKVector(SerializeContext *ser, const Catalog &catalog, float minDistance, float maxDistance, long numBins) { +void SerializePairDistanceKVector(SerializeContext *ser, const Catalog &catalog, decimal minDistance, decimal maxDistance, long numBins) { std::vector kVector(numBins+1); // numBins = length, all elements zero std::vector pairs = CatalogToPairDistances(catalog, minDistance, maxDistance); // sort pairs in increasing order. std::sort(pairs.begin(), pairs.end(), CompareKVectorPairs); - std::vector distances; + std::vector distances; for (const KVectorPair &pair : pairs) { distances.push_back(pair.distance); @@ -221,7 +221,7 @@ PairDistanceKVectorDatabase::PairDistanceKVectorDatabase(DeserializeContext *des } /// Return the value in the range [low,high] which is closest to num -float Clamp(float num, float low, float high) { +decimal Clamp(decimal num, decimal low, decimal high) { return num < low ? low : num > high ? high : num; } @@ -231,7 +231,7 @@ float Clamp(float num, float low, float high) { * @return A pointer to the start of the matched pairs. Each pair is stored as simply two 16-bit integers, each of which is a catalog index. (you must increment the pointer twice to get to the next pair). */ const int16_t *PairDistanceKVectorDatabase::FindPairsLiberal( - float minQueryDistance, float maxQueryDistance, const int16_t **end) const { + decimal minQueryDistance, decimal maxQueryDistance, const int16_t **end) const { assert(maxQueryDistance <= M_PI); @@ -242,7 +242,7 @@ const int16_t *PairDistanceKVectorDatabase::FindPairsLiberal( } const int16_t *PairDistanceKVectorDatabase::FindPairsExact(const Catalog &catalog, - float minQueryDistance, float maxQueryDistance, const int16_t **end) const { + decimal minQueryDistance, decimal maxQueryDistance, const int16_t **end) const { // Instead of computing the angle for every pair in the database, we pre-compute the /cosines/ // of the min and max query distances so that we can compare against dot products directly! As @@ -250,8 +250,8 @@ const int16_t *PairDistanceKVectorDatabase::FindPairsExact(const Catalog &catalo // sense anyway) assert(maxQueryDistance <= M_PI); - float maxQueryCos = cos(minQueryDistance); - float minQueryCos = cos(maxQueryDistance); + decimal maxQueryCos = cos(minQueryDistance); + decimal minQueryCos = cos(maxQueryDistance); long liberalUpperIndex; long liberalLowerIndex = index.QueryLiberal(minQueryDistance, maxQueryDistance, &liberalUpperIndex); @@ -280,8 +280,8 @@ long PairDistanceKVectorDatabase::NumPairs() const { } /// Return the distances from the given star to each star it's paired with in the database (for debugging). -std::vector PairDistanceKVectorDatabase::StarDistances(int16_t star, const Catalog &catalog) const { - std::vector result; +std::vector PairDistanceKVectorDatabase::StarDistances(int16_t star, const Catalog &catalog) const { + std::vector result; for (int i = 0; i < NumPairs(); i++) { if (pairs[i*2] == star || pairs[i*2+1] == star) { result.push_back(AngleUnit(catalog[pairs[i*2]].spatial, catalog[pairs[i*2+1]].spatial)); @@ -296,6 +296,7 @@ std::vector PairDistanceKVectorDatabase::StarDistances(int16_t star, cons | size | name | description | |------+----------------+---------------------------------------------| | 4 | magicValue | unique database identifier | + | 4 | flags | [X, X, X, isDouble?] | | 4 | databaseLength | length in bytes (32-bit unsigned) | | n | database | the entire database. 8-byte aligned | | ... | ... | More databases (each has value, length, db) | @@ -318,6 +319,15 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const if (curMagicValue == 0) { return nullptr; } + uint32_t dbFlags = DeserializePrimitive(des); + + // Ensure that our database is using the same type as the runtime. + if(dbFlags & MULTI_DB_IS_DOUBLE) { + assert(typeid(decimal) == typeid(double)); + } else { + assert(typeid(decimal) == typeid(float)); + } + uint32_t dbLength = DeserializePrimitive(des); assert(dbLength > 0); DeserializePadding(des); // align to an 8-byte boundary @@ -331,9 +341,11 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const } void SerializeMultiDatabase(SerializeContext *ser, - const MultiDatabaseDescriptor &dbs) { + const MultiDatabaseDescriptor &dbs, + uint32_t flags) { for (const MultiDatabaseEntry &multiDbEntry : dbs) { SerializePrimitive(ser, multiDbEntry.magicValue); + SerializePrimitive(ser, flags); SerializePrimitive(ser, multiDbEntry.bytes.size()); SerializePadding(ser); std::copy(multiDbEntry.bytes.cbegin(), multiDbEntry.bytes.cend(), std::back_inserter(ser->buffer)); diff --git a/src/databases.hpp b/src/databases.hpp index c8337c11..e7d4cec4 100644 --- a/src/databases.hpp +++ b/src/databases.hpp @@ -22,27 +22,27 @@ class KVectorIndex { public: explicit KVectorIndex(DeserializeContext *des); - long QueryLiberal(float minQueryDistance, float maxQueryDistance, long *upperIndex) const; + long QueryLiberal(decimal minQueryDistance, decimal maxQueryDistance, long *upperIndex) const; /// The number of data points in the data referred to by the kvector long NumValues() const { return numValues; }; long NumBins() const { return numBins; }; /// Upper bound on elements - float Max() const { return max; }; + decimal Max() const { return max; }; // Lower bound on elements - float Min() const { return min; }; + decimal Min() const { return min; }; private: - long BinFor(float dist) const; + long BinFor(decimal dist) const; long numValues; - float min; - float max; - float binWidth; + decimal min; + decimal max; + decimal binWidth; long numBins; const int32_t *bins; }; -void SerializePairDistanceKVector(SerializeContext *, const Catalog &, float minDistance, float maxDistance, long numBins); +void SerializePairDistanceKVector(SerializeContext *, const Catalog &, decimal minDistance, decimal maxDistance, long numBins); /** * A database storing distances between pairs of stars. @@ -53,14 +53,14 @@ class PairDistanceKVectorDatabase { public: explicit PairDistanceKVectorDatabase(DeserializeContext *des); - const int16_t *FindPairsLiberal(float min, float max, const int16_t **end) const; - const int16_t *FindPairsExact(const Catalog &, float min, float max, const int16_t **end) const; - std::vector StarDistances(int16_t star, const Catalog &) const; + const int16_t *FindPairsLiberal(decimal min, decimal max, const int16_t **end) const; + const int16_t *FindPairsExact(const Catalog &, decimal min, decimal max, const int16_t **end) const; + std::vector StarDistances(int16_t star, const Catalog &) const; /// Upper bound on stored star pair distances - float MaxDistance() const { return index.Max(); }; + decimal MaxDistance() const { return index.Max(); }; /// Lower bound on stored star pair distances - float MinDistance() const { return index.Min(); }; + decimal MinDistance() const { return index.Min(); }; /// Exact number of stored pairs long NumPairs() const; @@ -85,7 +85,7 @@ class PairDistanceKVectorDatabase { // public: // explicit TripleInnerKVectorDatabase(const unsigned char *databaseBytes); -// void FindTriplesLiberal(float min, float max, long **begin, long **end) const; +// void FindTriplesLiberal(decimal min, decimal max, long **begin, long **end) const; // private: // KVectorIndex index; // int16_t *triples; @@ -96,6 +96,10 @@ class PairDistanceKVectorDatabase { * This is almost always the database that is actually passed to star-id algorithms in the real world, since you'll want to store at least the catalog plus one specific database. * Multi-databases are essentially a map from "magic values" to database buffers. */ + +#define MULTI_DB_IS_DOUBLE 0x0001 +#define MULTI_DB_IS_FLOAT 0x0000 + class MultiDatabase { public: /// Create a multidatabase from a serialized multidatabase. @@ -111,12 +115,13 @@ class MultiDatabaseEntry { : magicValue(magicValue), bytes(bytes) { } int32_t magicValue; + uint32_t flags; std::vector bytes; }; typedef std::vector MultiDatabaseDescriptor; -void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs); +void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs, uint32_t flags); } diff --git a/src/main.cpp b/src/main.cpp index 5d26fec4..cc936e16 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,6 +17,7 @@ #include "databases.hpp" #include "centroiders.hpp" +#include "decimal.hpp" #include "io.hpp" #include "man-database.h" #include "man-pipeline.h" @@ -30,9 +31,13 @@ static void DatabaseBuild(const DatabaseOptions &values) { MultiDatabaseDescriptor dbEntries = GenerateDatabases(narrowedCatalog, values); SerializeContext ser = serFromDbValues(values); - SerializeMultiDatabase(&ser, dbEntries); + + // Inject flags into the Serialized Database. + uint32_t dbFlags = typeid(decimal) == typeid(double) ? MULTI_DB_IS_DOUBLE : MULTI_DB_IS_FLOAT; + SerializeMultiDatabase(&ser, dbEntries, dbFlags); std::cerr << "Generated database with " << ser.buffer.size() << " bytes" << std::endl; + std::cerr << "Database flagged with " << dbFlags << std::endl; UserSpecifiedOutputStream pos = UserSpecifiedOutputStream(values.outputPath, true); pos.Stream().write((char *) ser.buffer.data(), ser.buffer.size());