diff --git a/hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java b/hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java index 780a66e8..dadc348f 100644 --- a/hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java +++ b/hnswlib-core/src/main/java/com/github/jelmerk/knn/bruteforce/BruteForceIndex.java @@ -5,6 +5,7 @@ import com.github.jelmerk.knn.Item; import com.github.jelmerk.knn.SearchResult; import com.github.jelmerk.knn.util.ClassLoaderObjectInputStream; +import com.github.jelmerk.knn.util.DummyComparator; import java.io.*; import java.nio.file.Files; @@ -26,6 +27,7 @@ public class BruteForceIndex, TDi private static final long serialVersionUID = 1L; + private final boolean immutable; private final int dimensions; private final DistanceFunction distanceFunction; private final Comparator distanceComparator; @@ -34,6 +36,7 @@ public class BruteForceIndex, TDi private final Map deletedItemVersions; private BruteForceIndex(BruteForceIndex.Builder builder) { + this.immutable = builder.immutable; this.dimensions = builder.dimensions; this.distanceFunction = builder.distanceFunction; this.distanceComparator = builder.distanceComparator; @@ -79,6 +82,9 @@ public int getDimensions() { */ @Override public boolean add(TItem item) { + if (immutable) { + throw new UnsupportedOperationException("Index is immutable"); + } if (item.dimensions() != dimensions) { throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions); } @@ -286,7 +292,7 @@ public static , TDistance> BruteF Builder newBuilder(int dimensions, DistanceFunction distanceFunction) { Comparator distanceComparator = Comparator.naturalOrder(); - return new Builder<>(dimensions, distanceFunction, distanceComparator); + return new Builder<>(false, dimensions, distanceFunction, distanceComparator); } /** @@ -301,7 +307,23 @@ Builder newBuilder(int dimensions, DistanceFunction Builder newBuilder(int dimensions, DistanceFunction distanceFunction, Comparator distanceComparator) { - return new Builder<>(dimensions, distanceFunction, distanceComparator); + return new Builder<>(false, dimensions, distanceFunction, distanceComparator); + } + + /** + * Creates an immutable empty index. + * + * @return the empty index + * @param Type of the external identifier of an item + * @param Type of the vector to perform distance calculation on + * @param Type of items stored in the index + * @param Type of distance between items (expect any numeric type: float, double, int, ..) + */ + public static , TDistance> BruteForceIndex empty() { + BruteForceIndex.Builder builder = new BruteForceIndex.Builder<>(true,0, (DistanceFunction) (u, v) -> { + throw new UnsupportedOperationException(); + }, new DummyComparator<>()); + return builder.build(); } /** @@ -318,7 +340,10 @@ public static class Builder { private final Comparator distanceComparator; - Builder(int dimensions, DistanceFunction distanceFunction, Comparator distanceComparator) { + private final boolean immutable; + + Builder(boolean immutable, int dimensions, DistanceFunction distanceFunction, Comparator distanceComparator) { + this.immutable = immutable; this.dimensions = dimensions; this.distanceFunction = distanceFunction; this.distanceComparator = distanceComparator; diff --git a/hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java b/hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java index 229e76e3..cf6330d9 100644 --- a/hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java +++ b/hnswlib-core/src/main/java/com/github/jelmerk/knn/hnsw/HnswIndex.java @@ -33,6 +33,7 @@ public class HnswIndex, TDistance implements Index { private static final byte VERSION_1 = 0x01; + private static final byte VERSION_2 = 0x02; private static final long serialVersionUID = 1L; @@ -42,6 +43,7 @@ public class HnswIndex, TDistance private Comparator distanceComparator; private MaxValueComparator maxValueDistanceComparator; + private boolean immutable; private int dimensions; private int maxItemCount; private int m; @@ -74,6 +76,7 @@ public class HnswIndex, TDistance private HnswIndex(RefinedBuilder builder) { + this.immutable = builder.immutable; this.dimensions = builder.dimensions; this.maxItemCount = builder.maxItemCount; this.distanceFunction = builder.distanceFunction; @@ -202,6 +205,9 @@ public boolean remove(TId id, long version) { */ @Override public boolean add(TItem item) { + if (immutable) { + throw new UnsupportedOperationException("Index is immutable"); + } if (item.dimensions() != dimensions) { throw new IllegalArgumentException("Item does not have dimensionality of : " + dimensions); } @@ -757,7 +763,7 @@ public void save(OutputStream out) throws IOException { } private void writeObject(ObjectOutputStream oos) throws IOException { - oos.writeByte(VERSION_1); + oos.writeByte(VERSION_2); oos.writeInt(dimensions); oos.writeObject(distanceFunction); oos.writeObject(distanceComparator); @@ -776,6 +782,7 @@ private void writeObject(ObjectOutputStream oos) throws IOException { writeMutableObjectLongMap(oos, deletedItemVersions); writeNodesArray(oos, nodes); oos.writeInt(entryPoint == null ? -1 : entryPoint.id); + oos.writeBoolean(immutable); } @SuppressWarnings("unchecked") @@ -802,6 +809,8 @@ private void readObject(ObjectInputStream ois) throws IOException, ClassNotFound this.nodes = readNodesArray(ois, itemSerializer, maxM0, maxM); int entrypointNodeId = ois.readInt(); + + this.immutable = version != VERSION_1 && ois.readBoolean(); this.entryPoint = entrypointNodeId == -1 ? null : nodes.get(entrypointNodeId); this.globalLock = new ReentrantLock(); @@ -1069,7 +1078,26 @@ public static > Builder distanceComparator = Comparator.naturalOrder(); - return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount); + return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount); + } + + /** + * Creates an immutable empty index. + * + * @return the empty index + * @param Type of the external identifier of an item + * @param Type of the vector to perform distance calculation on + * @param Type of items stored in the index + * @param Type of distance between items (expect any numeric type: float, double, int, ..) + */ + public static , TDistance> HnswIndex empty() { + Builder builder = new Builder<>(true, 0, new DistanceFunction() { + @Override + public TDistance distance(TVector u, TVector v) { + throw new UnsupportedOperationException(); + } + }, new DummyComparator<>(), 0); + return builder.build(); } /** @@ -1089,7 +1117,7 @@ public static Builder newBuilder( Comparator distanceComparator, int maxItemCount) { - return new Builder<>(dimensions, distanceFunction, distanceComparator, maxItemCount); + return new Builder<>(false, dimensions, distanceFunction, distanceComparator, maxItemCount); } private int assignLevel(TId value, double lambda) { @@ -1318,6 +1346,7 @@ public static abstract class BuilderBase distanceFunction; Comparator distanceComparator; @@ -1329,11 +1358,12 @@ public static abstract class BuilderBase distanceFunction, Comparator distanceComparator, int maxItemCount) { - + this.immutable = immutable; this.dimensions = dimensions; this.distanceFunction = distanceFunction; this.distanceComparator = distanceComparator; @@ -1417,12 +1447,13 @@ public static class Builder extends BuilderBase distanceFunction, Comparator distanceComparator, int maxItemCount) { - super(dimensions, distanceFunction, distanceComparator, maxItemCount); + super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount); } @Override @@ -1440,7 +1471,7 @@ Builder self() { * @return the builder */ public > RefinedBuilder withCustomSerializers(ObjectSerializer itemIdSerializer, ObjectSerializer itemSerializer) { - return new RefinedBuilder<>(dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction, + return new RefinedBuilder<>(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount, m, ef, efConstruction, removeEnabled, itemIdSerializer, itemSerializer); } @@ -1475,7 +1506,8 @@ public static class RefinedBuilder itemIdSerializer; private ObjectSerializer itemSerializer; - RefinedBuilder(int dimensions, + RefinedBuilder(boolean immutable, + int dimensions, DistanceFunction distanceFunction, Comparator distanceComparator, int maxItemCount, @@ -1486,7 +1518,7 @@ public static class RefinedBuilder itemIdSerializer, ObjectSerializer itemSerializer) { - super(dimensions, distanceFunction, distanceComparator, maxItemCount); + super(immutable, dimensions, distanceFunction, distanceComparator, maxItemCount); this.m = m; this.ef = ef; diff --git a/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/DummyComparator.java b/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/DummyComparator.java new file mode 100644 index 00000000..cff722f0 --- /dev/null +++ b/hnswlib-core/src/main/java/com/github/jelmerk/knn/util/DummyComparator.java @@ -0,0 +1,18 @@ +package com.github.jelmerk.knn.util; + +import java.io.Serializable; +import java.util.Comparator; + +/** + * Implementation of {@link Comparator} that is serializable and throws {@link UnsupportedOperationException} when + * compare is called. Useful as a dummy placeholder when you know it will never be called. + * + * @param the type of objects that may be compared by this comparator + */ +public class DummyComparator implements Comparator, Serializable { + + @Override + public int compare(T o1, T o2) { + throw new UnsupportedOperationException(); + } +} diff --git a/hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java b/hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java index a58d086a..6eec89d2 100644 --- a/hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java +++ b/hnswlib-core/src/test/java/com/github/jelmerk/knn/bruteforce/BruteForceIndexTest.java @@ -7,11 +7,13 @@ import java.io.IOException; import java.util.*; +import com.github.jelmerk.knn.hnsw.HnswIndex; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.BeforeEach; import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; class BruteForceIndexTest { @@ -152,6 +154,20 @@ void saveAndLoadIndex() throws IOException { assertThat(loadedIndex.size(), is(1)); } + @Test + void createEmptyIndex() { + BruteForceIndex index = BruteForceIndex.empty(); + + assertThrows( + UnsupportedOperationException.class, + () -> index.add(item1), + "Index should be immutable" + ); + + assertThat(index.size(), is(0)); + assertThat(index.getDimensions(), is(0)); + } + } diff --git a/hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java b/hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java index f0b769ca..b7834643 100644 --- a/hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java +++ b/hnswlib-core/src/test/java/com/github/jelmerk/knn/hnsw/HnswIndexTest.java @@ -12,6 +12,7 @@ import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; class HnswIndexTest { @@ -215,4 +216,18 @@ void saveAndLoadIndex() throws IOException { assertThat(loadedIndex.size(), is(1)); } + + @Test + void emptyIndexIsImmutable() { + HnswIndex index = HnswIndex.empty(); + + assertThrows( + UnsupportedOperationException.class, + () -> index.add(item1), + "Index should be immutable" + ); + + assertThat(index.size(), is(0)); + assertThat(index.getDimensions(), is(0)); + } } diff --git a/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala b/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala index de85e7fe..4a89b37e 100644 --- a/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala +++ b/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/bruteforce/BruteForceIndex.scala @@ -87,6 +87,20 @@ object BruteForceIndex { new BruteForceIndex[TId, TVector, TItem, TDistance](jIndex) } + + /** + * Creates an immutable empty index. + * + * @tparam TId Type of the external identifier of an item + * @tparam TVector Type of the vector to perform distance calculation on + * @tparam TItem Type of items stored in the index + * @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..) + * @return the index + */ + def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: BruteForceIndex[TId, TVector, TItem, TDistance] = { + val jIndex: JBruteForceIndex[TId, TVector, TItem, TDistance] = JBruteForceIndex.empty() + new BruteForceIndex(jIndex) + } } /** diff --git a/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndex.scala b/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndex.scala index 121a347d..2dbda75a 100644 --- a/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndex.scala +++ b/hnswlib-scala/src/main/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndex.scala @@ -115,8 +115,21 @@ object HnswIndex { if(removeEnabled) builder.withRemoveEnabled().build() else builder.build() - new HnswIndex[TId, TVector, TItem, TDistance](jIndex) + new HnswIndex(jIndex) + } + /** + * Creates an immutable empty index. + * + * @tparam TId Type of the external identifier of an item + * @tparam TVector Type of the vector to perform distance calculation on + * @tparam TItem Type of items stored in the index + * @tparam TDistance Type of distance between items (expect any numeric type: float, double, int, ..) + * @return the index + */ + def empty[TId, TVector, TItem <: Item[TId, TVector], TDistance]: HnswIndex[TId, TVector, TItem, TDistance] = { + val jIndex: JHnswIndex[TId, TVector, TItem, TDistance] = JHnswIndex.empty() + new HnswIndex(jIndex) } } @@ -140,13 +153,18 @@ class HnswIndex[TId, TVector, TItem <: Item[TId, TVector], TDistance] private (d /** * This distance function. */ - val distanceFunction: DistanceFunction[TVector, TDistance] = delegate - .getDistanceFunction.asInstanceOf[ScalaDistanceFunctionAdapter[TVector, TDistance]].scalaFunction + val distanceFunction: DistanceFunction[TVector, TDistance] = delegate.getDistanceFunction match { + case a: ScalaDistanceFunctionAdapter[TVector, TDistance] => a.scalaFunction + case f => (v1: TVector, v2: TVector) => f.distance(v1, v2) + } /** * The ordering used to compare distances */ - val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator.asInstanceOf[Ordering[TDistance]] + val distanceOrdering: Ordering[TDistance] = delegate.getDistanceComparator match { + case ordering: Ordering[TDistance] => ordering + case c => (x: TDistance, y: TDistance) => c.compare(x, y) + } /** * The maximum number of items the index can hold. diff --git a/hnswlib-scala/src/test/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndexSpec.scala b/hnswlib-scala/src/test/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndexSpec.scala index 170386be..cb8e55dd 100644 --- a/hnswlib-scala/src/test/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndexSpec.scala +++ b/hnswlib-scala/src/test/scala/com/github/jelmerk/knn/scalalike/hnsw/HnswIndexSpec.scala @@ -187,4 +187,9 @@ class HnswIndexSpec extends AnyFunSuite { index.asExactIndex.size should be (1) } + test("creates an empty immutable index") { + val index = HnswIndex.empty[String, Array[Float], TestItem, Float] + index.size should be (0) + } + }