diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/FeaturestoreService.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/FeaturestoreService.java index 71e1fa1b86..9835fb4dc9 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/FeaturestoreService.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/FeaturestoreService.java @@ -17,6 +17,7 @@ package io.hops.hopsworks.api.featurestore; import com.google.common.base.Strings; +import com.google.common.collect.Lists; import io.hops.hopsworks.api.featurestore.datavalidationv2.greatexpectations.GreatExpectationResource; import io.hops.hopsworks.api.featurestore.featuregroup.FeaturegroupService; import io.hops.hopsworks.api.featurestore.featureview.FeatureViewService; @@ -50,10 +51,12 @@ import javax.enterprise.context.RequestScoped; import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.DefaultValue; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; import javax.ws.rs.core.Context; import javax.ws.rs.core.GenericEntity; import javax.ws.rs.core.MediaType; @@ -119,16 +122,32 @@ public void setProjectId(Integer projectId) { @AllowedProjectRoles({AllowedProjectRoles.DATA_OWNER, AllowedProjectRoles.DATA_SCIENTIST}) @JWTRequired(acceptedTokens = {Audience.API, Audience.JOB}, allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER"}) - @ApiKeyRequired(acceptedScopes = {ApiScope.FEATURESTORE}, - allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER"}) + @ApiKeyRequired(acceptedScopes = {ApiScope.FEATURESTORE, ApiScope.KAFKA}, + allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER", "AGENT"}) @ApiOperation(value = "Get the list of feature stores for the project", response = FeaturestoreDTO.class, responseContainer = "List") - public Response getFeaturestores(@Context SecurityContext sc) throws FeaturestoreException { - List featurestores = featurestoreController.getFeaturestoresForProject(project); + public Response getFeaturestores( + @Context + SecurityContext sc, + @QueryParam("include_shared") + @DefaultValue("true") + @ApiParam(value = "include_shared=false", + allowableValues = "include_shared=false,include_shared=true", + defaultValue = "true") + Boolean includeShared + ) + throws FeaturestoreException { + List featurestores; + if (includeShared) { + featurestores = featurestoreController.getFeaturestoresForProject(project); + } else { + featurestores = Lists.newArrayList(featurestoreController.convertFeaturestoreToDTO( + featurestoreController.getProjectFeaturestore(project) + )); + } GenericEntity> featurestoresGeneric = - new GenericEntity>(featurestores) { - }; + new GenericEntity>(featurestores) {}; return noCacheResponse.getNoCacheResponseBuilder(Response.Status.OK).entity(featurestoresGeneric).build(); } @@ -142,7 +161,8 @@ public Response getFeaturestores(@Context SecurityContext sc) throws Featurestor @Path("/{featurestoreId: [0-9]+}") @Produces(MediaType.APPLICATION_JSON) @AllowedProjectRoles({AllowedProjectRoles.DATA_OWNER, AllowedProjectRoles.DATA_SCIENTIST}) - @JWTRequired(acceptedTokens = {Audience.API}, allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER"}) + @JWTRequired(acceptedTokens = {Audience.API, Audience.JOB}, + allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER"}) @ApiKeyRequired(acceptedScopes = {ApiScope.FEATURESTORE}, allowedUserRoles = {"HOPS_ADMIN", "HOPS_USER", "HOPS_SERVICE_USER"}) @ApiOperation(value = "Get featurestore with specific Id", diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featuregroup/FeaturegroupService.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featuregroup/FeaturegroupService.java index f4f429c363..79bacf2da6 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featuregroup/FeaturegroupService.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featuregroup/FeaturegroupService.java @@ -27,6 +27,7 @@ import io.hops.hopsworks.api.featurestore.datavalidationv2.suites.ExpectationSuiteResource; import io.hops.hopsworks.api.featurestore.statistics.StatisticsResource; import io.hops.hopsworks.api.featurestore.tag.FeatureGroupTagResource; +import io.hops.hopsworks.api.filter.JWTNotRequired; import io.hops.hopsworks.api.jobs.JobDTO; import io.hops.hopsworks.api.jobs.JobsBuilder; import io.hops.hopsworks.api.provenance.FeatureGroupProvenanceResource; @@ -301,6 +302,35 @@ public Response getFeatureGroup(@ApiParam(value = "Id of the featuregroup", requ return noCacheResponse.getNoCacheResponseBuilder(Response.Status.OK).entity(featuregroupGeneric).build(); } + /** + * Endpoint for retrieving a featuregroup with a specified id in a specified featurestore for onlinefs + * + * @param featuregroupId id of the featuregroup + * @return JSON representation of the featuregroup + */ + @GET + @Path("/{featuregroupId: [0-9]+}/onlinefs") + @JWTNotRequired + @Produces(MediaType.APPLICATION_JSON) + @AllowedProjectRoles({AllowedProjectRoles.DATA_OWNER, AllowedProjectRoles.DATA_SCIENTIST}) + @ApiKeyRequired(acceptedScopes = {ApiScope.FEATURESTORE, ApiScope.KAFKA}, + allowedUserRoles = {"HOPS_SERVICE_USER", "AGENT"}) + @ApiOperation(value = "Get specific featuregroup for onlinefs from a specific featurestore", + response = FeaturegroupDTO.class) + public Response getFeatureGroupForOnlinefs( + @ApiParam(value = "Id of the featuregroup", required = true) + @PathParam("featuregroupId") Integer featuregroupId, + @Context HttpServletRequest req, + @Context SecurityContext sc) + throws FeaturestoreException { + verifyIdProvided(featuregroupId); + Featuregroup featuregroup = featuregroupController.getFeaturegroupById(featurestore, featuregroupId); + FeaturegroupDTO featuregroupDTO = new FeaturegroupDTO(featuregroup); + GenericEntity featuregroupGeneric = + new GenericEntity(featuregroupDTO) {}; + return noCacheResponse.getNoCacheResponseBuilder(Response.Status.OK).entity(featuregroupGeneric).build(); + } + /** * Retrieve a specific feature group based name. Allow filtering on version. * diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featureview/FeatureViewBuilder.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featureview/FeatureViewBuilder.java index cc9c5256aa..f1bcd52d6b 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featureview/FeatureViewBuilder.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/featurestore/featureview/FeatureViewBuilder.java @@ -196,15 +196,18 @@ private List makeFeatures(FeatureView featureView) { Map fsLookupTable = trainingDatasetController.getFsLookupTableFeatures(tdFeatures); return tdFeatures .stream() - .map(f -> new TrainingDatasetFeatureDTO(trainingDatasetController.checkPrefix(f), f.getType(), + .map(f -> new TrainingDatasetFeatureDTO( + trainingDatasetController.checkPrefix(f), + f.getType(), f.getFeatureGroup() != null ? new FeaturegroupDTO(f.getFeatureGroup().getFeaturestore().getId(), fsLookupTable.get(f.getFeatureGroup().getFeaturestore().getId()), - f.getFeatureGroup().getId(), f.getFeatureGroup().getName(), + f.getFeatureGroup().getId(), + f.getFeatureGroup().getName(), f.getFeatureGroup().getVersion(), f.getFeatureGroup().isDeprecated()) : null, - f.getIndex(), f.isLabel(), f.isInferenceHelperColumn(), f.isTrainingHelperColumn())) + f.getName(), f.getIndex(), f.isLabel(), f.isInferenceHelperColumn(), f.isTrainingHelperColumn())) .collect(Collectors.toList()); } } diff --git a/hopsworks-common/pom.xml b/hopsworks-common/pom.xml index 1911f42bdb..b26e5dd47f 100644 --- a/hopsworks-common/pom.xml +++ b/hopsworks-common/pom.xml @@ -187,6 +187,11 @@ hopsworks-service-discovery + + io.hops.hopsworks + vector-db + + org.jsoup jsoup diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/FeaturestoreController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/FeaturestoreController.java index 06885eb8c5..80f36bf3a3 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/FeaturestoreController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/FeaturestoreController.java @@ -373,7 +373,7 @@ public FeaturestoreStorageConnectorDTO createOfflineJdbcConnector(String databas * @param featurestore the featurestore entity * @return a DTO representation of the featurestore */ - private FeaturestoreDTO convertFeaturestoreToDTO(Featurestore featurestore) { + public FeaturestoreDTO convertFeaturestoreToDTO(Featurestore featurestore) { FeaturestoreDTO featurestoreDTO = new FeaturestoreDTO(featurestore); String featureStoreName = getOfflineFeaturestoreDbName(featurestore); // TODO(Fabio): remove this when we switch to the new UI. diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java new file mode 100644 index 0000000000..938e344865 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java @@ -0,0 +1,170 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.embedding; + +import com.google.common.base.Strings; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO; +import io.hops.hopsworks.common.hdfs.Utils; +import io.hops.hopsworks.common.util.Settings; +import io.hops.hopsworks.exceptions.FeaturestoreException; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup; +import io.hops.hopsworks.persistence.entity.project.Project; +import io.hops.hopsworks.restutils.RESTCodes; +import io.hops.hopsworks.vectordb.Index; +import io.hops.hopsworks.vectordb.VectorDatabaseException; + +import javax.ejb.EJB; +import javax.ejb.Stateless; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.logging.Level; +import java.util.stream.Collectors; + +@Stateless +@TransactionAttribute(TransactionAttributeType.NEVER) +public class EmbeddingController { + + @EJB + private Settings settings; + @EJB + private VectorDatabaseClient vectorDatabaseClient; + + public void createVectorDbIndex(Project project, Featuregroup featureGroup) + throws FeaturestoreException { + Index index = new Index(featureGroup.getEmbedding().getVectorDbIndexName()); + try { + vectorDatabaseClient.getClient().createIndex(index, createIndex(featureGroup.getEmbedding().getColPrefix(), + featureGroup.getEmbedding().getEmbeddingFeatures()), true); + if (isDefaultVectorDbIndex(project, index.getName())) { + vectorDatabaseClient.getClient().addFields(index, createMapping(featureGroup.getEmbedding().getColPrefix(), + featureGroup.getEmbedding().getEmbeddingFeatures())); + } + } catch (VectorDatabaseException e) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, + Level.FINE, "Cannot create opensearch vectordb index: " + index.getName()); + } + } + + public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup) + throws FeaturestoreException { + Embedding embedding = new Embedding(); + embedding.setFeaturegroup(featuregroup); + if (embeddingDTO.getIndexName() == null) { + embedding.setVectorDbIndexName(getDefaultVectorDbIndex(project)); + embedding.setColPrefix(getVectorDbColPrefix(featuregroup)); + } else { + String vectorDbIndexPrefix = getVectorDbIndexPrefix(project); + // In hopsworks opensearch, users can only access indexes which start with specific prefix + if (!embeddingDTO.getIndexName().startsWith(vectorDbIndexPrefix)) { + embedding.setVectorDbIndexName( + vectorDbIndexPrefix + "_" + embeddingDTO.getIndexName()); + embedding.setColPrefix(""); + } + if (isDefaultVectorDbIndex(project, embeddingDTO.getIndexName())) { + embedding.setColPrefix(getVectorDbColPrefix(featuregroup)); + } + } + embedding.setEmbeddingFeatures( + embeddingDTO.getFeatures() + .stream() + .map(mapping -> new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), + mapping.getSimilarityFunctionType())) + .collect(Collectors.toList()) + ); + return embedding; + } + + protected String createMapping(String prefix, Collection features) { + String mappingString = "{\n" + + " \"properties\": {\n" + + "%s\n" + + " }\n" + + " }"; + String fieldString = " \"%s\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": %d\n" + + " }"; + List fieldMapping = Lists.newArrayList(); + for (EmbeddingFeature feature : features) { + fieldMapping.add(String.format( + fieldString, prefix + feature.getName(), feature.getDimension())); + } + return String.format(mappingString, String.join(",\n", fieldMapping)); + } + + protected String createIndex(String prefix, Collection features) { + String jsonString = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn\": \"true\",\n" + + " \"knn.algo_param.ef_search\": 512\n" + + " }\n" + + " },\n" + + " \"mappings\": %s\n" + + "}"; + return String.format(jsonString, createMapping(prefix, features)); + + } + + private String getDefaultVectorDbIndex(Project project) throws FeaturestoreException { + Set indexName = getAllDefaultVectorDbIndex(project); + // randomly select an index + return indexName.stream().sorted(Comparator.comparingInt(i -> new Random().nextInt())).findFirst().get(); + } + + private boolean isDefaultVectorDbIndex(Project project, String index) throws FeaturestoreException { + return getAllDefaultVectorDbIndex(project).contains(index); + } + + private Set getAllDefaultVectorDbIndex(Project project) throws FeaturestoreException { + Set indices; + if (!Strings.isNullOrEmpty(settings.getOpensearchDefaultEmbeddingIndexName())) { + indices = Arrays.stream(settings.getOpensearchDefaultEmbeddingIndexName().split(",")) + .collect(Collectors.toSet()); + } else { + indices = Sets.newHashSet(); + for (int i = 0; i < settings.getOpensearchNumDefaultEmbeddingIndex(); i++) { + indices.add(getVectorDbIndexPrefix(project) + "_default_project_embedding_" + i); + } + } + if (indices.size() == 0) { + throw new FeaturestoreException( + RESTCodes.FeaturestoreErrorCode.OPENSEARCH_DEFAULT_EMBEDDING_INDEX_SUFFIX_NOT_DEFINED, Level.FINE, + "Default vector db index is not defined."); + } + return indices; + } + + private String getVectorDbIndexPrefix(Project project) { + return project.getId() + "__embedding"; + } + + private String getVectorDbColPrefix(Featuregroup featuregroup) { + return Utils.getFeaturegroupName(featuregroup) + "_"; + } + +} diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java new file mode 100644 index 0000000000..10aa7beb20 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/VectorDatabaseClient.java @@ -0,0 +1,67 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.embedding; + +import com.logicalclocks.servicediscoverclient.exceptions.ServiceDiscoveryException; +import io.hops.hopsworks.common.opensearch.OpenSearchClient; +import io.hops.hopsworks.exceptions.FeaturestoreException; +import io.hops.hopsworks.exceptions.OpenSearchException; +import io.hops.hopsworks.restutils.RESTCodes; +import io.hops.hopsworks.vectordb.VectorDatabase; +import io.hops.hopsworks.vectordb.VectorDatabaseFactory; + +import javax.annotation.PreDestroy; +import javax.ejb.ConcurrencyManagement; +import javax.ejb.ConcurrencyManagementType; +import javax.ejb.EJB; +import javax.ejb.Singleton; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; +import java.util.logging.Level; +import java.util.logging.Logger; + +@Singleton +@TransactionAttribute(TransactionAttributeType.NOT_SUPPORTED) +@ConcurrencyManagement(ConcurrencyManagementType.BEAN) +public class VectorDatabaseClient { + + @EJB + private OpenSearchClient openSearchClient; + private VectorDatabase vectorDatabase; + private static final Logger LOG = Logger.getLogger(EmbeddingController.class.getName()); + + public synchronized VectorDatabase getClient() throws FeaturestoreException { + if (vectorDatabase == null) { + try { + vectorDatabase = VectorDatabaseFactory.getOpensearchDatabase(openSearchClient.getClient()); + } catch (OpenSearchException | ServiceDiscoveryException e) { + throw new FeaturestoreException(RESTCodes.FeaturestoreErrorCode.COULD_NOT_CREATE_FEATUREGROUP, + Level.FINE, "Cannot create opensearch vectordb"); + } + } + return vectorDatabase; + } + + @PreDestroy + private void close() { + try { + vectorDatabase.close(); + } catch (Exception ex) { + LOG.log(Level.SEVERE, null, ex); + } + } +} diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/feature/TrainingDatasetFeatureDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/feature/TrainingDatasetFeatureDTO.java index 2eb542664a..59a89d60cd 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/feature/TrainingDatasetFeatureDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/feature/TrainingDatasetFeatureDTO.java @@ -37,13 +37,15 @@ public class TrainingDatasetFeatureDTO { public TrainingDatasetFeatureDTO() { } - public TrainingDatasetFeatureDTO(String name, String type, FeaturegroupDTO featuregroupDTO, Integer index, - Boolean label, Boolean inferenceHelperColumn , Boolean trainingHelperColumn) { + public TrainingDatasetFeatureDTO(String name, String type, FeaturegroupDTO featuregroupDTO, + String featureGroupFeatureName,Integer index, Boolean label, Boolean inferenceHelperColumn , + Boolean trainingHelperColumn) { this.name = name; this.type = type; this.featuregroup = featuregroupDTO; this.index = index; this.label = label; + this.featureGroupFeatureName = featureGroupFeatureName; this.inferenceHelperColumn = inferenceHelperColumn; this.trainingHelperColumn = trainingHelperColumn; } diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java new file mode 100644 index 0000000000..785060132a --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingDTO.java @@ -0,0 +1,42 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.featuregroup; + +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; +import lombok.Getter; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.stream.Collectors; + +@NoArgsConstructor +public class EmbeddingDTO { + + @Getter + private String indexName; + @Getter + private String colPrefix; + @Getter + private List features; + + public EmbeddingDTO(Embedding embedding) { + indexName = embedding.getVectorDbIndexName(); + features = embedding.getEmbeddingFeatures().stream().map(EmbeddingFeatureDTO::new).collect(Collectors.toList()); + colPrefix = embedding.getColPrefix(); + } + +} diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java new file mode 100644 index 0000000000..fa8c37308e --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java @@ -0,0 +1,42 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.featuregroup; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; + +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class EmbeddingFeatureDTO { + + @Getter + private String name; + @Getter + private String similarityFunctionType; + @Getter + private Integer dimension; + + public EmbeddingFeatureDTO(EmbeddingFeature feature) { + name = feature.getName(); + similarityFunctionType = feature.getSimilarityFunctionType(); + dimension = feature.getDimension(); + } +} diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java index a1a0364655..7a264452e1 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupController.java @@ -25,6 +25,7 @@ import io.hops.hopsworks.common.featurestore.datavalidationv2.reports.ValidationReportController; import io.hops.hopsworks.common.featurestore.datavalidationv2.suites.ExpectationSuiteController; import io.hops.hopsworks.common.featurestore.datavalidationv2.suites.ExpectationSuiteDTO; +import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; import io.hops.hopsworks.common.featurestore.featuregroup.cached.CachedFeaturegroupController; import io.hops.hopsworks.common.featurestore.featuregroup.cached.CachedFeaturegroupDTO; @@ -133,6 +134,8 @@ public class FeaturegroupController { private FeatureGroupCommitController featureGroupCommitController; @EJB private SearchFSCommandLogger searchCommandLogger; + @EJB + private EmbeddingController embeddingController; /** * Gets all featuregroups for a particular featurestore and project, using the userCerts to query Hive @@ -203,7 +206,7 @@ public FeaturegroupDTO createFeaturegroup(Featurestore featurestore, Featuregrou enforceFeaturegroupQuotas(featurestore, featuregroupDTO); featureGroupInputValidation.verifySchemaProvided(featuregroupDTO); featureGroupInputValidation.verifyNoDuplicatedFeatures(featuregroupDTO); - + // if version not provided, get latest and increment if (featuregroupDTO.getVersion() == null) { // returns ordered list by desc version @@ -237,10 +240,10 @@ public FeaturegroupDTO createFeaturegroupNoValidation(Featurestore featurestore, boolean isSpine = false; CachedFeaturegroup cachedFeaturegroup = null; StreamFeatureGroup streamFeatureGroup = null; - + // make copy of schema without hudi columns List featuresNoHudi = new ArrayList<>(featuregroupDTO.getFeatures());; - + if (featuregroupDTO instanceof CachedFeaturegroupDTO) { cachedFeaturegroup = cachedFeaturegroupController.createCachedFeaturegroup(featurestore, (CachedFeaturegroupDTO) featuregroupDTO, project, user); @@ -262,7 +265,7 @@ public FeaturegroupDTO createFeaturegroupNoValidation(Featurestore featurestore, //Persist basic feature group metadata Featuregroup featuregroup = persistFeaturegroupMetadata(featurestore, project, user, featuregroupDTO, cachedFeaturegroup, streamFeatureGroup, onDemandFeaturegroup); - + // online feature group needs to be set up after persisting metadata in order to get feature group id // don't setup online storage for spine group for now if (settings.isOnlineFeaturestore() && featuregroup.isOnlineEnabled() && !isSpine) { @@ -278,7 +281,7 @@ public FeaturegroupDTO createFeaturegroupNoValidation(Featurestore featurestore, fsActivityFacade.logMetadataActivity(user, featuregroup, FeaturestoreActivityMeta.FG_CREATED, null); if (featuregroup.getExpectationSuite() != null) { fsActivityFacade.logExpectationSuiteActivity( - user, featuregroup, featuregroup.getExpectationSuite(), + user, featuregroup, featuregroup.getExpectationSuite(), FeaturestoreActivityMeta.EXPECTATION_SUITE_ATTACHED_ON_FG_CREATION, ""); } @@ -429,14 +432,14 @@ public FeaturegroupDTO updateFeaturegroupMetadata(Project project, Users user, F // adding new features // feature group description // feature descriptions - + // Verify general entity related information featurestoreInputValidation.verifyDescription(featuregroupDTO); featureGroupInputValidation.verifyFeatureGroupFeatureList(featuregroupDTO.getFeatures()); featureGroupInputValidation.verifyOnlineOfflineTypeMatch(featuregroupDTO); featureGroupInputValidation.verifyOnlineSchemaValid(featuregroupDTO); featureGroupInputValidation.verifyPrimaryKeySupported(featuregroupDTO); - + // Update on-demand feature group metadata if (featuregroup.getFeaturegroupType() == FeaturegroupType.CACHED_FEATURE_GROUP) { cachedFeaturegroupController @@ -741,7 +744,12 @@ private Featuregroup persistFeaturegroupMetadata(Featurestore featurestore, Proj featuregroup.setCreator(user); featuregroup.setVersion(featuregroupDTO.getVersion()); featuregroup.setDescription(featuregroupDTO.getDescription()); - + // set embedding + if (featuregroupDTO.getEmbeddingIndex() != null) { + featuregroup.setEmbedding( + embeddingController.getEmbedding(project, featuregroupDTO.getEmbeddingIndex(), featuregroup) + ); + } if (featuregroupDTO instanceof CachedFeaturegroupDTO) { featuregroup.setFeaturegroupType(FeaturegroupType.CACHED_FEATURE_GROUP); } else if (featuregroupDTO instanceof StreamFeatureGroupDTO) { @@ -776,7 +784,7 @@ private Featuregroup persistFeaturegroupMetadata(Featurestore featurestore, Proj featuregroup.setExpectationSuite(expectationSuiteController.convertExpectationSuiteDTOToPersistent( featuregroup, featuregroupDTO.getExpectationSuite())); } - + featuregroupFacade.persist(featuregroup); searchCommandLogger.create(featuregroup); if(cachedFeaturegroup != null) { diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupDTO.java index d2f7bc3d01..43094bc11c 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/FeaturegroupDTO.java @@ -59,6 +59,7 @@ public class FeaturegroupDTO extends FeaturestoreEntityDTO { @JsonSetter(nulls = Nulls.SKIP) private Boolean deprecated = false; private String topicName; + private EmbeddingDTO embeddingIndex; public FeaturegroupDTO() { } @@ -76,6 +77,9 @@ public FeaturegroupDTO(Featuregroup featuregroup) { this.eventTime = featuregroup.getEventTime(); this.deprecated = featuregroup.isDeprecated(); this.topicName = featuregroup.getTopicName(); + if (featuregroup.getEmbedding() != null) { + this.embeddingIndex = new EmbeddingDTO(featuregroup.getEmbedding()); + } } // for testing @@ -138,7 +142,15 @@ public Boolean getDeprecated() { public void setDeprecated(Boolean deprecated) { this.deprecated = deprecated; } - + + public EmbeddingDTO getEmbeddingIndex() { + return embeddingIndex; + } + + public void setEmbeddingIndex(EmbeddingDTO embeddingIndex) { + this.embeddingIndex = embeddingIndex; + } + @Override public String toString() { return "FeaturegroupDTO{" + diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java index 99724eaa6d..c78d3dc39c 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/online/OnlineFeaturegroupController.java @@ -19,6 +19,7 @@ import com.google.common.base.Strings; import com.logicalclocks.shaded.org.apache.commons.lang3.StringUtils; import io.hops.hopsworks.common.dao.kafka.TopicDTO; +import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController; import io.hops.hopsworks.common.featurestore.featuregroup.cached.FeaturegroupPreview; @@ -95,17 +96,21 @@ public class OnlineFeaturegroupController { private ConstructorController constructorController; @EJB private FeaturegroupController featuregroupController; - + @EJB + private EmbeddingController embeddingController; + private final static List SUPPORTED_MYSQL_TYPES = Arrays.asList("INT", "TINYINT", - "SMALLINT", "BIGINT", "FLOAT", "DOUBLE", "DECIMAL", "DATE", "TIMESTAMP"); + "SMALLINT", "BIGINT", "FLOAT", "DOUBLE", "DECIMAL", "DATE", "TIMESTAMP"); private final static String VARBINARY_DEFAULT = "VARBINARY(100)"; private final static String VARCHAR_DEFAULT = "VARCHAR(100)"; - public OnlineFeaturegroupController() {} + public OnlineFeaturegroupController() { + } - protected OnlineFeaturegroupController(Settings settings) { + protected OnlineFeaturegroupController(Settings settings, EmbeddingController embeddingController) { this.settings = settings; + this.embeddingController = embeddingController; } /** @@ -134,25 +139,32 @@ public void createMySQLTable(Featurestore featurestore, String tableName, List features, Project project, Users user) + List features, Project project, Users user) throws KafkaException, SchemaException, ProjectException, FeaturestoreException, IOException, HopsSecurityException, ServiceException { // check if onlinefs user is part of project + checkOnlineFsUserExist(project); + + createFeatureGroupKafkaTopic(project, featureGroup, features); + if (featureGroup.getEmbedding() != null) { + embeddingController.createVectorDbIndex(project, featureGroup); + } else { + createMySQLTable(featureStore, Utils.getFeaturegroupName(featureGroup), features, project, user); + } + } + + void checkOnlineFsUserExist(Project project) + throws ServiceException, HopsSecurityException, IOException, ProjectException { if (project.getProjectTeamCollection().stream().noneMatch(pt -> - pt.getUser().getUsername().equals(OnlineFeaturestoreController.ONLINEFS_USERNAME))) { + pt.getUser().getUsername().equals(OnlineFeaturestoreController.ONLINEFS_USERNAME))) { try { // wait for the future projectController.addOnlineFsUser(project).get(); } catch (InterruptedException | ExecutionException e) { throw new ServiceException(RESTCodes.ServiceErrorCode.SERVICE_GENERIC_ERROR, - Level.SEVERE, "failed to add onlinefs user to project: " + project.getName(), e.getMessage(), e); + Level.SEVERE, "failed to add onlinefs user to project: " + project.getName(), e.getMessage(), e); } } - - String featureGroupEntityName = Utils.getFeaturegroupName(featureGroup); - createMySQLTable(featureStore, featureGroupEntityName, features, project, user); - - createFeatureGroupKafkaTopic(project, featureGroup, features); } // For ingesting data in the online feature store, we set up a topic for project/feature group @@ -204,7 +216,7 @@ public void alterOnlineFeatureGroupSchema(Featuregroup featureGroup, List fullNewSchema, Project project) throws FeaturestoreException, SchemaException, KafkaException { @@ -214,11 +226,12 @@ public void alterFeatureGroupSchema(Featuregroup featureGroup, List. + */ + +package io.hops.hopsworks.common.featurestore.embedding; + +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.spy; + +public class EmbeddingControllerTest { + + private EmbeddingController embeddingController; + + @Before + public void setup() { + embeddingController = spy(new EmbeddingController()); + } + + @Test + public void testCreateIndex() { + List features = new ArrayList<>(); + features.add(new EmbeddingFeature(null, "vector", 512, "l2")); + features.add(new EmbeddingFeature(null, "vector2", 128, "l2")); + Assert.assertEquals( + "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"knn\": \"true\",\n" + + " \"knn.algo_param.ef_search\": 512\n" + + " }\n" + + " },\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"vector\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 512\n" + + " },\n" + + " \"vector2\": {\n" + + " \"type\": \"knn_vector\",\n" + + " \"dimension\": 128\n" + + " }\n" + + " }\n" + + " }\n" + + "}", + embeddingController.createIndex("", features)); + } +} \ No newline at end of file diff --git a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java index a59ff06955..678194c117 100644 --- a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java +++ b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/featuregroup/online/TestOnlineFeatureGroupController.java @@ -16,8 +16,15 @@ package io.hops.hopsworks.common.featurestore.featuregroup.online; +import io.hops.hopsworks.common.featurestore.embedding.EmbeddingController; import io.hops.hopsworks.common.featurestore.feature.FeatureGroupFeatureDTO; import io.hops.hopsworks.common.util.Settings; +import io.hops.hopsworks.persistence.entity.featurestore.Featurestore; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; +import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup; +import io.hops.hopsworks.persistence.entity.project.Project; +import io.hops.hopsworks.persistence.entity.user.Users; +import io.hops.hopsworks.vectordb.VectorDatabase; import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -26,18 +33,42 @@ import java.util.ArrayList; import java.util.List; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + public class TestOnlineFeatureGroupController { - + private Settings settings; - + private OnlineFeaturegroupController onlineFeaturegroupController; - + private EmbeddingController embeddingController; + + private VectorDatabase vectorDatabase; + private Project project; + private Featurestore featureStore; + private Featuregroup featureGroup; + @Before public void setup() { - settings = Mockito.mock(Settings.class); - onlineFeaturegroupController = new OnlineFeaturegroupController(settings); + settings = mock(Settings.class); + vectorDatabase = mock(VectorDatabase.class); + featureStore = mock(Featurestore.class); + project = mock(Project.class); + embeddingController = spy(new EmbeddingController()); + onlineFeaturegroupController = spy(new OnlineFeaturegroupController(settings, embeddingController)); + featureGroup = new Featuregroup(); + featureGroup.setEmbedding(null); + featureGroup.setName("fg"); + featureGroup.setVersion(1); } - + @Test public void testBuildAlterStatementNoDefaultValue() { List features = new ArrayList<>(); @@ -46,22 +77,22 @@ public void testBuildAlterStatementNoDefaultValue() { String expected = "ALTER TABLE `db`.`tbl` ADD COLUMN `feature` VARCHAR(100) DEFAULT NULL, ALGORITHM=INPLACE;"; Assert.assertEquals(expected, output); } - + @Test public void testBuildAlterStatementDefaultValueString() { List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("feature", "string", "", - false, "defaultValue")); + false, "defaultValue")); String output = onlineFeaturegroupController.buildAlterStatement("tbl", "db", features); String expected = "ALTER TABLE `db`.`tbl` ADD COLUMN `feature` VARCHAR(100), ALGORITHM=INPLACE;"; Assert.assertEquals(expected, output); } - + @Test public void testBuildAlterStatementDefaultValueOther() { List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("feature", "float", "", - false, "10.0")); + false, "10.0")); String output = onlineFeaturegroupController.buildAlterStatement("tbl", "db", features); String expected = "ALTER TABLE `db`.`tbl` ADD COLUMN `feature` float, ALGORITHM=INPLACE;"; Assert.assertEquals(expected, output); @@ -71,7 +102,7 @@ public void testBuildAlterStatementDefaultValueOther() { public void testBuildAlterStatementMultiColumns() { List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("feature", "float", "", - false, "10.0")); + false, "10.0")); features.add(new FeatureGroupFeatureDTO("feature2", "float", "", false, "19.0")); features.add(new FeatureGroupFeatureDTO("feature3", "float", "")); @@ -80,62 +111,62 @@ public void testBuildAlterStatementMultiColumns() { "ADD COLUMN `feature3` float DEFAULT NULL, ALGORITHM=INPLACE;"; Assert.assertEquals(expected, output); } - + @Test public void testBuildCreateStatementNoDefaultValue() { Mockito.when(settings.getOnlineFeatureStoreTableSpace()).thenReturn(""); - + List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("pk", "int", "", true, false)); features.add(new FeatureGroupFeatureDTO("feature", "String", "", false, false)); - + String output = onlineFeaturegroupController.buildCreateStatement("db", "tbl", features); String expected = "CREATE TABLE IF NOT EXISTS `db`.`tbl`(`pk` int, `feature` VARCHAR(100), PRIMARY KEY (`pk`))" + - "ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; + "ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; Assert.assertEquals(expected, output); } - + @Test public void testBuildCreateStatementTableSpaceNoDefaultValue() { Mockito.when(settings.getOnlineFeatureStoreTableSpace()).thenReturn("abc"); - + List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("pk", "int", "", true, false)); features.add(new FeatureGroupFeatureDTO("feature", "String", "", false, false)); - + String output = onlineFeaturegroupController.buildCreateStatement("db", "tbl", features); String expected = "CREATE TABLE IF NOT EXISTS `db`.`tbl`(`pk` int, `feature` VARCHAR(100), PRIMARY KEY (`pk`))" + - "ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'/*!50100 TABLESPACE `abc` STORAGE DISK */"; + "ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'/*!50100 TABLESPACE `abc` STORAGE DISK */"; Assert.assertEquals(expected, output); } - + @Test public void testBuildCreateStatementDefaultValueString() { Mockito.when(settings.getOnlineFeatureStoreTableSpace()).thenReturn(""); - + List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("pk", "int", "", true, false)); features.add(new FeatureGroupFeatureDTO("feature", "String", "", false, - "hello")); - + "hello")); + String output = onlineFeaturegroupController.buildCreateStatement("db", "tbl", features); String expected = "CREATE TABLE IF NOT EXISTS `db`.`tbl`(`pk` int, `feature` VARCHAR(100) NOT NULL DEFAULT " + - "'hello', PRIMARY KEY (`pk`))ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; + "'hello', PRIMARY KEY (`pk`))ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; Assert.assertEquals(expected, output); } - + @Test public void testBuildCreateStatementDefaultValueOther() { Mockito.when(settings.getOnlineFeatureStoreTableSpace()).thenReturn(""); - + List features = new ArrayList<>(); features.add(new FeatureGroupFeatureDTO("pk", "int", "", true, false)); features.add(new FeatureGroupFeatureDTO("feature", "float", "", false, - "10.0")); - + "10.0")); + String output = onlineFeaturegroupController.buildCreateStatement("db", "tbl", features); String expected = "CREATE TABLE IF NOT EXISTS `db`.`tbl`(`pk` int, `feature` float NOT NULL DEFAULT " + - "10.0, PRIMARY KEY (`pk`))ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; + "10.0, PRIMARY KEY (`pk`))ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; Assert.assertEquals(expected, output); } @@ -166,4 +197,55 @@ public void testBuildCreateStatementDecimal() { "ENGINE=ndbcluster COMMENT='NDB_TABLE=READ_BACKUP=1'"; Assert.assertEquals(expected, output); } + + @Test + public void testSetupOnlineFeatureGroupWithEmbedding() throws Exception { + // Arrange + List features = new ArrayList<>(); + Users user = new Users(); + // Set up the scenario where featureGroup.getEmbedding() is not null + Embedding embedding = new Embedding(); + featureGroup.setEmbedding(embedding); + + // Mock the behavior for vectorDatabase initialization + doNothing().when(embeddingController).createVectorDbIndex(any(), any()); + doNothing().when(onlineFeaturegroupController).checkOnlineFsUserExist(eq(project)); + doNothing().when(onlineFeaturegroupController) + .createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), eq(features)); + + // Act + onlineFeaturegroupController.setupOnlineFeatureGroup(featureStore, featureGroup, features, project, user); + + // Assert + // Verify that vectorDatabase.createIndex is called with the correct parameters + verify(embeddingController, times(1)).createVectorDbIndex(any(), any()); + verify(onlineFeaturegroupController, times(1)).checkOnlineFsUserExist(eq(project)); + verify(onlineFeaturegroupController, times(1)).createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), + eq(features)); + } + + @Test + public void testSetupOnlineFeatureGroupWithoutEmbedding() throws Exception { + // Arrange + + List features = new ArrayList<>(); + Users user = new Users(); + + // Mock the behavior for createMySQLTable + doNothing().when(onlineFeaturegroupController) + .createMySQLTable(eq(featureStore), anyString(), anyList(), eq(project), eq(user)); + doNothing().when(onlineFeaturegroupController).checkOnlineFsUserExist(eq(project)); + doNothing().when(onlineFeaturegroupController) + .createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), eq(features)); + + // Act + onlineFeaturegroupController.setupOnlineFeatureGroup(featureStore, featureGroup, features, project, user); + + // Assert + verify(onlineFeaturegroupController, times(1)).createMySQLTable(eq(featureStore), anyString(), anyList(), + eq(project), eq(user)); + verify(onlineFeaturegroupController, times(1)).checkOnlineFsUserExist(eq(project)); + verify(onlineFeaturegroupController, times(1)).createFeatureGroupKafkaTopic(eq(project), eq(featureGroup), + eq(features)); + } } diff --git a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/storageconnectors/TestStorageConnectorUtil.java b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/storageconnectors/TestStorageConnectorUtil.java index 00ce139d99..42410cd3e0 100644 --- a/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/storageconnectors/TestStorageConnectorUtil.java +++ b/hopsworks-common/src/test/io/hops/hopsworks/common/featurestore/storageconnectors/TestStorageConnectorUtil.java @@ -115,7 +115,7 @@ public void testNumberOfStorageConnectorTypes() { // as well as their corresponding tests: // StorageConnectorUtil.getEnabledStorageConnectorTypes() -> testGetEnabledStorageConnectorTypes // StorageConnectorUtil.isStorageConnectorTypeEnabled() -> testIsStorageConnectorTypeEnabled - Assert.assertEquals(FeaturestoreConnectorType.values().length, 9); + Assert.assertEquals(FeaturestoreConnectorType.values().length, 10); } @Test diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Embedding.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Embedding.java new file mode 100644 index 0000000000..3d84a077d2 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Embedding.java @@ -0,0 +1,104 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.featurestore.featuregroup; + +import javax.persistence.Basic; +import javax.persistence.CascadeType; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.GenerationType; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.OneToMany; +import javax.persistence.Table; +import java.util.Collection; +import java.util.Objects; + +@Entity +@Table(name = "embedding", catalog = "hopsworks") +public class Embedding { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Basic(optional = false) + @Column(name = "id") + private Integer id; + @JoinColumn(name = "feature_group_id", referencedColumnName = "id") + private Featuregroup featuregroup; + @Column(name = "vector_db_index_name") + private String vectorDbIndexName; + @Column(name = "col_prefix") + private String colPrefix; + @OneToMany(cascade = CascadeType.ALL, mappedBy = "embedding") + private Collection embeddingFeatures; + + public Embedding() { + } + + public Featuregroup getFeaturegroup() { + return featuregroup; + } + + public void setFeaturegroup(Featuregroup featuregroup) { + this.featuregroup = featuregroup; + } + + public Collection getEmbeddingFeatures() { + return embeddingFeatures; + } + + public void setEmbeddingFeatures( + Collection embeddingFeatures) { + this.embeddingFeatures = embeddingFeatures; + } + + public String getVectorDbIndexName() { + return vectorDbIndexName; + } + + public void setVectorDbIndexName(String vectorDbIndexName) { + this.vectorDbIndexName = vectorDbIndexName; + } + + public String getColPrefix() { + return colPrefix; + } + + public void setColPrefix(String colPrefix) { + this.colPrefix = colPrefix; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Embedding embedding = (Embedding) o; + return Objects.equals(id, embedding.id) + && Objects.equals(vectorDbIndexName, embedding.vectorDbIndexName) && Objects.equals(colPrefix, + embedding.colPrefix) && Objects.equals(embeddingFeatures, embedding.embeddingFeatures); + } + + @Override + public int hashCode() { + return Objects.hash(id, vectorDbIndexName, colPrefix, embeddingFeatures); + } +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java new file mode 100644 index 0000000000..2604ec5aa3 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java @@ -0,0 +1,85 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.featurestore.featuregroup; + +import javax.persistence.Basic; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.GenerationType; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.Table; + +@Entity +@Table(name = "embedding_feature", catalog = "hopsworks") +public class EmbeddingFeature { + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Basic(optional = false) + @Column(name = "id") + private Integer id; + @JoinColumn(name = "embedding_id", referencedColumnName = "id") + private Embedding embedding; + @Column + private String name; + @Column + private Integer dimension; + @Column(name = "similarity_function_type") + private String similarityFunctionType; + + public EmbeddingFeature() { + } + + public EmbeddingFeature(Embedding embedding, String name, Integer dimension, + String similarityFunctionType) { + this.embedding = embedding; + this.name = name; + this.dimension = dimension; + this.similarityFunctionType = similarityFunctionType; + } + + public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension, + String similarityFunctionType) { + this.id = id; + this.embedding = embedding; + this.name = name; + this.dimension = dimension; + this.similarityFunctionType = similarityFunctionType; + } + + public Integer getId() { + return id; + } + + public Embedding getEmbedding() { + return embedding; + } + + public String getName() { + return name; + } + + public Integer getDimension() { + return dimension; + } + + public String getSimilarityFunctionType() { + return similarityFunctionType; + } + +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Featuregroup.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Featuregroup.java index 6d8c54ab0b..697c65b0b9 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Featuregroup.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/Featuregroup.java @@ -131,6 +131,8 @@ public class Featuregroup implements Serializable { private String topicName; @Column(name = "deprecated") private boolean deprecated; + @OneToOne(cascade = CascadeType.ALL, mappedBy = "featuregroup") + private Embedding embedding; @NotNull @Enumerated(EnumType.ORDINAL) @Column(name = "feature_group_type") @@ -326,7 +328,15 @@ public boolean isDeprecated() { public void setDeprecated(boolean deprecated) { this.deprecated = deprecated; } - + + public Embedding getEmbedding() { + return embedding; + } + + public void setEmbedding(Embedding embedding) { + this.embedding = embedding; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -350,6 +360,7 @@ public boolean equals(Object o) { if (!Objects.equals(topicName, that.topicName)) return false; if (!Objects.equals(deprecated, that.deprecated)) return false; if (!Objects.equals(expectationSuite, that.expectationSuite)) return false; + if (!Objects.equals(embedding, that.embedding)) return false; return Objects.equals(statisticsConfig, that.statisticsConfig); } @@ -372,6 +383,7 @@ public int hashCode() { result = 31 * result + (topicName != null ? topicName.hashCode() : 0); result = 31 * result + (deprecated ? 1: 0); result = 31 * result + (expectationSuite != null ? expectationSuite.hashCode(): 0); + result = 31 * result + (embedding != null ? embedding.hashCode(): 0); return result; } } diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/storageconnector/FeaturestoreConnectorType.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/storageconnector/FeaturestoreConnectorType.java index bbf9da6c57..770c8df542 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/storageconnector/FeaturestoreConnectorType.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/storageconnector/FeaturestoreConnectorType.java @@ -25,5 +25,6 @@ public enum FeaturestoreConnectorType { SNOWFLAKE, KAFKA, GCS, - BIGQUERY; + BIGQUERY, + OPENSEARCH; } diff --git a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml index f6e4dfcd18..8960220556 100644 --- a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml +++ b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml @@ -64,6 +64,8 @@ io.hops.hopsworks.persistence.entity.featurestore.storageconnector.gcs.FeatureStoreGcsConnector io.hops.hopsworks.persistence.entity.featurestore.storageconnector.bigquery.FeatureStoreBigqueryConnector io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup + io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding + io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature io.hops.hopsworks.persistence.entity.featurestore.featuregroup.ondemand.OnDemandFeaturegroup io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.CachedFeaturegroup io.hops.hopsworks.persistence.entity.featurestore.featuregroup.cached.FeatureGroupCommit diff --git a/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java b/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java index aeadc96136..64df9c3012 100644 --- a/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java +++ b/hopsworks-rest-utils/src/main/java/io/hops/hopsworks/restutils/RESTCodes.java @@ -1678,7 +1678,9 @@ public enum FeaturestoreErrorCode implements RESTErrorCode { "of a feature view join.", Response.Status.BAD_REQUEST), FEATURE_GROUP_DUPLICATE_FEATURE(224, "Feature list contains duplicate", Response.Status.BAD_REQUEST), HELPER_COL_NOT_FOUND(225, "Could not find helper column in feature view schema", - Response.Status.NOT_FOUND); + Response.Status.NOT_FOUND), + OPENSEARCH_DEFAULT_EMBEDDING_INDEX_SUFFIX_NOT_DEFINED(226, "Opensearch default embedding index not defined", + Response.Status.INTERNAL_SERVER_ERROR); private int code; private String message; diff --git a/pom.xml b/pom.xml index 4b29ecef74..1ad25e9cbc 100644 --- a/pom.xml +++ b/pom.xml @@ -60,6 +60,7 @@ alerting hopsworks-alert hopsworks-service-discovery + vector-db @@ -562,6 +563,17 @@ hopsworks-persistence ${project.version} + + io.hops.hopsworks + vector-db + ${project.version} + + + io.hops.hopsworks + hopsworks-remote-user + ${project.version} + ejb + io.hops.hopsworks hopsworks-rest-utils diff --git a/vector-db/pom.xml b/vector-db/pom.xml new file mode 100644 index 0000000000..371cc18b25 --- /dev/null +++ b/vector-db/pom.xml @@ -0,0 +1,79 @@ + + + + + 4.0.0 + + + io.hops + hopsworks + 3.7.0-SNAPSHOT + ../pom.xml + + + io.hops.hopsworks + vector-db + 3.7.0-SNAPSHOT + vector-db + + + UTF-8 + 1.8 + 1.8 + 5.7.1 + + + + + org.junit.jupiter + junit-jupiter-api + ${junit.version} + test + + + org.junit.jupiter + junit-jupiter-engine + ${junit.version} + test + + + org.opensearch.client + opensearch-rest-high-level-client + 1.3.3 + compile + + + org.projectlombok + lombok + + + com.fasterxml.jackson.core + jackson-databind + + + com.google.guava + guava + + + + + + + + \ No newline at end of file diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/Field.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/Field.java new file mode 100644 index 0000000000..c391739f76 --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/Field.java @@ -0,0 +1,29 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +public class Field { + + @Getter + private String name; + @Getter + private Object type; +} diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/Index.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/Index.java new file mode 100644 index 0000000000..d67aa54589 --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/Index.java @@ -0,0 +1,27 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +public class Index { + + @Getter + private String name; +} diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java new file mode 100644 index 0000000000..8a27c73d92 --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/OpensearchVectorDatabase.java @@ -0,0 +1,249 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.opensearch.OpenSearchException; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.Request; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.Response; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.client.indices.CreateIndexRequest; +import org.opensearch.client.indices.CreateIndexResponse; +import org.opensearch.client.indices.GetIndexRequest; +import org.opensearch.client.indices.GetIndexResponse; +import org.opensearch.client.indices.PutMappingRequest; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.rest.RestStatus; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.stream.Collectors; + + +public class OpensearchVectorDatabase implements VectorDatabase { + + private RestHighLevelClient client = null; + private ObjectMapper objectMapper = new ObjectMapper(); + private static final Logger LOGGER = Logger.getLogger( + OpensearchVectorDatabase.class.getName()); + + public OpensearchVectorDatabase(RestHighLevelClient client) { + this.client = client; + } + + @Override + public void createIndex(Index index, String mapping, Boolean skipIfExist) throws VectorDatabaseException { + try { + if (skipIfExist) { + Request request = new Request("HEAD", "/" + index.getName()); + Response response = client.getLowLevelClient().performRequest(request); + if (response.getStatusLine().getStatusCode() == 200) { + return; + } + } + CreateIndexRequest createIndexRequest = new CreateIndexRequest(index.getName()); + createIndexRequest.source(mapping, XContentType.JSON); + CreateIndexResponse response = client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + if (!response.isAcknowledged()) { + throw new VectorDatabaseException("Failed to create opensearch index: " + index.getName()); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to create opensearch index: " + index.getName() + "Err: " + e); + } + } + + @Override + public void deleteIndex(Index index) throws VectorDatabaseException { + try { + DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest(index.getName()); + AcknowledgedResponse response = client.indices().delete(deleteIndexRequest, RequestOptions.DEFAULT); + if (!response.isAcknowledged()) { + throw new VectorDatabaseException("Failed to delete opensearch index: " + index.getName()); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to delete opensearch index: " + index.getName() + "Err: " + e); + } + } + + @Override + public void addFields(Index index, String mapping) throws VectorDatabaseException { + PutMappingRequest request = new PutMappingRequest(index.getName()); + request.source(mapping, XContentType.JSON); + try { + AcknowledgedResponse response = client.indices().putMapping(request, RequestOptions.DEFAULT); + if (!response.isAcknowledged()) { + throw new VectorDatabaseException("Failed to add fields to opensearch index: " + index.getName()); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to add fields to opensearch index: " + index.getName() + "Err: " + e); + } + } + + @Override + public List getSchema(Index index) throws VectorDatabaseException { + // Create a GetIndexRequest + GetIndexRequest request = new GetIndexRequest(index.getName()); + // Get the index mapping + try { + GetIndexResponse response = client.indices().get(request, RequestOptions.DEFAULT); + Object mapping = response.getMappings().get(index.getName()).getSourceAsMap().getOrDefault("properties", null); + if (mapping != null) { + return ((Map) mapping).entrySet().stream() + .map(entry -> new Field(entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + } else { + return Lists.newArrayList(); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to get schema from opensearch index: " + index.getName() + "Err: " + e); + } + } + + @Override + public void write(Index index, String data, String docId) throws VectorDatabaseException { + try { + IndexRequest indexRequest = makeIndexRequest(index.getName(), data, docId); + IndexResponse response = client.index(indexRequest, RequestOptions.DEFAULT); + if (!(response.status().equals(RestStatus.CREATED) || response.status().equals(RestStatus.OK))) { + throw new VectorDatabaseException("Cannot index data. Status: " + response.status()); + } + } catch (IOException | OpenSearchException e) { + throw new VectorDatabaseException("Cannot index data. Err: " + e); + } + } + + private IndexRequest makeIndexRequest(String indexName, String data, String docId) { + IndexRequest indexRequest = new IndexRequest(indexName) + .source(data, XContentType.JSON); + if (docId != null) { + indexRequest.id(docId); + } + return indexRequest; + } + + @Override + public void write(Index index, String data) throws VectorDatabaseException { + write(index, data, null); + } + + @Override + public void batchWrite(Index index, List data) throws VectorDatabaseException { + BulkRequest bulkRequest = new BulkRequest(); + for (String doc : data) { + bulkRequest.add( + makeIndexRequest(index.getName(), doc, null) + ); + } + bulkRequest(bulkRequest); + } + + @Override + public void batchWrite(Index index, Map data) throws VectorDatabaseException { + BulkRequest bulkRequest = new BulkRequest(); + for (Map.Entry entry : data.entrySet()) { + bulkRequest.add( + makeIndexRequest(index.getName(), entry.getValue(), entry.getKey()) + ); + } + bulkRequest(bulkRequest); + } + + private void bulkRequest(BulkRequest bulkRequest) throws VectorDatabaseException { + try { + BulkResponse response = client.bulk(bulkRequest, RequestOptions.DEFAULT); + if (response.status().getStatus() != 200) { + throw new VectorDatabaseException( + String.format("Cannot index data. Response status %d; Message: %s", + response.status().getStatus(), response.buildFailureMessage()) + ); + } else { + if (response.hasFailures()) { + throw new VectorDatabaseException( + String.format("Index data failed partially. Response status %d; Message: %s", + response.status().getStatus(), response.buildFailureMessage()) + ); + } + } + } catch (IOException e) { + throw new VectorDatabaseException("Cannot index data. Err: " + e); + } + } + + @Override + public void writeMap(Index index, Map data) throws VectorDatabaseException { + writeMap(index, data, null); + } + + @Override + public void writeMap(Index index, Map data, String docId) throws VectorDatabaseException { + try { + write(index, objectMapper.writeValueAsString(data), docId); + } catch (IOException e) { + throw new VectorDatabaseException("Failed to index data because data cannot be written to String."); + } + } + + @Override + public void batchWriteMap(Index index, List> data) throws VectorDatabaseException { + List batchData = Lists.newArrayList(); + try { + for (Map map : data) { + batchData.add(objectMapper.writeValueAsString(map)); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to index data because data cannot be written to String."); + } + batchWrite(index, batchData); + } + + @Override + public void batchWriteMap(Index index, Map> data) + throws VectorDatabaseException { + Map batchData = Maps.newHashMap(); + try { + for (Map.Entry> entry : data.entrySet()) { + batchData.put(entry.getKey(), objectMapper.writeValueAsString(entry.getValue())); + } + } catch (IOException e) { + throw new VectorDatabaseException("Failed to index data because data cannot be written to String."); + } + batchWrite(index, batchData); + } + + @Override + public void close() { + if (client != null) { + try { + client.close(); + client = null; + } catch (IOException e) { + throw new OpenSearchException("Error while shuting down client"); + } + } + } +} diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java new file mode 100644 index 0000000000..3f2cac55f6 --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabase.java @@ -0,0 +1,36 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import java.util.List; +import java.util.Map; + +public interface VectorDatabase { + void createIndex(Index index, String mapping, Boolean skipIfExist) throws VectorDatabaseException; + void deleteIndex(Index index) throws VectorDatabaseException; + void addFields(Index index, String mapping) throws VectorDatabaseException; + List getSchema(Index index) throws VectorDatabaseException; + void writeMap(Index index, Map data) throws VectorDatabaseException; + void writeMap(Index index, Map data, String docId) throws VectorDatabaseException; + void batchWriteMap(Index index, List> data) throws VectorDatabaseException; + void batchWriteMap(Index index, Map> data) throws VectorDatabaseException; + void write(Index index, String data) throws VectorDatabaseException; + void write(Index index, String data, String docId) throws VectorDatabaseException; + void batchWrite(Index index, List data) throws VectorDatabaseException; + void batchWrite(Index index, Map data) throws VectorDatabaseException; + void close(); +} diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseException.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseException.java new file mode 100644 index 0000000000..21256b093a --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseException.java @@ -0,0 +1,26 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import java.io.Serializable; + +public class VectorDatabaseException extends Exception implements Serializable { + + public VectorDatabaseException(String msg) { + super(msg); + } +} diff --git a/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java new file mode 100644 index 0000000000..4ab79010af --- /dev/null +++ b/vector-db/src/main/java/io/hops/hopsworks/vectordb/VectorDatabaseFactory.java @@ -0,0 +1,78 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.vectordb; + +import org.apache.http.HttpHost; +import org.apache.http.auth.AuthScope; +import org.apache.http.auth.UsernamePasswordCredentials; +import org.apache.http.client.CredentialsProvider; +import org.apache.http.conn.ssl.NoopHostnameVerifier; +import org.apache.http.impl.client.BasicCredentialsProvider; +import org.apache.http.impl.nio.reactor.IOReactorConfig; +import org.apache.http.ssl.SSLContexts; +import org.opensearch.client.RestClient; +import org.opensearch.client.RestHighLevelClient; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; + +public class VectorDatabaseFactory { + + public static VectorDatabase getOpensearchDatabase(String host, String user, String password, String certPath, + String trustStorePassword) throws IOException { + SSLContext sslCtx = null; + Path trustStore = Paths.get(certPath); + char[] trustStorePw = null; + if (trustStorePassword != null) { + trustStorePw = + trustStorePassword.toCharArray(); + } + try { + sslCtx = SSLContexts.custom() + .loadTrustMaterial(trustStore.toFile(), trustStorePw) + .build(); + } catch (GeneralSecurityException | IOException e) { + throw new IOException("Failed to load ssl context."); + } + CredentialsProvider credentialsProvider = + new BasicCredentialsProvider(); + credentialsProvider.setCredentials(AuthScope.ANY, + new UsernamePasswordCredentials(user, password)); + final SSLContext finalSslCtx = sslCtx; + final CredentialsProvider finalCredentialsProvider = credentialsProvider; + + RestHighLevelClient client = new RestHighLevelClient( + RestClient.builder(HttpHost.create(host)) + .setHttpClientConfigCallback(httpAsyncClientBuilder -> { + httpAsyncClientBuilder.setDefaultIOReactorConfig( + IOReactorConfig.custom().setIoThreadCount(5).build()); + return httpAsyncClientBuilder + .setSSLContext(finalSslCtx) + .setDefaultCredentialsProvider(finalCredentialsProvider) + .setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE); + })); + return new OpensearchVectorDatabase(client); + } + + public static VectorDatabase getOpensearchDatabase(RestHighLevelClient client) { + return new OpensearchVectorDatabase(client); + } +} +