Skip to content

Commit

Permalink
feat: ShardKey helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
Anush008 committed Dec 7, 2023
1 parent e2cdc80 commit a50066e
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 9 deletions.
33 changes: 33 additions & 0 deletions src/main/java/io/qdrant/client/utils/CollectionUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package io.qdrant.client.utils;

import io.qdrant.client.grpc.Collections.ShardKey;

/**
* Utility class for working with collections.
*/
public class CollectionUtil {

/**
* Creates a {@link ShardKey} based on a keyword.
*
* @param keyword The keyword to create the shard key from
* @return The {@link ShardKey} object
*/
public static ShardKey shardKey(String keyword) {
return ShardKey.newBuilder()
.setKeyword(keyword)
.build();
}

/**
* Creates a {@link ShardKey} based on a number.
*
* @param number The number to create the shard key from
* @return The {@link ShardKey} object
*/
public static ShardKey shardKey(long number) {
return ShardKey.newBuilder()
.setNumber(number)
.build();
}
}
4 changes: 2 additions & 2 deletions src/main/java/io/qdrant/client/utils/PayloadUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public static Map<String, Value> toPayload(Map<String, Object> inputMap) {
* Converts a payload struct to a Java Map.
*
* @param struct The payload struct to convert.
* @return The converted hash map.
* @return The converted map.
*/
public static Map<String, Object> toMap(Struct struct) {
Map<String, Object> structMap = toMap(struct.getFieldsMap());
Expand All @@ -58,7 +58,7 @@ public static Map<String, Object> toMap(Struct struct) {
* Converts a payload map to a Java Map.
*
* @param payload The payload map to convert.
* @return The converted hash map.
* @return The converted map.
*/
public static Map<String, Object> toMap(Map<String, Value> payload) {
Map<String, Object> hashMap = new HashMap<>();
Expand Down
50 changes: 46 additions & 4 deletions src/main/java/io/qdrant/client/utils/SelectorUtil.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.qdrant.client.utils;

import io.qdrant.client.grpc.Collections.ShardKey;
import io.qdrant.client.grpc.Points.Filter;
import io.qdrant.client.grpc.Points.PayloadIncludeSelector;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointsIdsList;
import io.qdrant.client.grpc.Points.PointsSelector;
import io.qdrant.client.grpc.Points.ShardKeySelector;
import io.qdrant.client.grpc.Points.VectorsSelector;
import io.qdrant.client.grpc.Points.WithPayloadSelector;
import io.qdrant.client.grpc.Points.WithVectorsSelector;
Expand Down Expand Up @@ -34,19 +36,20 @@ public static WithVectorsSelector withVectors() {
}

/**
* Creates a {@link WithPayloadSelector} with the specified fields included in the payload.
* Creates a {@link WithPayloadSelector} with the specified fields included in
* the payload.
*
* @param fields The fields to include in the payload.
* @return The created {@link WithPayloadSelector} object.
*/
public static WithPayloadSelector withPayload(String... fields) {
PayloadIncludeSelector include =
PayloadIncludeSelector.newBuilder().addAllFields(Arrays.asList(fields)).build();
PayloadIncludeSelector include = PayloadIncludeSelector.newBuilder().addAllFields(Arrays.asList(fields)).build();
return WithPayloadSelector.newBuilder().setInclude(include).build();
}

/**
* Creates a {@link WithVectorsSelector} with the specified vector fields included.
* Creates a {@link WithVectorsSelector} with the specified vector fields
* included.
*
* @param vectors The names of the vectors to include.
* @return The created {@link WithVectorsSelector} object.
Expand Down Expand Up @@ -89,4 +92,43 @@ public static PointsSelector idsSelector(PointId... ids) {
public static PointsSelector filterSelector(Filter filter) {
return PointsSelector.newBuilder().setFilter(filter).build();
}

/**
* Creates a {@link ShardKeySelector} with the given shard keys.
*
* @param shardKeys The shard keys to include in the selector.
* @return The created {@link ShardKeySelector} object.
*/
public static ShardKeySelector shardKeySelector(ShardKey... shardKeys) {
return ShardKeySelector.newBuilder().addAllShardKeys(Arrays.asList(shardKeys)).build();
}

/**
* Creates a {@link ShardKeySelector} with the given shard key keywords.
*
* @param keywords The shard key keywords to include in the selector.
* @return The created {@link ShardKeySelector} object.
*/
public static ShardKeySelector shardKeySelector(String... keywords) {
ShardKeySelector.Builder builder = ShardKeySelector.newBuilder();
for (String keyword : keywords) {
builder.addShardKeys(CollectionUtil.shardKey(keyword));
}
return builder.build();
}

/**
* Creates a {@link ShardKeySelector} with the given shard key numbers.
*
* @param numbers The shard key numbers to include in the selector.
* @return The created {@link ShardKeySelector} object.
*/
public static ShardKeySelector shardKeySelector(long... numbers) {
ShardKeySelector.Builder builder = ShardKeySelector.newBuilder();
for (long number : numbers) {
builder.addShardKeys(CollectionUtil.shardKey(number));
}
return builder.build();
}

}
27 changes: 27 additions & 0 deletions src/test/java/io/qdrant/client/utils/CollectionUtilTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package io.qdrant.client.utils;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import io.qdrant.client.grpc.Collections.ShardKey;

import org.junit.jupiter.api.Test;

class CollectionUtilTest {

@Test
void testShardKeyWithString() {
String keyword = "somethinKeySomething";
ShardKey shardKey = CollectionUtil.shardKey(keyword);
assertNotNull(shardKey);
assertEquals(keyword, shardKey.getKeyword());
}

@Test
void testShardKeyWithLong() {
long number = 12345L;
ShardKey shardKey = CollectionUtil.shardKey(number);
assertNotNull(shardKey);
assertEquals(number, shardKey.getNumber());
}
}
37 changes: 34 additions & 3 deletions src/test/java/io/qdrant/client/utils/SelectorUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.qdrant.client.grpc.Collections.ShardKey;
import io.qdrant.client.grpc.Points.Filter;
import io.qdrant.client.grpc.Points.PointId;
import io.qdrant.client.grpc.Points.PointsSelector;
import io.qdrant.client.grpc.Points.ShardKeySelector;
import io.qdrant.client.grpc.Points.WithPayloadSelector;
import io.qdrant.client.grpc.Points.WithVectorsSelector;
import java.util.Arrays;
Expand All @@ -28,15 +30,15 @@ void testWithVectors() {

@Test
void testWithPayloadWithFields() {
String[] fields = {"field1", "field2"};
String[] fields = { "field1", "field2" };
WithPayloadSelector selector = SelectorUtil.withPayload(fields);
List<String> expectedFields = Arrays.asList(fields);
assertEquals(expectedFields, selector.getInclude().getFieldsList());
}

@Test
void testWithVectorsWithNames() {
String[] vectors = {"vector1", "vector2"};
String[] vectors = { "vector1", "vector2" };
WithVectorsSelector selector = SelectorUtil.withVectors(vectors);
List<String> expectedVectors = Arrays.asList(vectors);
assertEquals(expectedVectors, selector.getInclude().getNamesList());
Expand All @@ -51,7 +53,7 @@ void testIdsSelectorWithList() {

@Test
void testIdsSelectorWithArray() {
PointId[] ids = {PointUtil.pointId(1), PointUtil.pointId(2)};
PointId[] ids = { PointUtil.pointId(1), PointUtil.pointId(2) };
PointsSelector selector = SelectorUtil.idsSelector(ids);
List<PointId> expectedIds = Arrays.asList(ids);
assertEquals(expectedIds, selector.getPoints().getIdsList());
Expand All @@ -63,4 +65,33 @@ void testFilterSelector() {
PointsSelector selector = SelectorUtil.filterSelector(filter);
assertEquals(filter, selector.getFilter());
}

@Test
void testShardKeySelectorWithShardKeys() {
ShardKey[] shardKeys = { CollectionUtil.shardKey("key1"), CollectionUtil.shardKey("key2") };
ShardKeySelector selector = SelectorUtil.shardKeySelector(shardKeys);
List<ShardKey> expectedShardKeys = Arrays.asList(shardKeys);
assertEquals(expectedShardKeys, selector.getShardKeysList());
}

@Test
void testShardKeySelectorWithKeywords() {
String[] keywords = { "keyword1", "keyword2" };
ShardKeySelector selector = SelectorUtil.shardKeySelector(keywords);
List<ShardKey> expectedShardKeys = Arrays.asList(
CollectionUtil.shardKey("keyword1"),
CollectionUtil.shardKey("keyword2"));
assertEquals(expectedShardKeys, selector.getShardKeysList());
}

@Test
void testShardKeySelectorWithNumbers() {
long[] numbers = { 1, 2 };
ShardKeySelector selector = SelectorUtil.shardKeySelector(numbers);
List<ShardKey> expectedShardKeys = Arrays.asList(
CollectionUtil.shardKey(1),
CollectionUtil.shardKey(2));
assertEquals(expectedShardKeys, selector.getShardKeysList());

}
}

0 comments on commit a50066e

Please sign in to comment.