From 6aadde2d85f4a3fb09362ee51ce9c73bbd96b465 Mon Sep 17 00:00:00 2001 From: kennethmhc Date: Mon, 12 Feb 2024 10:33:21 +0100 Subject: [PATCH] [FSTORE-1190] Attach model to embedding feature (#1481) * save model to embedding feature * remove modelId * address comment * fix NPE (cherry picked from commit 6e36379eb4e0e375e01b2874581eac2a2b21ed8d) --- .../embedding/EmbeddingController.java | 28 +++++++++++++-- .../featuregroup/EmbeddingFeatureDTO.java | 10 ++++++ .../featurestore/featuregroup/ModelDto.java | 36 +++++++++++++++++++ .../hopsworks/common/models/ModelFacade.java | 10 ++++++ .../featuregroup/EmbeddingFeature.java | 23 ++++++++++++ .../persistence/entity/models/Model.java | 13 ++++--- 6 files changed, 113 insertions(+), 7 deletions(-) create mode 100644 hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ModelDto.java diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java index 074a88befd..bb6d6133d7 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/embedding/EmbeddingController.java @@ -22,11 +22,15 @@ import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO; import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController; import io.hops.hopsworks.common.hdfs.Utils; +import io.hops.hopsworks.common.models.ModelFacade; +import io.hops.hopsworks.common.models.version.ModelVersionFacade; import io.hops.hopsworks.common.util.Settings; import io.hops.hopsworks.exceptions.FeaturestoreException; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature; import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup; +import io.hops.hopsworks.persistence.entity.models.Model; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.restutils.RESTCodes; import io.hops.hopsworks.vectordb.Index; @@ -55,6 +59,10 @@ public class EmbeddingController { private VectorDatabaseClient vectorDatabaseClient; @EJB private FeaturegroupController featuregroupController; + @EJB + private ModelVersionFacade modelVersionFacade; + @EJB + private ModelFacade modelFacade; public void createVectorDbIndex(Project project, Featuregroup featureGroup) throws FeaturestoreException { @@ -72,6 +80,11 @@ public void createVectorDbIndex(Project project, Featuregroup featureGroup) } } + private ModelVersion getModel(Integer projectId, String modelName, Integer modelVersion) { + Model model = modelFacade.findByProjectIdAndName(projectId, modelName); + return modelVersionFacade.findByProjectAndMlId(model.getId(), modelVersion); + } + public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup) throws FeaturestoreException { Embedding embedding = new Embedding(); @@ -94,8 +107,19 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur embedding.setEmbeddingFeatures( embeddingDTO.getFeatures() .stream() - .map(mapping -> new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), - mapping.getSimilarityFunctionType())) + .map(mapping -> { + if (mapping.getModel() != null) { + return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), + mapping.getSimilarityFunctionType(), + getModel(mapping.getModel().getModelRegistryId(), + mapping.getModel().getModelName(), + mapping.getModel().getModelVersion())); + } else { + return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(), + mapping.getSimilarityFunctionType()); + } + } + ) .collect(Collectors.toList()) ); return embedding; diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java index fa8c37308e..e93e5a6bb4 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java @@ -33,10 +33,20 @@ public class EmbeddingFeatureDTO { private String similarityFunctionType; @Getter private Integer dimension; + @Getter + private ModelDto model; + public EmbeddingFeatureDTO(EmbeddingFeature feature) { name = feature.getName(); similarityFunctionType = feature.getSimilarityFunctionType(); dimension = feature.getDimension(); + if (feature.getModelVersion() != null) { + model = new ModelDto( + // model registry id is same as project id + feature.getModelVersion().getModel().getProject().getId(), + feature.getModelVersion().getModel().getName(), + feature.getModelVersion().getModelVersionPK().getVersion()); + } } } diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ModelDto.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ModelDto.java new file mode 100644 index 0000000000..6789376ee0 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/ModelDto.java @@ -0,0 +1,36 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2024, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.featurestore.featuregroup; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; + +@NoArgsConstructor +@AllArgsConstructor +@JsonIgnoreProperties(ignoreUnknown = true) +public class ModelDto { + + @Getter + private Integer modelRegistryId; + @Getter + private String modelName; + @Getter + private Integer modelVersion; + +} \ No newline at end of file diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java index 1dcca51fd4..619fb7aff6 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java @@ -85,6 +85,16 @@ public Model findByProjectAndName(Project project, String name) { } } + public Model findByProjectIdAndName(Integer projectId, String name) { + TypedQuery query = em.createNamedQuery("Model.findByProjectIdAndName", Model.class); + query.setParameter("name", name).setParameter("projectId", projectId); + try { + return query.getSingleResult(); + } catch (NoResultException e) { + return null; + } + } + public CollectionInfo findByProject(Integer offset, Integer limit, Set filters, Set sorts, Project project) { diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java index 2604ec5aa3..2a8eccb2ca 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java @@ -16,6 +16,8 @@ package io.hops.hopsworks.persistence.entity.featurestore.featuregroup; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; + import javax.persistence.Basic; import javax.persistence.Column; import javax.persistence.Entity; @@ -23,10 +25,14 @@ import javax.persistence.GenerationType; import javax.persistence.Id; import javax.persistence.JoinColumn; +import javax.persistence.JoinColumns; +import javax.persistence.OneToOne; import javax.persistence.Table; +import javax.xml.bind.annotation.XmlRootElement; @Entity @Table(name = "embedding_feature", catalog = "hopsworks") +@XmlRootElement public class EmbeddingFeature { @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @@ -41,6 +47,11 @@ public class EmbeddingFeature { private Integer dimension; @Column(name = "similarity_function_type") private String similarityFunctionType; + @JoinColumns({ + @JoinColumn(name = "hsml_model_version", referencedColumnName = "version"), + @JoinColumn(name = "hsml_model_id", referencedColumnName = "model_id")}) + @OneToOne + private ModelVersion modelVersion; public EmbeddingFeature() { } @@ -53,6 +64,15 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension, this.similarityFunctionType = similarityFunctionType; } + public EmbeddingFeature(Embedding embedding, String name, Integer dimension, + String similarityFunctionType, ModelVersion modelVersion) { + this.embedding = embedding; + this.name = name; + this.dimension = dimension; + this.similarityFunctionType = similarityFunctionType; + this.modelVersion = modelVersion; + } + public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension, String similarityFunctionType) { this.id = id; @@ -82,4 +102,7 @@ public String getSimilarityFunctionType() { return similarityFunctionType; } + public ModelVersion getModelVersion() { + return modelVersion; + } } diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java index b6f7cdb955..85bd7db31c 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java @@ -45,11 +45,14 @@ @Table(name = "model", catalog = "hopsworks") @XmlRootElement @NamedQueries({ - @NamedQuery(name = "Model.findAll", - query = "SELECT m FROM Model m"), - @NamedQuery(name = "Model.findByProjectAndName", - query - = "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),}) + @NamedQuery(name = "Model.findAll", + query = "SELECT m FROM Model m"), + @NamedQuery(name = "Model.findByProjectAndName", + query + = "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"), + @NamedQuery(name = "Model.findByProjectIdAndName", + query + = "SELECT m FROM Model m WHERE m.name = :name AND m.project.id = :projectId"),}) public class Model implements Serializable { private static final long serialVersionUID = 1L;