From a50066e9aed6fcef4309048d427770479c35abb9 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Thu, 7 Dec 2023 21:14:30 +0530 Subject: [PATCH] feat: ShardKey helpers --- .../qdrant/client/utils/CollectionUtil.java | 33 ++++++++++++ .../io/qdrant/client/utils/PayloadUtil.java | 4 +- .../io/qdrant/client/utils/SelectorUtil.java | 50 +++++++++++++++++-- .../client/utils/CollectionUtilTest.java | 27 ++++++++++ .../qdrant/client/utils/SelectorUtilTest.java | 37 ++++++++++++-- 5 files changed, 142 insertions(+), 9 deletions(-) create mode 100644 src/main/java/io/qdrant/client/utils/CollectionUtil.java create mode 100644 src/test/java/io/qdrant/client/utils/CollectionUtilTest.java diff --git a/src/main/java/io/qdrant/client/utils/CollectionUtil.java b/src/main/java/io/qdrant/client/utils/CollectionUtil.java new file mode 100644 index 00000000..1633b544 --- /dev/null +++ b/src/main/java/io/qdrant/client/utils/CollectionUtil.java @@ -0,0 +1,33 @@ +package io.qdrant.client.utils; + +import io.qdrant.client.grpc.Collections.ShardKey; + +/** + * Utility class for working with collections. + */ +public class CollectionUtil { + + /** + * Creates a {@link ShardKey} based on a keyword. + * + * @param keyword The keyword to create the shard key from + * @return The {@link ShardKey} object + */ + public static ShardKey shardKey(String keyword) { + return ShardKey.newBuilder() + .setKeyword(keyword) + .build(); + } + + /** + * Creates a {@link ShardKey} based on a number. + * + * @param number The number to create the shard key from + * @return The {@link ShardKey} object + */ + public static ShardKey shardKey(long number) { + return ShardKey.newBuilder() + .setNumber(number) + .build(); + } +} diff --git a/src/main/java/io/qdrant/client/utils/PayloadUtil.java b/src/main/java/io/qdrant/client/utils/PayloadUtil.java index 01fa6938..2a14fffa 100644 --- a/src/main/java/io/qdrant/client/utils/PayloadUtil.java +++ b/src/main/java/io/qdrant/client/utils/PayloadUtil.java @@ -47,7 +47,7 @@ public static Map toPayload(Map inputMap) { * Converts a payload struct to a Java Map. * * @param struct The payload struct to convert. - * @return The converted hash map. + * @return The converted map. */ public static Map toMap(Struct struct) { Map structMap = toMap(struct.getFieldsMap()); @@ -58,7 +58,7 @@ public static Map toMap(Struct struct) { * Converts a payload map to a Java Map. * * @param payload The payload map to convert. - * @return The converted hash map. + * @return The converted map. */ public static Map toMap(Map payload) { Map hashMap = new HashMap<>(); diff --git a/src/main/java/io/qdrant/client/utils/SelectorUtil.java b/src/main/java/io/qdrant/client/utils/SelectorUtil.java index e2c6c0b1..afae25d6 100644 --- a/src/main/java/io/qdrant/client/utils/SelectorUtil.java +++ b/src/main/java/io/qdrant/client/utils/SelectorUtil.java @@ -1,10 +1,12 @@ package io.qdrant.client.utils; +import io.qdrant.client.grpc.Collections.ShardKey; import io.qdrant.client.grpc.Points.Filter; import io.qdrant.client.grpc.Points.PayloadIncludeSelector; import io.qdrant.client.grpc.Points.PointId; import io.qdrant.client.grpc.Points.PointsIdsList; import io.qdrant.client.grpc.Points.PointsSelector; +import io.qdrant.client.grpc.Points.ShardKeySelector; import io.qdrant.client.grpc.Points.VectorsSelector; import io.qdrant.client.grpc.Points.WithPayloadSelector; import io.qdrant.client.grpc.Points.WithVectorsSelector; @@ -34,19 +36,20 @@ public static WithVectorsSelector withVectors() { } /** - * Creates a {@link WithPayloadSelector} with the specified fields included in the payload. + * Creates a {@link WithPayloadSelector} with the specified fields included in + * the payload. * * @param fields The fields to include in the payload. * @return The created {@link WithPayloadSelector} object. */ public static WithPayloadSelector withPayload(String... fields) { - PayloadIncludeSelector include = - PayloadIncludeSelector.newBuilder().addAllFields(Arrays.asList(fields)).build(); + PayloadIncludeSelector include = PayloadIncludeSelector.newBuilder().addAllFields(Arrays.asList(fields)).build(); return WithPayloadSelector.newBuilder().setInclude(include).build(); } /** - * Creates a {@link WithVectorsSelector} with the specified vector fields included. + * Creates a {@link WithVectorsSelector} with the specified vector fields + * included. * * @param vectors The names of the vectors to include. * @return The created {@link WithVectorsSelector} object. @@ -89,4 +92,43 @@ public static PointsSelector idsSelector(PointId... ids) { public static PointsSelector filterSelector(Filter filter) { return PointsSelector.newBuilder().setFilter(filter).build(); } + + /** + * Creates a {@link ShardKeySelector} with the given shard keys. + * + * @param shardKeys The shard keys to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(ShardKey... shardKeys) { + return ShardKeySelector.newBuilder().addAllShardKeys(Arrays.asList(shardKeys)).build(); + } + + /** + * Creates a {@link ShardKeySelector} with the given shard key keywords. + * + * @param keywords The shard key keywords to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(String... keywords) { + ShardKeySelector.Builder builder = ShardKeySelector.newBuilder(); + for (String keyword : keywords) { + builder.addShardKeys(CollectionUtil.shardKey(keyword)); + } + return builder.build(); + } + + /** + * Creates a {@link ShardKeySelector} with the given shard key numbers. + * + * @param numbers The shard key numbers to include in the selector. + * @return The created {@link ShardKeySelector} object. + */ + public static ShardKeySelector shardKeySelector(long... numbers) { + ShardKeySelector.Builder builder = ShardKeySelector.newBuilder(); + for (long number : numbers) { + builder.addShardKeys(CollectionUtil.shardKey(number)); + } + return builder.build(); + } + } diff --git a/src/test/java/io/qdrant/client/utils/CollectionUtilTest.java b/src/test/java/io/qdrant/client/utils/CollectionUtilTest.java new file mode 100644 index 00000000..929f79fa --- /dev/null +++ b/src/test/java/io/qdrant/client/utils/CollectionUtilTest.java @@ -0,0 +1,27 @@ +package io.qdrant.client.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import io.qdrant.client.grpc.Collections.ShardKey; + +import org.junit.jupiter.api.Test; + +class CollectionUtilTest { + + @Test + void testShardKeyWithString() { + String keyword = "somethinKeySomething"; + ShardKey shardKey = CollectionUtil.shardKey(keyword); + assertNotNull(shardKey); + assertEquals(keyword, shardKey.getKeyword()); + } + + @Test + void testShardKeyWithLong() { + long number = 12345L; + ShardKey shardKey = CollectionUtil.shardKey(number); + assertNotNull(shardKey); + assertEquals(number, shardKey.getNumber()); + } +} \ No newline at end of file diff --git a/src/test/java/io/qdrant/client/utils/SelectorUtilTest.java b/src/test/java/io/qdrant/client/utils/SelectorUtilTest.java index 0735867d..0eccfb0d 100644 --- a/src/test/java/io/qdrant/client/utils/SelectorUtilTest.java +++ b/src/test/java/io/qdrant/client/utils/SelectorUtilTest.java @@ -3,9 +3,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import io.qdrant.client.grpc.Collections.ShardKey; import io.qdrant.client.grpc.Points.Filter; import io.qdrant.client.grpc.Points.PointId; import io.qdrant.client.grpc.Points.PointsSelector; +import io.qdrant.client.grpc.Points.ShardKeySelector; import io.qdrant.client.grpc.Points.WithPayloadSelector; import io.qdrant.client.grpc.Points.WithVectorsSelector; import java.util.Arrays; @@ -28,7 +30,7 @@ void testWithVectors() { @Test void testWithPayloadWithFields() { - String[] fields = {"field1", "field2"}; + String[] fields = { "field1", "field2" }; WithPayloadSelector selector = SelectorUtil.withPayload(fields); List expectedFields = Arrays.asList(fields); assertEquals(expectedFields, selector.getInclude().getFieldsList()); @@ -36,7 +38,7 @@ void testWithPayloadWithFields() { @Test void testWithVectorsWithNames() { - String[] vectors = {"vector1", "vector2"}; + String[] vectors = { "vector1", "vector2" }; WithVectorsSelector selector = SelectorUtil.withVectors(vectors); List expectedVectors = Arrays.asList(vectors); assertEquals(expectedVectors, selector.getInclude().getNamesList()); @@ -51,7 +53,7 @@ void testIdsSelectorWithList() { @Test void testIdsSelectorWithArray() { - PointId[] ids = {PointUtil.pointId(1), PointUtil.pointId(2)}; + PointId[] ids = { PointUtil.pointId(1), PointUtil.pointId(2) }; PointsSelector selector = SelectorUtil.idsSelector(ids); List expectedIds = Arrays.asList(ids); assertEquals(expectedIds, selector.getPoints().getIdsList()); @@ -63,4 +65,33 @@ void testFilterSelector() { PointsSelector selector = SelectorUtil.filterSelector(filter); assertEquals(filter, selector.getFilter()); } + + @Test + void testShardKeySelectorWithShardKeys() { + ShardKey[] shardKeys = { CollectionUtil.shardKey("key1"), CollectionUtil.shardKey("key2") }; + ShardKeySelector selector = SelectorUtil.shardKeySelector(shardKeys); + List expectedShardKeys = Arrays.asList(shardKeys); + assertEquals(expectedShardKeys, selector.getShardKeysList()); + } + + @Test + void testShardKeySelectorWithKeywords() { + String[] keywords = { "keyword1", "keyword2" }; + ShardKeySelector selector = SelectorUtil.shardKeySelector(keywords); + List expectedShardKeys = Arrays.asList( + CollectionUtil.shardKey("keyword1"), + CollectionUtil.shardKey("keyword2")); + assertEquals(expectedShardKeys, selector.getShardKeysList()); + } + + @Test + void testShardKeySelectorWithNumbers() { + long[] numbers = { 1, 2 }; + ShardKeySelector selector = SelectorUtil.shardKeySelector(numbers); + List expectedShardKeys = Arrays.asList( + CollectionUtil.shardKey(1), + CollectionUtil.shardKey(2)); + assertEquals(expectedShardKeys, selector.getShardKeysList()); + + } }