From d34dfd7901e77c47e745214a30c3ea9c4cfe4824 Mon Sep 17 00:00:00 2001 From: kennethmhc Date: Fri, 17 May 2024 10:30:07 +0200 Subject: [PATCH] [FSTORE-1078] Support Similarity Search in the Feature Store v1.5 (#1529) * [FSTORE-1224] Validate the embedding index and features before creating feature group (#1747) * validate embedding index and feature * fix style * fix comment (cherry picked from commit 798c399c15d7ba9619709e5cbb0ab1ed608b2fb4) * [APPEND][FSTORE-1202] Fix deleting external embedding feature group (#1752) * remove metadata at the end * fix comment (cherry picked from commit 28490920ccdb3ccadef96c21472fa5dc12a2dc76) * [FSTORE-1226] Make sure the index mapping limit does not exceed when creating feature group (#1756) * validateWithinMappingLimit * validate when appending features * skip columns check if embedding * check opensearch status code * fix validate index creation * fix style * fix test * remove test (cherry picked from commit 5ff93f8336243c1ec1d858902c53ed182a85c045) * [FSTORE-1317] Handle opensearch bulk response with partial failures (#1751) * check hasFailures only * set refresh policy * increase timeout * retry index creation * increase retry * cap sleep time * refactor constant variable * set thread * address comments * rename variable (cherry picked from commit a310525fcbca1e94f85445a084d9af84555480f1) * [FSTORE-1264] Test ingesting different data types to opensearch (#1764) * define mapping * check supported feature types * add test (cherry picked from commit d94a981c2790af1c6967c8756dc5355db54ed48f) * [FSTORE-1241] Return online type in FG metadata when vector store is used (#1767) * set embedding fg online type * get opensearch type * fix test (cherry picked from commit 242034988dbb93a732aad21b5a9c3f9f401c95f5) * [APPEND][FSTORE-1226] Validate mapping size including sub-field (#1769) * validate mapping size * fix test (cherry picked from commit 18285b3020d851b3f305a497661a1e72a273f9e8) * [FSTORE-1378] Handle Opensearch vector database error (#1774) * refactor retry * opensearch setting * do not throw exception * refactor index identifier * index cleaner * fix import * return index name * fix get index * address comments * fix timer interval * address comments * address comments * constrained retry (cherry picked from commit 0110e63f6c42d3fab794fe9a403de917ca40eb69) * [APPEND][FSTORE-1378] Handle opensearch restart in OpensearchVectorDatabase #1782 (cherry picked from commit 964aaa9a09b65b4523083b0661415560087e23a1) * [APPEND][FSTORE-1378] Fix get index if not exist (#1785) (cherry picked from commit d228a9be9fe9ccc951a39560d7e94cb14353502e) * [FSTORE-1314] Support defining similar function in embedding index (#1783) * use sim function * add license * update license * refactor getOpensearchFunction (cherry picked from commit 5bf89809eaf0aa3663596d7480efd901b9f39aff) --- .../ruby/spec/helpers/featurestore_helper.rb | 1 + .../test/ruby/spec/similarity_search_spec.rb | 20 ++ .../embedding/EmbeddingController.java | 189 +++++++++-- .../embedding/EmbeddingIndexCleaner.java | 110 ++++++ ...nsearchVectorDatabaseConstrainedRetry.java | 73 ++++ .../embedding/VectorDatabaseClient.java | 34 +- .../featuregroup/EmbeddingDTO.java | 2 + .../featuregroup/EmbeddingFeatureDTO.java | 8 +- .../FeatureGroupInputValidation.java | 106 +++++- .../featuregroup/FeaturegroupController.java | 11 +- .../cached/CachedFeaturegroupController.java | 3 + .../OnDemandFeaturegroupController.java | 5 +- .../online/OnlineFeaturegroupController.java | 29 +- .../stream/StreamFeatureGroupController.java | 3 + .../online/OnlineFeaturestoreController.java | 7 + .../common/opensearch/OpenSearchClient.java | 3 +- .../hops/hopsworks/common/util/Settings.java | 13 +- .../embedding/EmbeddingControllerTest.java | 291 ++++++++++++++-- .../TestOnlineFeatureGroupController.java | 16 +- .../TestFeatureGroupInputValidation.java | 321 +++++++++++++++--- .../featuregroup/EmbeddingFeature.java | 13 +- .../featuregroup/SimilarityFunctionType.java | 39 +++ .../SimilarityFunctionTypeConverter.java | 33 ++ .../main/resources/META-INF/persistence.xml | 1 + .../hops/hopsworks/restutils/RESTCodes.java | 14 +- .../vectordb/OpensearchVectorDatabase.java | 319 ++++++++++++----- .../hopsworks/vectordb/VectorDatabase.java | 2 + .../vectordb/VectorDatabaseFactory.java | 53 +-- 28 files changed, 1443 insertions(+), 276 deletions(-) create mode 100644 hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingIndexCleaner.java create mode 100644 hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/OpensearchVectorDatabaseConstrainedRetry.java create mode 100644 hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionType.java create mode 100644 hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionTypeConverter.java diff --git a/hopsworks-IT/src/test/ruby/spec/helpers/featurestore_helper.rb b/hopsworks-IT/src/test/ruby/spec/helpers/featurestore_helper.rb index c07c7aefbd..bb47d0322b 100644 --- a/hopsworks-IT/src/test/ruby/spec/helpers/featurestore_helper.rb +++ b/hopsworks-IT/src/test/ruby/spec/helpers/featurestore_helper.rb @@ -126,6 +126,7 @@ def create_cached_featuregroup(project_id, featurestore_id, features: nil, featu { "name": "col1", "type": "array", + "description": "testfeaturedescription", "primary": false, "partition": false, "hudiPrecombineKey": false, diff --git a/hopsworks-IT/src/test/ruby/spec/similarity_search_spec.rb b/hopsworks-IT/src/test/ruby/spec/similarity_search_spec.rb index 1138b544ff..b7767906f5 100644 --- a/hopsworks-IT/src/test/ruby/spec/similarity_search_spec.rb +++ b/hopsworks-IT/src/test/ruby/spec/similarity_search_spec.rb @@ -89,6 +89,26 @@ expect(parsed_json["embeddingIndex"]["indexName"]).to eq("#{project.id}__embedding_test_index") expect(parsed_json["embeddingIndex"]["features"].length).to be 1 expect(parsed_json["embeddingIndex"]["colPrefix"]).to eq("") + # check features type + expect(parsed_json.first["features"][0].key?("name")).to be true + expect(parsed_json.first["features"][0]["name"]).to eql("id") + expect(parsed_json.first["features"][0].key?("type")).to be true + expect(parsed_json.first["features"][0]["type"].downcase).to eql("int") + expect(parsed_json.first["features"][0].key?("onlineType")).to be true + expect(parsed_json.first["features"][0]["onlineType"].downcase).to eql("int") + expect(parsed_json.first["features"][0].key?("description")).to be true + expect(parsed_json.first["features"][0].key?("primary")).to be true + expect(parsed_json.first["features"][0]["primary"]).to eql(true) + + expect(parsed_json.first["features"][1].key?("name")).to be true + expect(parsed_json.first["features"][1]["name"]).to eql("col1") + expect(parsed_json.first["features"][1].key?("type")).to be true + expect(parsed_json.first["features"][1]["type"].downcase).to eql("array") + expect(parsed_json.first["features"][1].key?("onlineType")).to be true + expect(parsed_json.first["features"][0]["onlineType"].downcase).to eql("knn_vector") + expect(parsed_json.first["features"][1].key?("description")).to be true + expect(parsed_json.first["features"][1].key?("primary")).to be true + expect(parsed_json.first["features"][1]["primary"]).to eql(false) end it "should be able to delete a feature group with embedding and custom index" do diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java index acf28edc33..412ca5783c 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java @@ -19,6 +19,7 @@ import com.google.common.base.Strings; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO; import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController; import io.hops.hopsworks.common.models.ModelFacade; @@ -32,7 +33,9 @@ import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.restutils.RESTCodes; +import io.hops.hopsworks.vectordb.Field; import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.OpensearchVectorDatabase; import io.hops.hopsworks.vectordb.VectorDatabaseException; import javax.ejb.EJB; @@ -43,6 +46,8 @@ import java.util.Collection; import java.util.Comparator; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.logging.Level; @@ -62,20 +67,105 @@ public class EmbeddingController { private ModelVersionFacade modelVersionFacade; @EJB private ModelFacade modelFacade; + private static final String embeddingIndexIdentifier = "__embedding"; - public void createVectorDbIndex(Project project, Featuregroup featureGroup) + public EmbeddingController() { + } + + // For testing + EmbeddingController(VectorDatabaseClient vectorDatabaseClient, Settings settings) { + this.vectorDatabaseClient = vectorDatabaseClient; + this.settings = settings; + } + + public void createVectorDbIndex(Project project, Featuregroup featureGroup, List features) throws FeaturestoreException { Index index = new Index(featureGroup.getEmbedding().getVectorDbIndexName()); try { - vectorDatabaseClient.getClient().createIndex(index, createIndex(featureGroup.getEmbedding().getColPrefix(), - featureGroup.getEmbedding().getEmbeddingFeatures()), true); if (isDefaultVectorDbIndex(project, index.getName())) { + vectorDatabaseClient.getClient().createIndex(index, createIndex(featureGroup.getEmbedding().getColPrefix(), + featureGroup.getEmbedding().getEmbeddingFeatures(), features), true); vectorDatabaseClient.getClient().addFields(index, createMapping(featureGroup.getEmbedding().getColPrefix(), - featureGroup.getEmbedding().getEmbeddingFeatures())); + featureGroup.getEmbedding().getEmbeddingFeatures(), features)); + } else { + vectorDatabaseClient.getClient().createIndex(index, createIndex(featureGroup.getEmbedding().getColPrefix(), + featureGroup.getEmbedding().getEmbeddingFeatures(), features), false); } + } catch (VectorDatabaseException e) { - throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, - Level.FINE, "Cannot create opensearch vectordb index: " + index.getName()); + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, + Level.FINE, String.format( + "Cannot create opensearch vectordb index: %s. Reason: %s", index.getName(), e.getMessage())); + } + } + + public void validateWithinMappingLimit(Project project, Index index, Integer numFeatures) + throws FeaturestoreException { + String indexName = getProjectIndexName(project, index.getName()); + try { + int remainingMappingSize; + if (indexExist(indexName)) { + remainingMappingSize = settings.getOpensearchDefaultIndexMappingLimit() + - vectorDatabaseClient.getClient().getSchema(new Index(indexName)).stream() + .mapToInt(field -> countMappingSizeIncludingSubFields(field.getType())).sum(); + } else { + remainingMappingSize = settings.getOpensearchDefaultIndexMappingLimit(); + } + if (numFeatures > remainingMappingSize) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.VECTOR_DATABASE_INDEX_MAPPING_LIMIT_EXCEEDED, + Level.FINE, String.format("Number of features exceeds the limit of the index '%s'." + + " Maximum number of features can be added/created is %d." + + " Reduce the number of features or use a different embedding index.", + index.getName(), remainingMappingSize)); + } + } catch (VectorDatabaseException e) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, + Level.FINE, "Cannot create opensearch vectordb index: " + indexName, e.getMessage()); + } + } + + // In opensearch, if the type has sub-fields, it is considered as 2 mappings + // `"4160_col_98":{"type":"long"}` --> 1 mapping + // `"5438_binary":{"type":"text","fields":{"keyword":{"type":"keyword","ignore_above":256}}}` --> 2 mappings + private int countMappingSizeIncludingSubFields(Object value) { + int count = 1; + if (value instanceof Map) { + if (((Map) value).containsKey("fields")) { + count += 1; + } + } + return count; + } + + public boolean indexExist(String name) throws FeaturestoreException { + try { + return vectorDatabaseClient.getClient().getIndex(name).isPresent(); + } catch (VectorDatabaseException e) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_GET_VECTOR_DB_INDEX, + Level.FINE, "Cannot get opensearch vectordb index: " + name); + } + } + + public void verifyIndexName(Project project, String name) throws FeaturestoreException { + if (name != null && !Strings.isNullOrEmpty(name)) { + String projectIndexName = getProjectIndexName(project, name); + if (indexExist(projectIndexName) && !isDefaultVectorDbIndex(project, projectIndexName)) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.EMBEDDING_INDEX_EXISTED, Level.FINE, + String.format("Provided embedding index `%s` already exists in the vector database.", projectIndexName)); + } + } + } + + String getProjectIndexName(Project project, String name) throws FeaturestoreException { + if (Strings.isNullOrEmpty(name)) { + return getDefaultVectorDbIndex(project); + } else { + String vectorDbIndexPrefix = getVectorDbIndexPrefix(project); + // In hopsworks opensearch, users can only access indexes which start with specific prefix + if (!name.startsWith(vectorDbIndexPrefix)) { + return vectorDbIndexPrefix + "_" + name; + } + return name; } } @@ -94,15 +184,13 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur throws FeaturestoreException { Embedding embedding = new Embedding(); embedding.setFeaturegroup(featuregroup); + String projectIndexName = getProjectIndexName(project, embeddingDTO.getIndexName()); + embedding.setVectorDbIndexName(projectIndexName); if (Strings.isNullOrEmpty(embeddingDTO.getIndexName())) { - embedding.setVectorDbIndexName(getDefaultVectorDbIndex(project)); embedding.setColPrefix(getVectorDbColPrefix(featuregroup)); } else { String vectorDbIndexPrefix = getVectorDbIndexPrefix(project); - // In hopsworks opensearch, users can only access indexes which start with specific prefix if (!embeddingDTO.getIndexName().startsWith(vectorDbIndexPrefix)) { - embedding.setVectorDbIndexName( - vectorDbIndexPrefix + "_" + embeddingDTO.getIndexName()); embedding.setColPrefix(""); } if (isDefaultVectorDbIndex(project, embeddingDTO.getIndexName())) { @@ -133,7 +221,7 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur public void dropEmbeddingForProject(Project project) throws FeaturestoreException { try { - for (Index index: vectorDatabaseClient.getClient().getAllIndices().stream() + for (Index index : vectorDatabaseClient.getClient().getAllIndices().stream() .filter(index -> index.getName().startsWith(getVectorDbIndexPrefix(project))).collect(Collectors.toSet())) { vectorDatabaseClient.getClient().deleteIndex(index); } @@ -143,6 +231,14 @@ public void dropEmbeddingForProject(Project project) } } + public Boolean isEmbeddingIndex(String indexName) { + return indexName.matches("^\\d+" + embeddingIndexIdentifier + ".*"); + } + + public Integer getProjectId(String indexName) { + return Integer.valueOf(indexName.split(embeddingIndexIdentifier)[0]); + } + public void dropEmbedding(Project project, Featuregroup featureGroup) throws FeaturestoreException { Index index = new Index(featureGroup.getEmbedding().getVectorDbIndexName()); @@ -160,7 +256,7 @@ public void dropEmbedding(Project project, Featuregroup featureGroup) } } catch (VectorDatabaseException e) { throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_DELETE_FEATUREGROUP, - Level.FINE, "Cannot delete documents from vectordb for feature group: " + + Level.FINE, "Cannot delete index from vectordb for feature group: " + featureGroup.getName(), e.getMessage(), e); } } @@ -174,37 +270,66 @@ private boolean isPreviousDefaultVectorDbIndex(Embedding embedding) } private void removeDocuments(Featuregroup featureGroup) throws FeaturestoreException, VectorDatabaseException { - EmbeddingFeature feature = featureGroup.getEmbedding().getEmbeddingFeatures().stream().findFirst().get(); + Set fields = + vectorDatabaseClient.getClient().getSchema(new Index(featureGroup.getEmbedding().getVectorDbIndexName())) + .stream().map(Field::getName).collect(Collectors.toSet()); + // Get any of the embedding feature which exists in the vector database for removing document if it is not null + Optional embeddingFeatureName = featureGroup + .getEmbedding().getEmbeddingFeatures().stream().map(feature -> feature.getEmbedding().getColPrefix() == null + ? feature.getName() + : feature.getEmbedding().getColPrefix() + feature.getName()).filter(fields::contains).findFirst(); String matchQuery = "%s:*"; - - String field = feature.getEmbedding().getColPrefix() == null - ? feature.getName() - : feature.getEmbedding().getColPrefix() + feature.getName(); - vectorDatabaseClient.getClient().deleteByQuery( - new Index(featureGroup.getEmbedding().getVectorDbIndexName()), - String.format(matchQuery, field) - ); + if (embeddingFeatureName.isPresent()) { + vectorDatabaseClient.getClient().deleteByQuery( + new Index(featureGroup.getEmbedding().getVectorDbIndexName()), + String.format(matchQuery, embeddingFeatureName.get()) + ); + } } - protected String createMapping(String prefix, Collection features) { + protected String createMapping(String prefix, Collection embeddingFeatures, + List features) { + Set embeddingFeatureNames = + embeddingFeatures.stream().map(EmbeddingFeature::getName).collect(Collectors.toSet()); String mappingString = "{\n" + " \"properties\": {\n" + "%s\n" + " }\n" + " }"; - String fieldString = " \"%s\": {\n" + + String embeddingFieldString = " \"%s\": {\n" + " \"type\": \"knn_vector\",\n" + - " \"dimension\": %d\n" + + " \"dimension\": %d,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"%s\",\n" + + " \"engine\": \"nmslib\"\n" + + " }\n" + + " }"; + String fieldString = " \"%s\": {\n" + + " \"type\": \"%s\"\n" + " }"; List fieldMapping = Lists.newArrayList(); - for (EmbeddingFeature feature : features) { + + for (EmbeddingFeature feature : embeddingFeatures) { fieldMapping.add(String.format( - fieldString, prefix + feature.getName(), feature.getDimension())); + embeddingFieldString, + prefix + feature.getName(), feature.getDimension(), + feature.getSimilarityFunctionType().getOpensearchFunction())); + } + for (FeatureGroupFeatureDTO feature : features) { + if (!embeddingFeatureNames.contains(feature.getName())) { + String type = OpensearchVectorDatabase.getDataType(feature.getType()); + if (type != null) { // if type cannot be converted, opensearch will infer the type + fieldMapping.add(String.format( + fieldString, prefix + feature.getName(), type)); + } + } } return String.format(mappingString, String.join(",\n", fieldMapping)); } - protected String createIndex(String prefix, Collection features) { + protected String createIndex(String prefix, Collection embeddingFeatures, + List features) { String jsonString = "{\n" + " \"settings\": {\n" + " \"index\": {\n" + @@ -214,17 +339,17 @@ protected String createIndex(String prefix, Collection feature " },\n" + " \"mappings\": %s\n" + "}"; - return String.format(jsonString, createMapping(prefix, features)); + return String.format(jsonString, createMapping(prefix, embeddingFeatures, features)); } - private String getDefaultVectorDbIndex(Project project) throws FeaturestoreException { + String getDefaultVectorDbIndex(Project project) throws FeaturestoreException { Set indexName = getAllDefaultVectorDbIndex(project); // randomly select an index return indexName.stream().sorted(Comparator.comparingInt(i -> RANDOM.nextInt())).findFirst().get(); } - private boolean isDefaultVectorDbIndex(Project project, String index) throws FeaturestoreException { + boolean isDefaultVectorDbIndex(Project project, String index) throws FeaturestoreException { return getAllDefaultVectorDbIndex(project).contains(index); } @@ -247,8 +372,8 @@ private Set getAllDefaultVectorDbIndex(Project project) throws Featurest return indices; } - private String getVectorDbIndexPrefix(Project project) { - return project.getId() + "__embedding"; + String getVectorDbIndexPrefix(Project project) { + return project.getId() + embeddingIndexIdentifier; } private String getVectorDbColPrefix(Featuregroup featuregroup) { diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingIndexCleaner.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingIndexCleaner.java new file mode 100644 index 0000000000..6c8ed19965 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingIndexCleaner.java @@ -0,0 +1,110 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2024, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.embedding; + +import com.google.common.collect.Sets; +import io.hops.hopsworks.common.dao.project.ProjectFacade; +import io.hops.hopsworks.common.util.PayaraClusterManager; +import io.hops.hopsworks.exceptions.FeaturestoreException; +import io.hops.hopsworks.persistence.entity.project.Project; +import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.VectorDatabaseException; + +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import javax.annotation.Resource; +import javax.ejb.EJB; +import javax.ejb.Singleton; +import javax.ejb.Startup; +import javax.ejb.Timeout; +import javax.ejb.Timer; +import javax.ejb.TimerConfig; +import javax.ejb.TimerService; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; + +@Singleton +@Startup +@TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED) +public class EmbeddingIndexCleaner { + + private final static Logger LOGGER = Logger.getLogger(EmbeddingIndexCleaner.class.getName()); + + @Resource + private TimerService timerService; + private Timer timer; + @EJB + private ProjectFacade projectFacade; + @EJB + private VectorDatabaseClient vectorDatabaseClient; + @EJB + private EmbeddingController embeddingController; + @EJB + private PayaraClusterManager payaraClusterManager; + + @PostConstruct + public void init() { + // Schedule the cleaner to run every 6 hours + timer = timerService.createIntervalTimer(10 * 60 * 1000, 6 * 60 * 60 * 1000, + new TimerConfig("EmbeddingIndexCleaner", false)); + + } + + @PreDestroy + private void destroyTimer() { + if (timer != null) { + timer.cancel(); + } + } + + @TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED) + @Timeout + public void cleanExpiredIndexes() { + if (!payaraClusterManager.amIThePrimary()) { + return; + } + LOGGER.log(Level.INFO, "Checking index to be removed"); + + try { + Set indexesToRemove = getIndexesToRemove(); + for (Index index : indexesToRemove) { + vectorDatabaseClient.getClient().deleteIndex(index); + LOGGER.log(Level.INFO, "Removed embedding index: " + index.getName()); + } + } catch (Exception e) { + LOGGER.log(Level.SEVERE, "Error occurred while cleaning embedding indexes", e); + } + } + + private Set getIndexesToRemove() throws VectorDatabaseException, FeaturestoreException { + Set indexesToRemove = Sets.newHashSet(); + Set projectIds = Sets.newHashSet(); + for (Project project : projectFacade.findAll()) { + projectIds.add(project.getId()); + } + for (Index index : vectorDatabaseClient.getClient().getAllIndices()) { + if (embeddingController.isEmbeddingIndex(index.getName()) && + !projectIds.contains(embeddingController.getProjectId(index.getName()))) { + indexesToRemove.add(index); + } + } + return indexesToRemove; + } +} diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/OpensearchVectorDatabaseConstrainedRetry.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/OpensearchVectorDatabaseConstrainedRetry.java new file mode 100644 index 0000000000..9efa8fe337 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/OpensearchVectorDatabaseConstrainedRetry.java @@ -0,0 +1,73 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2024, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.embedding; + +import com.logicalclocks.servicediscoverclient.exceptions.ServiceDiscoveryException; +import io.hops.hopsworks.common.opensearch.OpenSearchClient; +import io.hops.hopsworks.common.util.LongRunningHttpRequests; +import io.hops.hopsworks.common.util.Settings; +import io.hops.hopsworks.exceptions.OpenSearchException; +import io.hops.hopsworks.vectordb.OpensearchVectorDatabase; +import io.hops.hopsworks.vectordb.VectorDatabaseException; +import org.opensearch.client.RestHighLevelClient; + +import javax.ejb.ConcurrencyManagement; +import javax.ejb.ConcurrencyManagementType; +import javax.ejb.DependsOn; +import javax.ejb.EJB; +import javax.ejb.Stateless; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; + +@Stateless +@TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED) +@ConcurrencyManagement(ConcurrencyManagementType.BEAN) +@DependsOn("OpenSearchClient") +public class OpensearchVectorDatabaseConstrainedRetry extends OpensearchVectorDatabase { + + @EJB + private LongRunningHttpRequests longRunningHttpRequests; + @EJB + private Settings settings; + @EJB + private OpenSearchClient openSearchClient; + + @Override + protected Boolean shouldRetry() { + return longRunningHttpRequests.get() < settings.getMaxLongRunningHttpRequests(); + } + + @Override + protected void startRetry() { + longRunningHttpRequests.increment(); + } + + @Override + protected void doneRetry() { + longRunningHttpRequests.decrement(); + } + + @Override + protected RestHighLevelClient getClient() throws VectorDatabaseException { + try { + return openSearchClient.getClient(); + } catch (OpenSearchException | ServiceDiscoveryException e) { + throw new VectorDatabaseException("Cannot create opensearch client. " + e.getMessage()); + } + } + +} \ No newline at end of file diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java index 10aa7beb20..9ddd107e15 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java @@ -22,11 +22,12 @@ import io.hops.hopsworks.exceptions.OpenSearchException; import io.hops.hopsworks.restutils.RESTCodes; import io.hops.hopsworks.vectordb.VectorDatabase; -import io.hops.hopsworks.vectordb.VectorDatabaseFactory; +import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.ejb.ConcurrencyManagement; import javax.ejb.ConcurrencyManagementType; +import javax.ejb.DependsOn; import javax.ejb.EJB; import javax.ejb.Singleton; import javax.ejb.TransactionAttribute; @@ -37,29 +38,40 @@ @Singleton @TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED) @ConcurrencyManagement(ConcurrencyManagementType.BEAN) +@DependsOn("OpenSearchClient") public class VectorDatabaseClient { @EJB private OpenSearchClient openSearchClient; - private VectorDatabase vectorDatabase; + @EJB + private OpensearchVectorDatabaseConstrainedRetry vectorDatabase; private static final Logger LOG = Logger.getLogger(EmbeddingController.class.getName()); + @PostConstruct + public void init() { + try { + vectorDatabase.init(openSearchClient.getClient()); + } catch (OpenSearchException | ServiceDiscoveryException e) { + vectorDatabase = null; + LOG.log(Level.SEVERE, "Cannot create opensearch vectordb client"); + } + } + public synchronized VectorDatabase getClient() throws FeaturestoreException { - if (vectorDatabase == null) { - try { - vectorDatabase = VectorDatabaseFactory.getOpensearchDatabase(openSearchClient.getClient()); - } catch (OpenSearchException | ServiceDiscoveryException e) { - throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, - Level.FINE, "Cannot create opensearch vectordb"); - } + if (vectorDatabase != null) { + return vectorDatabase; + } else { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, + Level.FINE, "Cannot create opensearch vectordb client."); } - return vectorDatabase; } @PreDestroy private void close() { try { - vectorDatabase.close(); + if (vectorDatabase != null) { + vectorDatabase.close(); + } } catch (Exception ex) { LOG.log(Level.SEVERE, null, ex); } diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java index 785060132a..1f7fce432d 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java @@ -17,6 +17,7 @@ package io.hops.hopsworks.common.featurestore.featuregroup; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; +import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; @@ -24,6 +25,7 @@ import java.util.stream.Collectors; @NoArgsConstructor +@AllArgsConstructor public class EmbeddingDTO { @Getter diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java index 5a65a20ff9..77ede7f9a5 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java @@ -17,6 +17,7 @@ package io.hops.hopsworks.common.featurestore.featuregroup; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; import lombok.AllArgsConstructor; import lombok.Getter; @@ -30,12 +31,17 @@ public class EmbeddingFeatureDTO { @Getter private String name; @Getter - private String similarityFunctionType; + private SimilarityFunctionType similarityFunctionType; @Getter private Integer dimension; @Getter private ModelDto model; + public EmbeddingFeatureDTO(String name, SimilarityFunctionType similarityFunctionType, Integer dimension) { + this.name = name; + this.similarityFunctionType = similarityFunctionType; + this.dimension = dimension; + } public EmbeddingFeatureDTO(EmbeddingFeature feature) { name = feature.getName(); diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeatureGroupInputValidation.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeatureGroupInputValidation.java index 3e46f57596..18c6e1b93b 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeatureGroupInputValidation.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeatureGroupInputValidation.java @@ -16,8 +16,10 @@ package io.hops.hopsworks.common.featurestore.featuregroup; +import com.google.common.base.Joiner; import com.google.common.base.Strings; import io.hops.hopsworks.common.featurestore.FeaturestoreConstants; +import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; import io.hops.hopsworks.common.featurestore.featuregroup.cached.CachedFeaturegroupDTO; import io.hops.hopsworks.common.featurestore.featuregroup.online.OnlineFeaturegroupController; @@ -25,7 +27,10 @@ import io.hops.hopsworks.common.featurestore.utils.FeaturestoreInputValidation; import io.hops.hopsworks.exceptions.FeaturestoreException; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.TimeTravelFormat; +import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.restutils.RESTCodes; +import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.OpensearchVectorDatabase; import org.apache.commons.lang.StringUtils; import javax.ejb.EJB; @@ -36,6 +41,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Level; import java.util.stream.Collectors; @@ -50,15 +56,19 @@ public class FeatureGroupInputValidation { protected FeaturestoreInputValidation featureStoreInputValidation; @EJB protected OnlineFeaturegroupController onlineFeaturegroupController; + @EJB + protected EmbeddingController embeddingController; public FeatureGroupInputValidation() { } // for testing public FeatureGroupInputValidation(FeaturestoreInputValidation featureStoreInputValidation, - OnlineFeaturegroupController onlineFeaturegroupController) { + OnlineFeaturegroupController onlineFeaturegroupController, + EmbeddingController embeddingController) { this.featureStoreInputValidation = featureStoreInputValidation; this.onlineFeaturegroupController = onlineFeaturegroupController; + this.embeddingController = embeddingController; } /** @@ -166,6 +176,10 @@ public void verifyNoDuplicatedFeatures(FeaturegroupDTO featureGroupDTO) */ public void verifyOnlineOfflineTypeMatch(FeaturegroupDTO featuregroupDTO) throws FeaturestoreException{ if (featuregroupDTO.getOnlineEnabled()) { + // Users cannot specify the online type for embedding fg + if (featuregroupDTO.getEmbeddingIndex() != null) { + return; + } for (FeatureGroupFeatureDTO feature : featuregroupDTO.getFeatures()) { String offlineType = feature.getType().toLowerCase().replace(" ", ""); String onlineType = @@ -223,14 +237,14 @@ public void verifyOnlineOfflineTypeMatch(FeaturegroupDTO featuregroupDTO) throws * @throws FeaturestoreException */ public void verifyOnlineSchemaValid(FeaturegroupDTO featuregroupDTO) throws FeaturestoreException{ - if (featuregroupDTO.getOnlineEnabled()) { + if (featuregroupDTO.getOnlineEnabled() && featuregroupDTO.getEmbeddingIndex() == null) { if (featuregroupDTO.getFeatures().size() > FeaturestoreConstants.MAX_MYSQL_COLUMNS) { throw new FeaturestoreException( COULD_NOT_CREATE_ONLINE_FEATUREGROUP, Level.SEVERE, "Cannot create an online feature group because it contains > " + - FeaturestoreConstants.MAX_MYSQL_COLUMNS + " rows (provided: " + - featuregroupDTO.getFeatures().size() + " rows)."); + FeaturestoreConstants.MAX_MYSQL_COLUMNS + " columns (provided: " + + featuregroupDTO.getFeatures().size() + " columns)."); } Integer totalBytes = 0; @@ -258,7 +272,7 @@ public void verifyOnlineSchemaValid(FeaturegroupDTO featuregroupDTO) throws Feat * @throws FeaturestoreException */ public void verifyPrimaryKeySupported(FeaturegroupDTO featuregroupDTO) throws FeaturestoreException{ - if (featuregroupDTO.getOnlineEnabled()) { + if (featuregroupDTO.getOnlineEnabled() && featuregroupDTO.getEmbeddingIndex() == null) { Integer totalBytes = 0; for (FeatureGroupFeatureDTO feature : featuregroupDTO.getFeatures()) { if (feature.getPrimary()) { @@ -388,4 +402,86 @@ public List verifyAndGetNewFeatures(List features = + featureGroupDTO.getFeatures().stream().map(FeatureGroupFeatureDTO::getName).collect(Collectors.toSet()); + if (featureGroupDTO.getEmbeddingIndex() != null) { + for (EmbeddingFeatureDTO embeddingFeatureDTO : featureGroupDTO.getEmbeddingIndex().getFeatures()) { + if (!features.contains(embeddingFeatureDTO.getName())) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.EMBEDDING_FEATURE_NOT_FOUND, Level.FINE, + String.format("Provided embedding index `%s` does not exist in the feature group.", + embeddingFeatureDTO.getName())); + } + } + } + } + + public void verifyEmbeddingIndexNotExist(Project project, FeaturegroupDTO featureGroupDTO) + throws FeaturestoreException { + if (featureGroupDTO.getEmbeddingIndex() != null) { + embeddingController.verifyIndexName(project, featureGroupDTO.getEmbeddingIndex().getIndexName()); + } + } + + public void verifyEmbeddingIndexName(FeaturegroupDTO featureGroupDTO) throws FeaturestoreException { + if (featureGroupDTO.getEmbeddingIndex() != null && featureGroupDTO.getEmbeddingIndex().getIndexName() != null) { + String indexName = featureGroupDTO.getEmbeddingIndex().getIndexName(); + String errorMessage = String.format("Provided embedding index name `%s` is not valid. It should be " + + "1. All letters must be lowercase." + + "2. Index names cannot begin with _ or -." + + "3. Index names can't contain specified special characters.", indexName); + // https://docs.aws.amazon.com/opensearch-service/latest/developerguide/indexing.html + // Rule 1: All letters must be lowercase. + if (!indexName.equals(indexName.toLowerCase())) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.INVALID_EMBEDDING_INDEX_NAME, Level.FINE, + errorMessage); + } + + // Rule 2: Index names cannot begin with _ or -. + if (indexName.startsWith("_") || indexName.startsWith("-")) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.INVALID_EMBEDDING_INDEX_NAME, Level.FINE, + errorMessage); + } + + // Rule 3: Index names can't contain specified special characters. + String[] forbiddenChars = {" ", ",", ":", "\"", "*", "+", "/", "\\", "|", "?", "#", ">", "<"}; + for (String forbiddenChar : forbiddenChars) { + if (indexName.contains(forbiddenChar)) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.INVALID_EMBEDDING_INDEX_NAME, Level.FINE, + errorMessage); + } + } + } + } + + public void verifyVectorDatabaseIndexMappingLimit(Project project, FeaturegroupDTO featureGroupDTO, + Integer numFeatures) + throws FeaturestoreException { + if (featureGroupDTO.getEmbeddingIndex() != null) { + embeddingController.validateWithinMappingLimit(project, + new Index(featureGroupDTO.getEmbeddingIndex().getIndexName()), + numFeatures); + } + } + + public void verifyVectorDatabaseSupportedDataType(FeaturegroupDTO featureGroupDTO) + throws FeaturestoreException { + if (featureGroupDTO.getEmbeddingIndex() != null) { + verifyVectorDatabaseSupportedDataType(featureGroupDTO.getFeatures()); + } + } + + public void verifyVectorDatabaseSupportedDataType(List featureDTOS) + throws FeaturestoreException { + Set unsupportedFeatures = + featureDTOS.stream().filter(f -> OpensearchVectorDatabase.getDataType(f.getType()) == null) + .collect(Collectors.toSet()); + if (unsupportedFeatures.size() > 0) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.VECTOR_DATABASE_DATA_TYPE_NOT_SUPPORTED, + Level.FINE, "Vector database does not support data type in the following features: " + Joiner.on(", ") + .join(unsupportedFeatures.stream().map(f -> f.getName() + ": " + f.getType()).collect( + Collectors.toSet()))); + } + } } diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java index 8a77799cb8..514b77d2e4 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java @@ -224,6 +224,12 @@ public FeaturegroupDTO createFeaturegroup(Featurestore featurestore, Featuregrou enforceFeaturegroupQuotas(featurestore, featuregroupDTO); featureGroupInputValidation.verifySchemaProvided(featuregroupDTO); featureGroupInputValidation.verifyNoDuplicatedFeatures(featuregroupDTO); + featureGroupInputValidation.verifyEmbeddingFeatureExist(featuregroupDTO); + featureGroupInputValidation.verifyEmbeddingIndexNotExist(project, featuregroupDTO); + featureGroupInputValidation.verifyEmbeddingIndexName(featuregroupDTO); + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, featuregroupDTO, + featuregroupDTO.getFeatures().size()); + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(featuregroupDTO); // if version not provided, get latest and increment if (featuregroupDTO.getVersion() == null) { @@ -694,13 +700,14 @@ public void deleteFeaturegroup(Featuregroup featuregroup, Project project, Users streamFeatureGroupController.deleteFeatureGroup(featuregroup, project, user); break; case ON_DEMAND_FEATURE_GROUP: - // Delete on_demand_feature_group will cascade to feature_group table - onDemandFeaturegroupController.removeOnDemandFeaturegroup(featuregroup, project, user); // Delete mysql table and metadata if (settings.isOnlineFeaturestore() && featuregroup.isOnlineEnabled() && !featuregroup.getOnDemandFeaturegroup().isSpine()) { onlineFeaturegroupController.disableOnlineFeatureGroup(featuregroup, project, user); } + // Delete the metadata at the end as `disableOnlineFeatureGroup` requires `Embedding` metadata. + // Delete on_demand_feature_group will cascade to feature_group table + onDemandFeaturegroupController.removeOnDemandFeaturegroup(featuregroup, project, user); break; default: throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.ILLEGAL_FEATUREGROUP_TYPE, Level.FINE, diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/cached/CachedFeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/cached/CachedFeaturegroupController.java index f5081d0891..b95d898692 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/cached/CachedFeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/cached/CachedFeaturegroupController.java @@ -423,6 +423,9 @@ public void updateMetadata(Project project, Users user, Featuregroup featuregrou if (featuregroupDTO.getFeatures() != null) { verifyPreviousSchemaUnchanged(previousSchema, featuregroupDTO.getFeatures()); newFeatures = featureGroupInputValidation.verifyAndGetNewFeatures(previousSchema, featuregroupDTO.getFeatures()); + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, featuregroupDTO, newFeatures.size()); + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(newFeatures); + } // change feature descriptions diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ondemand/OnDemandFeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ondemand/OnDemandFeaturegroupController.java index ca94585caa..662034418f 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ondemand/OnDemandFeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ondemand/OnDemandFeaturegroupController.java @@ -53,8 +53,8 @@ import javax.ejb.TransactionAttributeType; import java.io.IOException; import java.sql.SQLException; -import java.util.Collection; import java.util.ArrayList; +import java.util.Collection; import java.util.Comparator; import java.util.List; import java.util.Optional; @@ -189,6 +189,9 @@ public void updateOnDemandFeaturegroupMetadata(Project project, Users user, Feat List newFeatures = featureGroupInputValidation.verifyAndGetNewFeatures(previousSchema, onDemandFeaturegroupDTO.getFeatures()); + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, onDemandFeaturegroupDTO, + newFeatures.size()); + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(newFeatures); // append new features and update existing ones updateOnDemandFeatures(onDemandFeaturegroup, onDemandFeaturegroupDTO.getFeatures()); diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java index 6fb4d22357..8d7b517677 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java @@ -42,6 +42,7 @@ import io.hops.hopsworks.exceptions.SchemaException; import io.hops.hopsworks.exceptions.ServiceException; import io.hops.hopsworks.persistence.entity.featurestore.Featurestore; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup; import io.hops.hopsworks.persistence.entity.kafka.schemas.SchemaCompatibility; import io.hops.hopsworks.persistence.entity.project.Project; @@ -49,6 +50,7 @@ import io.hops.hopsworks.restutils.RESTCodes; import io.hops.hopsworks.vectordb.Field; import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.OpensearchVectorDatabase; import io.hops.hopsworks.vectordb.VectorDatabaseException; import org.apache.calcite.sql.SqlDialect; import org.apache.calcite.sql.SqlIdentifier; @@ -57,6 +59,7 @@ import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.dialect.MysqlSqlDialect; import org.apache.calcite.sql.parser.SqlParserPos; +import org.glassfish.jersey.internal.guava.Sets; import javax.ejb.EJB; import javax.ejb.Stateless; @@ -115,9 +118,11 @@ public class OnlineFeaturegroupController { public OnlineFeaturegroupController() { } - protected OnlineFeaturegroupController(Settings settings, EmbeddingController embeddingController) { + protected OnlineFeaturegroupController(Settings settings, EmbeddingController embeddingController, + FeaturegroupController featuregroupController) { this.settings = settings; this.embeddingController = embeddingController; + this.featuregroupController = featuregroupController; } /** @@ -156,10 +161,10 @@ public void setupOnlineFeatureGroup(Featurestore featureStore, Featuregroup feat checkOnlineFsUserExist(project); createFeatureGroupKafkaTopic(project, featureGroup, features); + // create mysql table for embedding feature group also for storing feature description + createMySQLTable(featureStore, Utils.getFeaturegroupName(featureGroup), features, project, user); if (featureGroup.getEmbedding() != null) { - embeddingController.createVectorDbIndex(project, featureGroup); - } else { - createMySQLTable(featureStore, Utils.getFeaturegroupName(featureGroup), features, project, user); + embeddingController.createVectorDbIndex(project, featureGroup, features); } } @@ -450,10 +455,24 @@ public List getFeaturegroupFeatures(Featuregroup feature List onlineFeatureGroupFeatureDTOS = onlineFeaturestoreFacade.getMySQLFeatures( Utils.getFeatureStoreEntityName(featuregroup.getName(), featuregroup.getVersion()), onlineFeaturestoreController.getOnlineFeaturestoreDbName(featuregroup.getFeaturestore().getProject())); + Set embeddingFeatureNames = Sets.newHashSet(); + if (featuregroup.getEmbedding() != null) { + embeddingFeatureNames = featuregroup.getEmbedding().getEmbeddingFeatures().stream().map(EmbeddingFeature::getName) + .collect(Collectors.toSet()); + } for (FeatureGroupFeatureDTO featureGroupFeatureDTO : featureGroupFeatureDTOS) { for (FeatureGroupFeatureDTO onlineFeatureGroupFeatureDTO : onlineFeatureGroupFeatureDTOS) { if(featureGroupFeatureDTO.getName().equalsIgnoreCase(onlineFeatureGroupFeatureDTO.getName())){ - featureGroupFeatureDTO.setOnlineType(onlineFeatureGroupFeatureDTO.getType()); + if (featuregroup.getEmbedding() != null) { + if (embeddingFeatureNames.contains(featureGroupFeatureDTO.getName())) { + featureGroupFeatureDTO.setOnlineType("knn_vector"); + } else { + featureGroupFeatureDTO.setOnlineType( + OpensearchVectorDatabase.getDataType(featureGroupFeatureDTO.getType())); + } + } else { + featureGroupFeatureDTO.setOnlineType(onlineFeatureGroupFeatureDTO.getType()); + } } } } diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/stream/StreamFeatureGroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/stream/StreamFeatureGroupController.java index 6154bc230c..de315abc8b 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/stream/StreamFeatureGroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/stream/StreamFeatureGroupController.java @@ -225,6 +225,9 @@ public void updateMetadata(Project project, Users user, Featuregroup featuregrou if (featuregroupDTO.getFeatures() != null) { cachedFeaturegroupController.verifyPreviousSchemaUnchanged(previousSchema, featuregroupDTO.getFeatures()); newFeatures = featureGroupInputValidation.verifyAndGetNewFeatures(previousSchema, featuregroupDTO.getFeatures()); + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, featuregroupDTO, newFeatures.size()); + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(newFeatures); + } // change feature descriptions diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/online/OnlineFeaturestoreController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/online/OnlineFeaturestoreController.java index 404cf22c64..7ac86c7cb8 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/online/OnlineFeaturestoreController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/online/OnlineFeaturestoreController.java @@ -352,6 +352,13 @@ public void removeOnlineFeatureStore(Project project) throws FeaturestoreExcepti String db = getOnlineFeaturestoreDbName(project); onlineFeaturestoreFacade.removeOnlineFeaturestoreDatabase(db, connection); + try { + embeddingController.dropEmbeddingForProject(project); + } catch (FeaturestoreException e) { + // Do not interrupt project deletion, instead clean up the orphan index in batch later, because + // project deletion cannot be retried and fails with “Your project role does not allow to perform this action." + LOGGER.log(Level.WARNING, "Failed to drop embedding for project.", e); + } } catch (SQLException se) { throw new FeaturestoreException( RESTCodes.FeaturestoreErrorCode.COULD_NOT_INITIATE_MYSQL_CONNECTION_TO_ONLINE_FEATURESTORE, diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/opensearch/OpenSearchClient.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/opensearch/OpenSearchClient.java index ab4f1809f5..9c7fcb0c50 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/opensearch/OpenSearchClient.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/opensearch/OpenSearchClient.java @@ -136,7 +136,8 @@ public synchronized RestHighLevelClient getClient() throws OpenSearchException, RestClient.builder(elasticAddr) .setHttpClientConfigCallback(httpAsyncClientBuilder -> { httpAsyncClientBuilder.setDefaultIOReactorConfig( - IOReactorConfig.custom().setIoThreadCount(Settings.OPENSEARCH_KIBANA_NO_CONNECTIONS).build()); + IOReactorConfig.custom().setIoThreadCount( + Settings.OPENSEARCH_KIBANA_NO_CONNECTIONS).setSoKeepAlive(true).build()); if (isSecurityEnabled) { return httpAsyncClientBuilder.setSSLContext(finalSslCtx) .setDefaultCredentialsProvider( diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/util/Settings.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/util/Settings.java index 70f27809aa..ccac7ecfd2 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/util/Settings.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/util/Settings.java @@ -361,6 +361,8 @@ public class Settings { "opensearch_default_embedding_index"; private static final String VARIABLE_NUM_OPENSEARCH_DEFAULT_EMBEDDING_INDEX = "opensearch_num_default_embedding_index"; + private static final String VARIABLE_OPENSEARCH_INDEX_MAPPING_LIMIT = + "opensearch_index_mapping_limit"; /* -------------------- Cloud --------------- */ private static final String VARIABLE_CLOUD_EVENTS_ENDPOINT= @@ -980,7 +982,8 @@ private void populateCache() { VARIABLE_OPENSEARCH_DEFAULT_EMBEDDING_INDEX, OPENSEARCH_DEFAULT_EMBEDDING_INDEX_NAME); OPENSEARCH_NUM_DEFAULT_EMBEDDING_INDEX = setIntVar( VARIABLE_NUM_OPENSEARCH_DEFAULT_EMBEDDING_INDEX, OPENSEARCH_NUM_DEFAULT_EMBEDDING_INDEX); - + OPENSEARCH_DEFAULT_INDEX_MAPPING_LIMIT = setIntVar( + VARIABLE_OPENSEARCH_INDEX_MAPPING_LIMIT, OPENSEARCH_DEFAULT_INDEX_MAPPING_LIMIT); ENABLE_CONDA_INSTALL = setBoolVar(VARIABLE_ENABLE_CONDA_INSTALL, ENABLE_CONDA_INSTALL); DEFAULT_FEATURE_STORE_PROJECT_ID = setIntVar(VARIABLE_FEATURE_STORE_PROJECT_ID, null); cached = true; @@ -1680,6 +1683,12 @@ public synchronized Integer getOpensearchNumDefaultEmbeddingIndex() { return OPENSEARCH_NUM_DEFAULT_EMBEDDING_INDEX; } + private Integer OPENSEARCH_DEFAULT_INDEX_MAPPING_LIMIT = 1000; + public synchronized Integer getOpensearchDefaultIndexMappingLimit() { + checkCache(); + return OPENSEARCH_DEFAULT_INDEX_MAPPING_LIMIT; + } + // Kibana public static final String KIBANA_INDEX_PREFIX = ".kibana"; @@ -3491,7 +3500,7 @@ public synchronized Boolean isKibanaMultiTenancyEnabled() { return KIBANA_MULTI_TENANCY_ENABELED; } - public static final int OPENSEARCH_KIBANA_NO_CONNECTIONS = 5; + public static final int OPENSEARCH_KIBANA_NO_CONNECTIONS = 30; //-------------------------------- PROVENANCE ----------------------------------------------// private static final String VARIABLE_PROVENANCE_TYPE = "provenance_type"; //disabled/meta/min/full diff --git a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/embedding/EmbeddingControllerTest.java b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/embedding/EmbeddingControllerTest.java index 4c987fd1af..a67e9dd877 100644 --- a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/embedding/EmbeddingControllerTest.java +++ b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/embedding/EmbeddingControllerTest.java @@ -16,51 +16,284 @@ package io.hops.hopsworks.common.featurestore.embedding; +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import io.hops.hopsworks.common.featurestore.FeaturestoreConstants; +import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; +import io.hops.hopsworks.common.util.Settings; +import io.hops.hopsworks.exceptions.FeaturestoreException; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; -import org.junit.Assert; +import io.hops.hopsworks.persistence.entity.project.Project; +import io.hops.hopsworks.vectordb.Field; +import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.OpensearchVectorDatabase; +import io.hops.hopsworks.vectordb.VectorDatabase; +import io.hops.hopsworks.vectordb.VectorDatabaseException; import org.junit.Before; import org.junit.Test; +import java.lang.reflect.Type; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.COSINE; +import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.DOT_PRODUCT; +import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.L2_NORM; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; public class EmbeddingControllerTest { - private EmbeddingController embeddingController; + private VectorDatabaseClient vectorDatabaseClient; + private VectorDatabase vectorDatabase; + private Settings settings; + private EmbeddingController target; + private final int defaultMappingSize = 1000; + private Project project; @Before - public void setup() { - embeddingController = spy(new EmbeddingController()); + public void setup() throws Exception { + project = mock(Project.class); + vectorDatabaseClient = mock(VectorDatabaseClient.class); + settings = mock(Settings.class); + when(settings.getOpensearchDefaultIndexMappingLimit()).thenReturn(defaultMappingSize); + vectorDatabase = mock(VectorDatabase.class); + when(vectorDatabaseClient.getClient()).thenReturn(vectorDatabase); + target = spy(new EmbeddingController(vectorDatabaseClient, settings)); } @Test public void testCreateIndex() { - List features = new ArrayList<>(); - features.add(new EmbeddingFeature(null, "vector", 512, "l2")); - features.add(new EmbeddingFeature(null, "vector2", 128, "l2")); - Assert.assertEquals( - "{\n" - + " \"settings\": {\n" - + " \"index\": {\n" - + " \"knn\": \"true\",\n" - + " \"knn.algo_param.ef_search\": 512\n" - + " }\n" - + " },\n" - + " \"mappings\": {\n" - + " \"properties\": {\n" - + " \"vector\": {\n" - + " \"type\": \"knn_vector\",\n" - + " \"dimension\": 512\n" - + " },\n" - + " \"vector2\": {\n" - + " \"type\": \"knn_vector\",\n" - + " \"dimension\": 128\n" - + " }\n" - + " }\n" - + " }\n" - + "}", - embeddingController.createIndex("", features)); + List embeddingFeatures = new ArrayList<>(); + embeddingFeatures.add(new EmbeddingFeature(null, "vector", 512, L2_NORM)); + embeddingFeatures.add(new EmbeddingFeature(null, "vector2", 128, COSINE)); + embeddingFeatures.add(new EmbeddingFeature(null, "vector3", 64, DOT_PRODUCT)); + List features = new ArrayList<>(); + Set offlineTypes = + FeaturestoreConstants.SUGGESTED_HIVE_FEATURE_TYPES.stream().map(type -> type.split(" ")[0]) + .collect(Collectors.toSet()); + offlineTypes.remove("DECIMAL"); // not supported by opensearch + for (String offlineType : offlineTypes) { + features.add(new FeatureGroupFeatureDTO("feature_" + offlineType, offlineType)); + } + features.add(new FeatureGroupFeatureDTO("vector", "ARRAY")); + features.add(new FeatureGroupFeatureDTO("vector2", "ARRAY")); + features.add(new FeatureGroupFeatureDTO("vector3", "ARRAY")); + + + String expectedMapping = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn\": \"true\",\n" + + " \"knn.algo_param.ef_search\": 512\n" + + " }\n" + + " },\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"vector\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 512,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"l2\",\n" + + " \"engine\": \"nmslib\"\n" + + " }\n" + + " },\n" + + " \"vector2\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 128,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"cosinesimil\",\n" + + " \"engine\": \"nmslib\"\n" + + " }\n" + + " },\n" + + " \"vector3\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 64,\n" + + " \"method\": {\n" + + " \"name\": \"hnsw\",\n" + + " \"space_type\": \"innerproduct\",\n" + + " \"engine\": \"nmslib\"\n" + + " }\n" + + " }"; + + for (String offlineType : offlineTypes) { + expectedMapping += ",\n \"feature_" + offlineType + "\": {\n" + + " \"type\": \"" + OpensearchVectorDatabase.getDataType(offlineType) + "\"\n" + + " }"; + } + + expectedMapping += "\n }\n" + + " }\n" + + "}"; + + assertEquals(expectedMapping, target.createIndex("", embeddingFeatures, features)); + } + + @Test + public void testValidateWithinMappingLimit_Success() throws Exception { + Index index = new Index("testIndex"); + int numFeatures = 5; + + when(vectorDatabase.getSchema(any())).thenReturn(new ArrayList<>(Collections.nCopies(defaultMappingSize - numFeatures, null))); + + // Call the method + target.validateWithinMappingLimit(project, index, 0); + + // Verify that no exception is thrown + } + + @Test(expected = FeaturestoreException.class) + public void testValidateWithinMappingLimit_ExceedLimit_IndexExists() throws Exception { + Index index = new Index("testIndex"); + int numFeatures = 5; + + doReturn(true).when(target).indexExist(eq(index.getName())); + when(vectorDatabase.getSchema(any())).thenReturn(new ArrayList<>(Collections.nCopies(defaultMappingSize, new Field("f1", "int")))); + doReturn(index.getName()).when(target).getProjectIndexName(any(), any()); + // Call the method + target.validateWithinMappingLimit(project, index, numFeatures); + + // Verify that FeaturestoreException is thrown + } + + @Test(expected = FeaturestoreException.class) + public void testValidateWithinMappingLimit_ExceedLimit_IndexExists_SubField() throws Exception { + Index index = new Index("testIndex"); + int numFeatures = 5; + + String jsonString = "{\"type\":\"text\",\"fields\":{\"keyword\":{\"type\":\"keyword\",\"ignore_above\":256}}}"; + Gson gson = new Gson(); + // Define the type of the map using TypeToken + Type type = new TypeToken>(){}.getType(); + + // Convert JSON string to Map + Map opensearchType = gson.fromJson(jsonString, type); + doReturn(true).when(target).indexExist(eq(index.getName())); + when(vectorDatabase.getSchema(any())).thenReturn(new ArrayList<>(Collections.nCopies(defaultMappingSize/2, new Field("f1", opensearchType)))); + doReturn(index.getName()).when(target).getProjectIndexName(any(), any()); + // Call the method + target.validateWithinMappingLimit(project, index, numFeatures); + + // Verify that FeaturestoreException is thrown + } + + @Test(expected = FeaturestoreException.class) + public void testValidateWithinMappingLimit_ExceedLimit_IndexNotExists() throws Exception { + Index index = new Index("testIndex"); + int numFeatures = defaultMappingSize + 1; + + doReturn(false).when(target).indexExist(eq(index.getName())); + doReturn(index.getName()).when(target).getProjectIndexName(any(), any()); + + // Call the method + target.validateWithinMappingLimit(project, index, numFeatures); + + // Verify that FeaturestoreException is thrown + } + + @Test(expected = FeaturestoreException.class) + public void testVerifyIndexName_IndexExists() throws FeaturestoreException { + String name = "testIndex"; + String projectIndexName = "project_testIndex"; + + // Mocking behavior to simulate index existence and non-default index + doReturn(projectIndexName).when(target).getProjectIndexName(any(), any()); + doReturn(true).when(target).indexExist(any()); + doReturn(false).when(target).isDefaultVectorDbIndex(any(), any()); + + // Call the method + target.verifyIndexName(project, name); + } + + @Test + public void testVerifyIndexName_DefaultIndex() throws FeaturestoreException { + String name = "testIndex"; + String projectIndexName = "project_testIndex"; + + // Mocking behavior to simulate index existence and non-default index + doReturn(projectIndexName).when(target).getProjectIndexName(any(), any()); + doReturn(true).when(target).indexExist(any()); + doReturn(true).when(target).isDefaultVectorDbIndex(any(), any()); + + // Call the method + target.verifyIndexName(project, name); + } + + @Test + public void testVerifyIndexName_IndexNotExists() throws FeaturestoreException { + String name = "testIndex"; + String projectIndexName = "project_testIndex"; + + // Mocking behavior to simulate index existence and non-default index + doReturn(projectIndexName).when(target).getProjectIndexName(any(), any()); + doReturn(false).when(target).indexExist(any()); + doReturn(false).when(target).isDefaultVectorDbIndex(any(), any()); + + // Call the method + target.verifyIndexName(project, name); + } + + @Test + public void testVerifyIndexName_NullName() throws FeaturestoreException { + // Mocking behavior to simulate index existence and non-default index + doReturn("").when(target).getProjectIndexName(any(), any()); + doReturn(false).when(target).indexExist(any()); + doReturn(false).when(target).isDefaultVectorDbIndex(any(), any()); + + // Call the method + target.verifyIndexName(project, null); + target.verifyIndexName(project, ""); + } + + @Test + public void testGetProjectIndexName_NullOrEmptyName() throws FeaturestoreException { + + // Mocking behavior to simulate empty or null name + doReturn("defaultIndex").when(target).getDefaultVectorDbIndex(any()); + + // Call the method with empty name + String emptyNameResult = target.getProjectIndexName(project, ""); + assertEquals("defaultIndex", emptyNameResult); + + // Call the method with null name + String nullNameResult = target.getProjectIndexName(project, null); + assertEquals("defaultIndex", nullNameResult); + } + + @Test + public void testGetProjectIndexName_NonEmptyName_NoPrefix() throws FeaturestoreException { + String name = "testIndex"; + + // Mocking behavior to simulate absence of prefix + doReturn("prefix").when(target).getVectorDbIndexPrefix(any()); + + // Call the method with empty name + String result = target.getProjectIndexName(project, name); + assertEquals("prefix_testIndex", result); + } + + @Test + public void testGetProjectIndexName_NonEmptyName_WithPrefix() throws FeaturestoreException { + String name = "prefix_testIndex"; + + // Mocking behavior to simulate presence of prefix + doReturn("prefix").when(target).getVectorDbIndexPrefix(any()); + + // Call the method + String result = target.getProjectIndexName(project, name); + + // Verify the result + assertEquals("prefix_testIndex", result); } } \ No newline at end of file diff --git a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java index 931e4738d1..3553771e0c 100644 --- a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java +++ b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java @@ -18,6 +18,7 @@ import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; +import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController; import io.hops.hopsworks.common.util.Settings; import io.hops.hopsworks.persistence.entity.featurestore.Featurestore; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; @@ -25,6 +26,7 @@ import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.vectordb.VectorDatabase; +import org.apache.commons.compress.utils.Lists; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -49,20 +51,22 @@ public class TestOnlineFeatureGroupController { private OnlineFeaturegroupController onlineFeaturegroupController; private EmbeddingController embeddingController; - + private FeaturegroupController featuregroupController; private VectorDatabase vectorDatabase; private Project project; private Featurestore featureStore; private Featuregroup featureGroup; @Before - public void setup() { + public void setup() throws Exception { settings = mock(Settings.class); vectorDatabase = mock(VectorDatabase.class); + featuregroupController = mock(FeaturegroupController.class); featureStore = mock(Featurestore.class); project = mock(Project.class); embeddingController = spy(new EmbeddingController()); - onlineFeaturegroupController = spy(new OnlineFeaturegroupController(settings, embeddingController)); + Mockito.when(featuregroupController.getFeatures(any(), any(), any())).thenReturn(Lists.newArrayList()); + onlineFeaturegroupController = spy(new OnlineFeaturegroupController(settings, embeddingController, featuregroupController)); featureGroup = new Featuregroup(); featureGroup.setEmbedding(null); featureGroup.setName("fg"); @@ -208,8 +212,10 @@ public void testSetupOnlineFeatureGroupWithEmbedding() throws Exception { featureGroup.setEmbedding(embedding); // Mock the behavior for vectorDatabase initialization - doNothing().when(embeddingController).createVectorDbIndex(any(), any()); + doNothing().when(embeddingController).createVectorDbIndex(any(), any(), any()); doNothing().when(onlineFeaturegroupController).checkOnlineFsUserExist(eq(project)); + doNothing().when(onlineFeaturegroupController) + .createMySQLTable(eq(featureStore), anyString(), anyList(), eq(project), eq(user)); doNothing().when(onlineFeaturegroupController) .createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), eq(features)); doNothing().when(onlineFeaturegroupController).createOnlineFeatureStore(any(), any(), any()); @@ -219,7 +225,7 @@ public void testSetupOnlineFeatureGroupWithEmbedding() throws Exception { // Assert // Verify that vectorDatabase.createIndex is called with the correct parameters - verify(embeddingController, times(1)).createVectorDbIndex(any(), any()); + verify(embeddingController, times(1)).createVectorDbIndex(any(), any(), any()); verify(onlineFeaturegroupController, times(1)).checkOnlineFsUserExist(eq(project)); verify(onlineFeaturegroupController, times(1)).createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), eq(features)); diff --git a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/utils/TestFeatureGroupInputValidation.java b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/utils/TestFeatureGroupInputValidation.java index ee7b358a39..c204fb6d9c 100644 --- a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/utils/TestFeatureGroupInputValidation.java +++ b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/utils/TestFeatureGroupInputValidation.java @@ -16,73 +16,97 @@ package io.hops.hopsworks.common.featurestore.utils; +import com.google.common.collect.Lists; +import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; +import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO; +import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingFeatureDTO; import io.hops.hopsworks.common.featurestore.featuregroup.FeatureGroupInputValidation; +import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupDTO; import io.hops.hopsworks.common.featurestore.featuregroup.cached.CachedFeaturegroupDTO; import io.hops.hopsworks.common.featurestore.featuregroup.ondemand.OnDemandFeaturegroupDTO; import io.hops.hopsworks.common.featurestore.featuregroup.online.OnlineFeaturegroupController; import io.hops.hopsworks.common.featurestore.featuregroup.stream.StreamFeatureGroupDTO; import io.hops.hopsworks.exceptions.FeaturestoreException; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.TimeTravelFormat; +import io.hops.hopsworks.persistence.entity.project.Project; import org.apache.commons.lang.StringUtils; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; - +import org.mockito.Mockito; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import static io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionType.L2_NORM; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + public class TestFeatureGroupInputValidation { - - private FeatureGroupInputValidation featureGroupInputValidation = - new FeatureGroupInputValidation(new FeaturestoreInputValidation(), new OnlineFeaturegroupController()); - + + private EmbeddingController embeddingController; + + private FeatureGroupInputValidation featureGroupInputValidation; + private Project project; + List features; - + @Rule public ExpectedException thrown = ExpectedException.none(); - + @Before public void setup() { + project = mock(Project.class); + embeddingController = Mockito.mock(EmbeddingController.class); features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("feature", "TIMESTAMP", "", true, false, "10", null)); features.add(new FeatureGroupFeatureDTO("feature2", "String", "", false, false, null, null)); + featureGroupInputValidation = + new FeatureGroupInputValidation(new FeaturestoreInputValidation(), new OnlineFeaturegroupController(), + embeddingController); } - + @Test public void testVerifyEventTimeFeature() throws Exception { featureGroupInputValidation.verifyEventTimeFeature("feature", features); } - + @Test public void testVerifyEventTimeFeatureType() throws Exception { thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyEventTimeFeature("feature2", features); } - + @Test public void testVerifyEventTimeUnavailable() throws Exception { thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyEventTimeFeature("time", features); } - + @Test public void testverifySchemaProvided_success() throws Exception { CachedFeaturegroupDTO featuregroupDTO = new CachedFeaturegroupDTO(); featuregroupDTO.setFeatures(features); featuregroupDTO.setOnlineEnabled(true); - + featureGroupInputValidation.verifySchemaProvided(featuregroupDTO); } - + @Test(expected = FeaturestoreException.class) public void testverifySchemaProvided_fail() throws Exception { CachedFeaturegroupDTO featuregroupDTO = new CachedFeaturegroupDTO(); featuregroupDTO.setFeatures(new ArrayList<>()); featuregroupDTO.setOnlineEnabled(true); - + featureGroupInputValidation.verifySchemaProvided(featuregroupDTO); } @@ -107,91 +131,91 @@ public void verifyNoDuplicatedFeatures_fail() throws Exception { @Test(expected = FeaturestoreException.class) public void testVerifyFeatureOfflineTypeProvided_null() throws Exception { FeatureGroupFeatureDTO featureDTO = new FeatureGroupFeatureDTO("feature_name", null); - + featureGroupInputValidation.verifyOfflineFeatureType(featureDTO); } - + @Test(expected = FeaturestoreException.class) public void testVerifyFeatureOfflineTypeProvided_empty() throws Exception { FeatureGroupFeatureDTO featureDTO = new FeatureGroupFeatureDTO("feature_name", ""); - + featureGroupInputValidation.verifyOfflineFeatureType(featureDTO); } @Test(expected = FeaturestoreException.class) public void testVerifyFeatureGroupFeatureList_name() throws Exception { List featureList = Arrays.asList( - new FeatureGroupFeatureDTO("feature_name", "string", "description"), - new FeatureGroupFeatureDTO("1234", "string", "description") + new FeatureGroupFeatureDTO("feature_name", "string", "description"), + new FeatureGroupFeatureDTO("1234", "string", "description") ); - + featureGroupInputValidation.verifyFeatureGroupFeatureList(featureList); } @Test(expected = FeaturestoreException.class) public void testVerifyFeatureGroupFeatureList_description() throws Exception { List featureList = Arrays.asList( - new FeatureGroupFeatureDTO("feature_name", "string", StringUtils.repeat("a", 300)), - new FeatureGroupFeatureDTO("ft2", "string", "description") + new FeatureGroupFeatureDTO("feature_name", "string", StringUtils.repeat("a", 300)), + new FeatureGroupFeatureDTO("ft2", "string", "description") ); - + featureGroupInputValidation.verifyFeatureGroupFeatureList(featureList); } @Test(expected = FeaturestoreException.class) public void testVerifyFeatureGroupFeatureList_type() throws Exception { List featureList = Arrays.asList( - new FeatureGroupFeatureDTO("feature_name", "string", "description"), - new FeatureGroupFeatureDTO("1234", "", "description") + new FeatureGroupFeatureDTO("feature_name", "string", "description"), + new FeatureGroupFeatureDTO("1234", "", "description") ); - + featureGroupInputValidation.verifyFeatureGroupFeatureList(featureList); } - + @Test public void testVerifyUserInputFeatureGroup() throws Exception { CachedFeaturegroupDTO featuregroupDTO = new CachedFeaturegroupDTO(); featuregroupDTO.setTimeTravelFormat(TimeTravelFormat.HUDI); - + // timestamp type camel case List newSchema = new ArrayList<>(); newSchema.add(new FeatureGroupFeatureDTO("part_param", "Integer", "", true, false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false , false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param3", "Timestamp", "", false , true)); + newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false, false)); + newSchema.add(new FeatureGroupFeatureDTO("part_param3", "Timestamp", "", false, true)); featuregroupDTO.setFeatures(newSchema); thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyPartitionKeySupported(featuregroupDTO); } - + @Test public void testVerifyAndGetNewFeaturesIfPrimary() throws Exception { List newSchema = new ArrayList<>(); newSchema.add(new FeatureGroupFeatureDTO("part_param", "Integer", "", true, false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false , false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param3", "String", "", true , false)); - + newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false, false)); + newSchema.add(new FeatureGroupFeatureDTO("part_param3", "String", "", true, false)); + thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyAndGetNewFeatures(features, newSchema); } - + @Test public void testVerifyAndGetNewFeaturesIfPartition() throws Exception { List newSchema = new ArrayList<>(); newSchema.add(new FeatureGroupFeatureDTO("part_param", "Integer", "", true, false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false , false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param3", "String", "", false , true)); - + newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false, false)); + newSchema.add(new FeatureGroupFeatureDTO("part_param3", "String", "", false, true)); + thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyAndGetNewFeatures(features, newSchema); } - + @Test public void testVerifyAndGetNewFeaturesIfMissingType() throws Exception { List newSchema = new ArrayList<>(); newSchema.add(new FeatureGroupFeatureDTO("part_param", "Integer", "", true, false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false , false)); - newSchema.add(new FeatureGroupFeatureDTO("part_param3", null, "", false , false)); - + newSchema.add(new FeatureGroupFeatureDTO("part_param2", "String", "", false, false)); + newSchema.add(new FeatureGroupFeatureDTO("part_param3", null, "", false, false)); + thrown.expect(FeaturestoreException.class); featureGroupInputValidation.verifyAndGetNewFeatures(features, newSchema); } @@ -676,4 +700,217 @@ public void testVerifyOnlineOfflineTypeMatchMapBlob() throws Exception { // Act featureGroupInputValidation.verifyOnlineOfflineTypeMatch(featuregroupDTO); } + + @Test + public void testVerifyEmbeddingFeatureExist_pass() throws FeaturestoreException { + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1"); + FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2"); + FeatureGroupFeatureDTO feature3 = new FeatureGroupFeatureDTO("feature3"); + + EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3); + + List features = Arrays.asList(feature1, feature2, feature3); + List embeddingFeatures = + Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3); + + EmbeddingDTO embeddingDTO = new EmbeddingDTO("name", "prefix", embeddingFeatures); + + FeaturegroupDTO featureGroupDTO = new FeaturegroupDTO(); + featureGroupDTO.setEmbeddingIndex(embeddingDTO); + featureGroupDTO.setFeatures(features); + + featureGroupInputValidation = + new FeatureGroupInputValidation(new FeaturestoreInputValidation(), new OnlineFeaturegroupController(), + new EmbeddingController()); + featureGroupInputValidation.verifyEmbeddingFeatureExist(featureGroupDTO); + } + + @Test + public void testVerifyEmbeddingFeatureExist_fail() throws FeaturestoreException { + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1"); + FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2"); + FeatureGroupFeatureDTO feature3 = new FeatureGroupFeatureDTO("feature3"); + + EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature3 = + new EmbeddingFeatureDTO("feature4", L2_NORM, 3); // this does not exist in feature group + + List features = Arrays.asList(feature1, feature2, feature3); + List embeddingFeatures = + Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3); + + EmbeddingDTO embeddingDTO = new EmbeddingDTO("name", "prefix", embeddingFeatures); + + FeaturegroupDTO featureGroupDTO = new FeaturegroupDTO(); + featureGroupDTO.setEmbeddingIndex(embeddingDTO); + featureGroupDTO.setFeatures(features); + + assertThrows(FeaturestoreException.class, + () -> featureGroupInputValidation.verifyEmbeddingFeatureExist(featureGroupDTO)); + } + + @Test + public void testVerifyEmbeddingIndex_pass() throws FeaturestoreException { + EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3); + + List embeddingFeatures = + Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3); + EmbeddingDTO embeddingDTO = new EmbeddingDTO("name", "prefix", embeddingFeatures); + + FeaturegroupDTO featureGroupDTO = new FeaturegroupDTO(); + featureGroupDTO.setEmbeddingIndex(embeddingDTO); + + doNothing().when(embeddingController).verifyIndexName(any(), anyString()); + + featureGroupInputValidation = + new FeatureGroupInputValidation(new FeaturestoreInputValidation(), new OnlineFeaturegroupController(), + embeddingController); + + featureGroupInputValidation.verifyEmbeddingIndexNotExist(project, featureGroupDTO); + } + + @Test + public void testVerifyEmbeddingIndex_fail() throws FeaturestoreException { + EmbeddingFeatureDTO embeddingFeature1 = new EmbeddingFeatureDTO("feature1", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature2 = new EmbeddingFeatureDTO("feature2", L2_NORM, 3); + EmbeddingFeatureDTO embeddingFeature3 = new EmbeddingFeatureDTO("feature3", L2_NORM, 3); + + List embeddingFeatures = + Arrays.asList(embeddingFeature1, embeddingFeature2, embeddingFeature3); + EmbeddingDTO embeddingDTO = new EmbeddingDTO("name", "prefix", embeddingFeatures); + + FeaturegroupDTO featureGroupDTO = new FeaturegroupDTO(); + featureGroupDTO.setEmbeddingIndex(embeddingDTO); + + doThrow(FeaturestoreException.class).when(embeddingController).verifyIndexName(any(), anyString()); + + featureGroupInputValidation = + new FeatureGroupInputValidation(new FeaturestoreInputValidation(), new OnlineFeaturegroupController(), + embeddingController); + + assertThrows(FeaturestoreException.class, () -> featureGroupInputValidation.verifyEmbeddingIndexNotExist(project, featureGroupDTO)); + } + + private FeaturegroupDTO createFeaturegroupDtoWithIndexName(String indexName) { + EmbeddingFeatureDTO embeddingFeature = new EmbeddingFeatureDTO("feature3", L2_NORM, 3); + + List embeddingFeatures = + Arrays.asList(embeddingFeature); + EmbeddingDTO embeddingDTO = new EmbeddingDTO(indexName, "prefix", embeddingFeatures); + + FeaturegroupDTO featureGroupDTO = new FeaturegroupDTO(); + featureGroupDTO.setEmbeddingIndex(embeddingDTO); + return featureGroupDTO; + } + + @Test(expected = FeaturestoreException.class) + public void testInvalidUpperCaseName() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("MyInvalidName")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("NAMEWITHUPPERCASE")); + } + + @Test(expected = FeaturestoreException.class) + public void testInvalidStartingCharacters() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("_invalid_name")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("-starting_with_hyphen")); + } + + @Test(expected = FeaturestoreException.class) + public void testInvalidCharacters() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("name,with,comma")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("invalid*name")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("withchars")); + } + + @Test + public void testNullOrEmptyName() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName(null)); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("")); + } + + @Test + public void testValidLowerCaseName() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("myindexname")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("valid_name123")); + } + + @Test + public void testValidNameWithAllowedCharacters() throws Exception { + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("name.with.dots")); + featureGroupInputValidation.verifyEmbeddingIndexName(createFeaturegroupDtoWithIndexName("name_with_underscores")); + } + + @Test + public void testVerifyVectorDatabaseIndexMappingLimit() throws FeaturestoreException { + // Call the method under test + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, createFeaturegroupDtoWithIndexName("myindexname"), 10); + + // Verify that the method called the embeddingController with the correct parameters + verify(embeddingController).validateWithinMappingLimit(any(), any(), any()); + } + + @Test + public void testVerifyVectorDatabaseIndexMappingLimit_EmbeddingIndexNull() throws FeaturestoreException { + // Call the method under test + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, createFeaturegroupDtoWithIndexName(null), 10); + + // Verify that the method did not call the embeddingController + verify(embeddingController).validateWithinMappingLimit(any(), any(), any()); + } + + @Test + public void testVerifyVectorDatabaseIndexMappingLimit_EmbeddingIndexWithoutName() throws FeaturestoreException { + // Call the method under test + featureGroupInputValidation.verifyVectorDatabaseIndexMappingLimit(project, createFeaturegroupDtoWithIndexName(""), 10); + + // Verify that the method did not call the embeddingController + verify(embeddingController).validateWithinMappingLimit(any(), any(), any()); + } + + @Test + public void testVerifyVectorDatabaseSupportedDataType_Success() throws Exception { + // Mocking FeatureGroupFeatureDTO objects + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1", "int"); + FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2", "float"); + FeaturegroupDTO featureGroupDTO = createFeaturegroupDtoWithIndexName("test"); + featureGroupDTO.setFeatures(Lists.newArrayList(feature1, feature2)); + + // Call the method + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(featureGroupDTO); + } + + @Test + public void testVerifyVectorDatabaseSupportedDataType_Fail() throws Exception { + // Mocking FeatureGroupFeatureDTO objects + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1", "not_supported_type"); + FeaturegroupDTO featureGroupDTO = createFeaturegroupDtoWithIndexName("test"); + featureGroupDTO.setFeatures(Lists.newArrayList(feature1)); + + // Call the method + assertThrows(FeaturestoreException.class, () -> featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(featureGroupDTO)); + } + + @Test + public void testVerifyVectorDatabaseSupportedDataType_FeatureGroupFeatureDTO_Success() throws Exception { + // Mocking FeatureGroupFeatureDTO objects + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1", "int"); + FeatureGroupFeatureDTO feature2 = new FeatureGroupFeatureDTO("feature2", "float"); + + // Call the method + featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(Lists.newArrayList(feature1, feature2)); + } + + @Test + public void testVerifyVectorDatabaseSupportedDataType_FeatureGroupFeatureDTO_Fail() throws Exception { + // Mocking FeatureGroupFeatureDTO objects + FeatureGroupFeatureDTO feature1 = new FeatureGroupFeatureDTO("feature1", "not_supported_type"); + + // Call the method + assertThrows(FeaturestoreException.class, () -> featureGroupInputValidation.verifyVectorDatabaseSupportedDataType(Lists.newArrayList(feature1))); + } } diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java index 81979f7892..3996f4f1a2 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java @@ -22,6 +22,8 @@ import javax.persistence.Basic; import javax.persistence.Column; import javax.persistence.Entity; +import javax.persistence.EnumType; +import javax.persistence.Enumerated; import javax.persistence.GeneratedValue; import javax.persistence.GenerationType; import javax.persistence.Id; @@ -48,7 +50,8 @@ public class EmbeddingFeature implements Serializable { @Column private Integer dimension; @Column(name = "similarity_function_type") - private String similarityFunctionType; + @Enumerated(EnumType.STRING) + private SimilarityFunctionType similarityFunctionType; @JoinColumn(name = "model_version_id", referencedColumnName = "id") @OneToOne private ModelVersion modelVersion; @@ -57,7 +60,7 @@ public EmbeddingFeature() { } public EmbeddingFeature(Embedding embedding, String name, Integer dimension, - String similarityFunctionType) { + SimilarityFunctionType similarityFunctionType) { this.embedding = embedding; this.name = name; this.dimension = dimension; @@ -65,7 +68,7 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension, } public EmbeddingFeature(Embedding embedding, String name, Integer dimension, - String similarityFunctionType, ModelVersion modelVersion) { + SimilarityFunctionType similarityFunctionType, ModelVersion modelVersion) { this.embedding = embedding; this.name = name; this.dimension = dimension; @@ -74,7 +77,7 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension, } public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension, - String similarityFunctionType) { + SimilarityFunctionType similarityFunctionType) { this.id = id; this.embedding = embedding; this.name = name; @@ -98,7 +101,7 @@ public Integer getDimension() { return dimension; } - public String getSimilarityFunctionType() { + public SimilarityFunctionType getSimilarityFunctionType() { return similarityFunctionType; } diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionType.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionType.java new file mode 100644 index 0000000000..0d0c0581c2 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionType.java @@ -0,0 +1,39 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2024, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.featurestore.featuregroup; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum SimilarityFunctionType { + + @JsonProperty(value = "l2_norm") + L2_NORM("l2"), + @JsonProperty(value = "cosine") + COSINE("cosinesimil"), + @JsonProperty(value = "dot_product") + DOT_PRODUCT("innerproduct"); + + private final String opensearchFunction; + + SimilarityFunctionType(String opensearchFunction) { + this.opensearchFunction = opensearchFunction; + } + + public String getOpensearchFunction() { + return opensearchFunction; + } +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionTypeConverter.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionTypeConverter.java new file mode 100644 index 0000000000..37e4e9e248 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/SimilarityFunctionTypeConverter.java @@ -0,0 +1,33 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2024, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.featurestore.featuregroup; + +import javax.persistence.AttributeConverter; +import javax.persistence.Converter; + +@Converter(autoApply = true) +public class SimilarityFunctionTypeConverter implements AttributeConverter { + @Override + public String convertToDatabaseColumn(SimilarityFunctionType attribute) { + return attribute.name().toLowerCase(); // Convert enum value to lowercase string + } + + @Override + public SimilarityFunctionType convertToEntityAttribute(String dbData) { + return SimilarityFunctionType.valueOf(dbData.toUpperCase()); // Convert lowercase string to uppercase enum value + } +} diff --git a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml index 2d953c4043..80e84335c9 100644 --- a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml +++ b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml @@ -66,6 +66,7 @@ io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature + io.hops.hopsworks.persistence.entity.featurestore.featuregroup.SimilarityFunctionTypeConverter io.hops.hopsworks.persistence.entity.featurestore.featuregroup.ondemand.OnDemandFeaturegroup io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.CachedFeaturegroup io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.FeatureGroupCommit diff --git a/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java b/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java index 4ae39d34e4..b546df624c 100644 --- a/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java +++ b/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java @@ -1329,7 +1329,7 @@ public String toString() { */ public enum FeaturestoreErrorCode implements RESTErrorCode { - COULD_NOT_CREATE_FEATUREGROUP(1, "Could not create feature group and corresponding Hive table", + COULD_NOT_CREATE_FEATUREGROUP(1, "Could not create feature group and corresponding online/offline store.", Response.Status.INTERNAL_SERVER_ERROR), FEATURESTORE_ID_NOT_PROVIDED(2, "Featurestore Id was not provided", Response.Status.BAD_REQUEST), FEATUREGROUP_ID_NOT_PROVIDED(3, "Featuregroup Id was not provided", Response.Status.BAD_REQUEST), @@ -1702,7 +1702,17 @@ public enum FeaturestoreErrorCode implements RESTErrorCode { FEATURE_NOT_FOUND_IN_VECTOR_DB(235, "Feature not found in vector db.", Response.Status.INTERNAL_SERVER_ERROR), COULD_NOT_PREVIEW_DATA_IN_VECTOR_DB(236, "Could not preview data in vector database.", - Response.Status.INTERNAL_SERVER_ERROR); + Response.Status.INTERNAL_SERVER_ERROR), + EMBEDDING_FEATURE_NOT_FOUND(237, "Embedding feature cannot be found in feature group.", + Response.Status.BAD_REQUEST), + COULD_NOT_GET_VECTOR_DB_INDEX(238, "Could not get index from vector db.", + Response.Status.INTERNAL_SERVER_ERROR), + EMBEDDING_INDEX_EXISTED(239, "Embedding index already exists.", Response.Status.BAD_REQUEST), + INVALID_EMBEDDING_INDEX_NAME(240, "Embedding index name is not valid.", Response.Status.BAD_REQUEST), + VECTOR_DATABASE_INDEX_MAPPING_LIMIT_EXCEEDED(241, "Index mapping limit exceeded.", Response.Status.BAD_REQUEST), + VECTOR_DATABASE_DATA_TYPE_NOT_SUPPORTED(242, "Provided data type is not supported by vector database.", + Response.Status.BAD_REQUEST); + private int code; private String message; diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java index 65fed15e72..a161877bcf 100644 --- a/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java @@ -17,9 +17,14 @@ package io.hops.hopsworks.vectordb; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import lombok.AllArgsConstructor; +import org.apache.http.client.config.RequestConfig; import org.opensearch.OpenSearchException; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.bulk.BulkResponse; @@ -28,15 +33,15 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.master.AcknowledgedResponse; -import org.opensearch.client.Request; import org.opensearch.client.RequestOptions; -import org.opensearch.client.Response; import org.opensearch.client.RestHighLevelClient; import org.opensearch.client.indices.CreateIndexRequest; import org.opensearch.client.indices.CreateIndexResponse; import org.opensearch.client.indices.GetIndexRequest; import org.opensearch.client.indices.GetIndexResponse; import org.opensearch.client.indices.PutMappingRequest; +import org.opensearch.common.Strings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; @@ -50,7 +55,9 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -62,58 +69,124 @@ public class OpensearchVectorDatabase implements VectorDatabase { private static final Logger LOGGER = Logger.getLogger( OpensearchVectorDatabase.class.getName()); + private Integer requestTimeout = 60000; + private Integer socketTimeout = 61000; + private final Integer maxRetry = 3; + + private static final Map dataTypeMap = ImmutableMap.builder() + .put("BOOLEAN", "byte") + .put("TINYINT", "byte") + .put("INT", "integer") + .put("SMALLINT", "short") + .put("BIGINT", "long") + .put("FLOAT", "float") + .put("DOUBLE", "double") + .put("TIMESTAMP", "date") + .put("DATE", "date") + .put("STRING", "text") + .put("ARRAY", "binary") + .put("STRUCT", "binary") + .put("BINARY", "binary") + .put("MAP", "binary") + .build(); + + public static String getDataType(String offlineType) { + if (!Strings.isNullOrEmpty(offlineType)) { + offlineType = offlineType.split("<")[0]; + } else { + return null; + } + return dataTypeMap.get(offlineType.toUpperCase()); + } + + public OpensearchVectorDatabase() { + + } + public OpensearchVectorDatabase(RestHighLevelClient client) { this.client = client; } + public OpensearchVectorDatabase(RestHighLevelClient client, Integer requestTimeout) { + this.client = client; + this.requestTimeout = requestTimeout; + this.socketTimeout = requestTimeout + 1000; + } + + public void init(RestHighLevelClient client) { + this.client = client; + } + + protected RestHighLevelClient getClient() throws VectorDatabaseException { + return client; + } + @Override public void createIndex(Index index, String mapping, Boolean skipIfExist) throws VectorDatabaseException { - try { - if (skipIfExist) { - Request request = new Request("HEAD", "/" + index.getName()); - Response response = client.getLowLevelClient().performRequest(request); - if (response.getStatusLine().getStatusCode() == 200) { - return; - } - } + if (skipIfExist && getIndex(index.getName()).isPresent()) { + return; + } + retry(() -> { CreateIndexRequest createIndexRequest = new CreateIndexRequest(index.getName()); + createIndexRequest.setTimeout(new TimeValue(requestTimeout)); + createIndexRequest.setMasterTimeout(new TimeValue(requestTimeout)); createIndexRequest.source(mapping, XContentType.JSON); - CreateIndexResponse response = client.indices().create(createIndexRequest, RequestOptions.DEFAULT); - if (!response.isAcknowledged()) { - throw new VectorDatabaseException("Failed to create opensearch index: " + index.getName()); + CreateIndexResponse response = getClient().indices().create(createIndexRequest, getRequestOptions()); + if (response.isAcknowledged()) { + return new OperationResult(true, null); } - } catch (IOException e) { - throw new VectorDatabaseException("Failed to create opensearch index: " + index.getName() + "Err: " + e); - } + return new OperationResult(false, null); + }, "create index", Sets.newHashSet(RestStatus.OK, RestStatus.CREATED)); + } + + public Optional getIndex(String name) throws VectorDatabaseException { + return retry(() -> { + GetIndexRequest getIndexRequest = new GetIndexRequest(name); + GetIndexResponse getIndexResponse = getClient().indices().get(getIndexRequest, RequestOptions.DEFAULT); + Optional result = getIndexResponse.getMappings().keySet().stream().map(Index::new).findFirst(); + return result.map(index -> new OperationResult<>( + true, index + )).orElseGet(() -> new OperationResult<>(false, null)); + + }, "get index", Sets.newHashSet(RestStatus.OK, RestStatus.NOT_FOUND)); } /** * Get all indices from OpenSearch. * * @return A set of index names. - * @throws VectorDatabaseException If there is an error while fetching indices. + * @throws VectorDatabaseException + * If there is an error while fetching indices. */ public Set getAllIndices() throws VectorDatabaseException { - try { + Optional> result = retry(() -> { GetIndexRequest getIndexRequest = new GetIndexRequest("*"); // "*" retrieves all indices - GetIndexResponse getIndexResponse = client.indices().get(getIndexRequest, RequestOptions.DEFAULT); - return getIndexResponse.getMappings().keySet().stream().map(Index::new).collect(Collectors.toSet()); - } catch (IOException e) { - throw new VectorDatabaseException("Failed to fetch opensearch indices. Err: " + e); - } + GetIndexResponse getIndexResponse = getClient().indices().get(getIndexRequest, RequestOptions.DEFAULT); + return new OperationResult>(true, + getIndexResponse.getMappings().keySet().stream().map(Index::new).collect(Collectors.toSet())); + }, "get all indices", Sets.newHashSet(RestStatus.OK)); + return result.orElseGet(Sets::newHashSet); } @Override public void deleteIndex(Index index) throws VectorDatabaseException { - try { - DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(index.getName()); - AcknowledgedResponse response = client.indices().delete(deleteIndexRequest, RequestOptions.DEFAULT); - if (!response.isAcknowledged()) { - throw new VectorDatabaseException("Failed to delete opensearch index: " + index.getName()); + retry(() -> { + DeleteIndexRequest deleteIndexRequest = + new DeleteIndexRequest(index.getName()) + .timeout(new TimeValue(requestTimeout)) + .masterNodeTimeout(new TimeValue(requestTimeout)); + RequestConfig requestConfig = RequestConfig.custom() + .setSocketTimeout(socketTimeout) + .build(); + RequestOptions options = RequestOptions.DEFAULT.toBuilder() + .setRequestConfig(requestConfig) + .build(); + AcknowledgedResponse response = getClient().indices().delete(deleteIndexRequest, options); + if (response.isAcknowledged()) { + return new OperationResult(true, null); } - } catch (IOException e) { - throw new VectorDatabaseException("Failed to delete opensearch index: " + index.getName() + "Err: " + e); - } + return new OperationResult(false, null); + }, "delete index", Sets.newHashSet(RestStatus.OK, RestStatus.NOT_FOUND)); } @Override @@ -121,7 +194,7 @@ public void addFields(Index index, String mapping) throws VectorDatabaseExceptio PutMappingRequest request = new PutMappingRequest(index.getName()); request.source(mapping, XContentType.JSON); try { - AcknowledgedResponse response = client.indices().putMapping(request, RequestOptions.DEFAULT); + AcknowledgedResponse response = getClient().indices().putMapping(request, RequestOptions.DEFAULT); if (!response.isAcknowledged()) { throw new VectorDatabaseException("Failed to add fields to opensearch index: " + index.getName()); } @@ -136,7 +209,7 @@ public List getSchema(Index index) throws VectorDatabaseException { GetIndexRequest request = new GetIndexRequest(index.getName()); // Get the index mapping try { - GetIndexResponse response = client.indices().get(request, RequestOptions.DEFAULT); + GetIndexResponse response = getClient().indices().get(request, RequestOptions.DEFAULT); Object mapping = response.getMappings().get(index.getName()).getSourceAsMap().getOrDefault("properties", null); if (mapping != null) { return ((Map) mapping).entrySet().stream() @@ -154,7 +227,7 @@ public List getSchema(Index index) throws VectorDatabaseException { public void write(Index index, String data, String docId) throws VectorDatabaseException { try { IndexRequest indexRequest = makeIndexRequest(index.getName(), data, docId); - IndexResponse response = client.index(indexRequest, RequestOptions.DEFAULT); + IndexResponse response = getClient().index(indexRequest, RequestOptions.DEFAULT); if (!(response.status().equals(RestStatus.CREATED) || response.status().equals(RestStatus.OK))) { throw new VectorDatabaseException("Cannot index data. Status: " + response.status()); } @@ -208,54 +281,42 @@ public List> preview(Index index, Set fields, int n) if (fields.size() == 0) { return results; } + Optional>> result = retry(() -> { + // Create a bool query + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - // Create a bool query - BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); - - // Add exists queries for each field - for (Field field : fields) { - boolQueryBuilder.must(QueryBuilders.existsQuery(field.getName())); - } + // Add exists queries for each field + for (Field field : fields) { + boolQueryBuilder.must(QueryBuilders.existsQuery(field.getName())); + } - // Create a SearchSourceBuilder to define the search query - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(boolQueryBuilder); - sourceBuilder.size(n); + // Create a SearchSourceBuilder to define the search query + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(boolQueryBuilder); + sourceBuilder.size(n); - // Create a SearchRequest with the specified index and source builder - SearchRequest searchRequest = new SearchRequest(index.getName()); - searchRequest.source(sourceBuilder); + // Create a SearchRequest with the specified index and source builder + SearchRequest searchRequest = new SearchRequest(index.getName()); + searchRequest.source(sourceBuilder); - try { // Execute the search request - SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT); + SearchResponse searchResponse = getClient().search(searchRequest, RequestOptions.DEFAULT); // Process the search hits for (SearchHit hit : searchResponse.getHits().getHits()) { results.add(hit.getSourceAsMap()); } - - } catch (IOException e) { - throw new VectorDatabaseException("Error occurred while querying OpenSearch index"); - } - - return results; + return new OperationResult(true, results); + }, "preview", Sets.newHashSet(RestStatus.OK)); + return result.orElseGet(() -> results); } private void bulkRequest(BulkRequest bulkRequest) throws VectorDatabaseException { try { - BulkResponse response = client.bulk(bulkRequest, RequestOptions.DEFAULT); - if (response.status().getStatus() != 200) { - throw new VectorDatabaseException( - String.format("Cannot index data. Response status %d; Message: %s", - response.status().getStatus(), response.buildFailureMessage()) - ); - } else { - if (response.hasFailures()) { - throw new VectorDatabaseException( - String.format("Index data failed partially. Response status %d; Message: %s", - response.status().getStatus(), response.buildFailureMessage()) - ); - } + BulkResponse response = getClient().bulk(bulkRequest, RequestOptions.DEFAULT); + if (response.hasFailures()) { + // Do not include message from `response.buildFailureMessage()` as this method failed to execute. + String msg = String.format("Index data failed partially. Response status %d", response.status().getStatus()); + throw new VectorDatabaseException(msg); } } catch (IOException e) { throw new VectorDatabaseException("Cannot index data. Err: " + e); @@ -305,23 +366,115 @@ public void batchWriteMap(Index index, Map> data) @Override public void deleteByQuery(Index index, String query) throws VectorDatabaseException { - DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(index.getName()); - deleteByQueryRequest.setQuery(new QueryStringQueryBuilder(query)); + retry(() -> { + DeleteByQueryRequest deleteByQueryRequest = new DeleteByQueryRequest(index.getName()); + deleteByQueryRequest.setQuery(new QueryStringQueryBuilder(query)); + deleteByQueryRequest.setTimeout(new TimeValue(requestTimeout)); + + BulkByScrollResponse response = getClient().deleteByQuery(deleteByQueryRequest, getRequestOptions()); + if (response.getBulkFailures().isEmpty()) { + return new OperationResult(true, null); + } + return new OperationResult(false, null); + + }, "delete by query", Sets.newHashSet(RestStatus.OK)); + } + + private long getDelayMillis(long delayMillis) { + return Math.min(delayMillis, 5000); + } + + // Retry wrapper for opensearch operation. It retries ONLY when the opensearch cluster in operation but + // is busy i.e it gives event timeout on operations. Note that this retry operation may halt the http connection + // thread for extensive amount of time if retry continues to fail, but the retry is necessary from the users + // perspective because otherwise they need to retry manually when for example creating feature groups. + protected Optional retry(OperationSupplier operation, String operationName, Set expectedStatus) + throws VectorDatabaseException { + long delayMillis = 1000; // Initial delay between retries (1 second) + OperationResult operationResult; + boolean retryStarted = false; try { - BulkByScrollResponse response = client.deleteByQuery(deleteByQueryRequest, RequestOptions.DEFAULT); - if (response.getBulkFailures().size() != 0) { - throw new VectorDatabaseException( - "Drop index failed partially. Message: " + - response.getBulkFailures() - .stream() - .map(f -> String.format("Index: %s , responseId: %s, status: %d, message: %s", - f.getIndex(), f.getId(), f.getStatus().getStatus(), f.getMessage())) - .collect(Collectors.joining("\t")) - ); + for (int i = 0; i < maxRetry; i++) { + try { + operationResult = operation.perform(); + if (operationResult.success) { // Operation succeeded, no need to retry + return Optional.ofNullable(operationResult.result); + } + + if (i < maxRetry - 1 && shouldRetry()) { + if (!retryStarted) { + startRetry(); + retryStarted = true; + } + Thread.sleep(getDelayMillis(delayMillis)); + delayMillis *= 2; + } else { + break; + } + } catch (IOException e) { + throw new VectorDatabaseException(String.format("Failed to %s index: %s", operationName, e.getMessage())); + } catch (OpenSearchStatusException e) { + if (expectedStatus.contains(e.status())) { + return Optional.empty(); + } else if (e.getDetailedMessage().contains("process_cluster_event_timeout_exception")) { + LOGGER.log(Level.INFO, + String.format("Failed to %s: %s", operationName, e.getDetailedMessage())); + if (i < maxRetry - 1 && shouldRetry()) { + if (!retryStarted) { + startRetry(); + retryStarted = true; + } + Thread.sleep(getDelayMillis(delayMillis)); + delayMillis *= 2; + } else { + break; + } + } else { + throw new VectorDatabaseException( + String.format("Failed to %s index: %s", operationName, e.getDetailedMessage())); + } + } + } + } catch (InterruptedException e) { + LOGGER.log(Level.INFO, String.format("Retry %s interrupted.", operationName)); + } finally { + if (retryStarted) { + doneRetry(); } - } catch (IOException e) { - throw new VectorDatabaseException("Failed to delete opensearch data."); } + throw new VectorDatabaseException(String.format("Operation %s failed after retries.", operationName)); + } + + protected Boolean shouldRetry() { + return true; + } + + protected void startRetry() { + + } + + protected void doneRetry() { + + } + + @FunctionalInterface + protected interface OperationSupplier { + OperationResult perform() throws IOException, OpenSearchStatusException, VectorDatabaseException; + } + + private RequestOptions getRequestOptions() { + RequestConfig requestConfig = RequestConfig.custom() + .setSocketTimeout(socketTimeout) + .build(); + return RequestOptions.DEFAULT.toBuilder() + .setRequestConfig(requestConfig) + .build(); + } + + @AllArgsConstructor + protected static class OperationResult { + private final Boolean success; + private final T result; } @Override diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java index e0c6b493e4..6603f2bc1c 100644 --- a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java @@ -18,9 +18,11 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; public interface VectorDatabase { + Optional getIndex(String name) throws VectorDatabaseException; Set getAllIndices() throws VectorDatabaseException; void createIndex(Index index, String mapping, Boolean skipIfExist) throws VectorDatabaseException; void deleteIndex(Index index) throws VectorDatabaseException; diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java index 4ab79010af..96cca17f41 100644 --- a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java @@ -16,63 +16,16 @@ package io.hops.hopsworks.vectordb; -import org.apache.http.HttpHost; -import org.apache.http.auth.AuthScope; -import org.apache.http.auth.UsernamePasswordCredentials; -import org.apache.http.client.CredentialsProvider; -import org.apache.http.conn.ssl.NoopHostnameVerifier; -import org.apache.http.impl.client.BasicCredentialsProvider; -import org.apache.http.impl.nio.reactor.IOReactorConfig; -import org.apache.http.ssl.SSLContexts; -import org.opensearch.client.RestClient; import org.opensearch.client.RestHighLevelClient; -import javax.net.ssl.SSLContext; -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.GeneralSecurityException; - public class VectorDatabaseFactory { - public static VectorDatabase getOpensearchDatabase(String host, String user, String password, String certPath, - String trustStorePassword) throws IOException { - SSLContext sslCtx = null; - Path trustStore = Paths.get(certPath); - char[] trustStorePw = null; - if (trustStorePassword != null) { - trustStorePw = - trustStorePassword.toCharArray(); - } - try { - sslCtx = SSLContexts.custom() - .loadTrustMaterial(trustStore.toFile(), trustStorePw) - .build(); - } catch (GeneralSecurityException | IOException e) { - throw new IOException("Failed to load ssl context."); - } - CredentialsProvider credentialsProvider = - new BasicCredentialsProvider(); - credentialsProvider.setCredentials(AuthScope.ANY, - new UsernamePasswordCredentials(user, password)); - final SSLContext finalSslCtx = sslCtx; - final CredentialsProvider finalCredentialsProvider = credentialsProvider; - - RestHighLevelClient client = new RestHighLevelClient( - RestClient.builder(HttpHost.create(host)) - .setHttpClientConfigCallback(httpAsyncClientBuilder -> { - httpAsyncClientBuilder.setDefaultIOReactorConfig( - IOReactorConfig.custom().setIoThreadCount(5).build()); - return httpAsyncClientBuilder - .setSSLContext(finalSslCtx) - .setDefaultCredentialsProvider(finalCredentialsProvider) - .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); - })); + public static VectorDatabase getOpensearchDatabase(RestHighLevelClient client) { return new OpensearchVectorDatabase(client); } - public static VectorDatabase getOpensearchDatabase(RestHighLevelClient client) { - return new OpensearchVectorDatabase(client); + public static VectorDatabase getOpensearchDatabase(RestHighLevelClient client, Integer requestTimeout) { + return new OpensearchVectorDatabase(client, requestTimeout); } }