Skip to content

Commit

Permalink
[FSTORE-1190] Attach model to embedding feature (#1481)
Browse files Browse the repository at this point in the history
* save model to embedding feature

* remove modelId

* address comment

* fix NPE

(cherry picked from commit 6e36379)
  • Loading branch information
kennethmhc authored Feb 12, 2024
1 parent 79a7139 commit 6aadde2
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import io.hops.hopsworks.common.featurestore.featuregroup.EmbeddingDTO;
import io.hops.hopsworks.common.featurestore.featuregroup.FeaturegroupController;
import io.hops.hopsworks.common.hdfs.Utils;
import io.hops.hopsworks.common.models.ModelFacade;
import io.hops.hopsworks.common.models.version.ModelVersionFacade;
import io.hops.hopsworks.common.util.Settings;
import io.hops.hopsworks.exceptions.FeaturestoreException;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Embedding;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.EmbeddingFeature;
import io.hops.hopsworks.persistence.entity.featurestore.featuregroup.Featuregroup;
import io.hops.hopsworks.persistence.entity.models.Model;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.restutils.RESTCodes;
import io.hops.hopsworks.vectordb.Index;
Expand Down Expand Up @@ -55,6 +59,10 @@ public class EmbeddingController {
private VectorDatabaseClient vectorDatabaseClient;
@EJB
private FeaturegroupController featuregroupController;
@EJB
private ModelVersionFacade modelVersionFacade;
@EJB
private ModelFacade modelFacade;

public void createVectorDbIndex(Project project, Featuregroup featureGroup)
throws FeaturestoreException {
Expand All @@ -72,6 +80,11 @@ public void createVectorDbIndex(Project project, Featuregroup featureGroup)
}
}

private ModelVersion getModel(Integer projectId, String modelName, Integer modelVersion) {
Model model = modelFacade.findByProjectIdAndName(projectId, modelName);
return modelVersionFacade.findByProjectAndMlId(model.getId(), modelVersion);
}

public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featuregroup featuregroup)
throws FeaturestoreException {
Embedding embedding = new Embedding();
Expand All @@ -94,8 +107,19 @@ public Embedding getEmbedding(Project project, EmbeddingDTO embeddingDTO, Featur
embedding.setEmbeddingFeatures(
embeddingDTO.getFeatures()
.stream()
.map(mapping -> new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
mapping.getSimilarityFunctionType()))
.map(mapping -> {
if (mapping.getModel() != null) {
return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
mapping.getSimilarityFunctionType(),
getModel(mapping.getModel().getModelRegistryId(),
mapping.getModel().getModelName(),
mapping.getModel().getModelVersion()));
} else {
return new EmbeddingFeature(embedding, mapping.getName(), mapping.getDimension(),
mapping.getSimilarityFunctionType());
}
}
)
.collect(Collectors.toList())
);
return embedding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,20 @@ public class EmbeddingFeatureDTO {
private String similarityFunctionType;
@Getter
private Integer dimension;
@Getter
private ModelDto model;


public EmbeddingFeatureDTO(EmbeddingFeature feature) {
name = feature.getName();
similarityFunctionType = feature.getSimilarityFunctionType();
dimension = feature.getDimension();
if (feature.getModelVersion() != null) {
model = new ModelDto(
// model registry id is same as project id
feature.getModelVersion().getModel().getProject().getId(),
feature.getModelVersion().getModel().getName(),
feature.getModelVersion().getModelVersionPK().getVersion());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* This file is part of Hopsworks
* Copyright (C) 2024, Hopsworks AB. All rights reserved
*
* Hopsworks is free software: you can redistribute it and/or modify it under the terms of
* the GNU Affero General Public License as published by the Free Software Foundation,
* either version 3 of the License, or (at your option) any later version.
*
* Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
* without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
* PURPOSE. See the GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License along with this program.
* If not, see <https://www.gnu.org/licenses/>.
*/

package io.hops.hopsworks.common.featurestore.featuregroup;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@AllArgsConstructor
@JsonIgnoreProperties(ignoreUnknown = true)
public class ModelDto {

@Getter
private Integer modelRegistryId;
@Getter
private String modelName;
@Getter
private Integer modelVersion;

}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ public Model findByProjectAndName(Project project, String name) {
}
}

public Model findByProjectIdAndName(Integer projectId, String name) {
TypedQuery<Model> query = em.createNamedQuery("Model.findByProjectIdAndName", Model.class);
query.setParameter("name", name).setParameter("projectId", projectId);
try {
return query.getSingleResult();
} catch (NoResultException e) {
return null;
}
}

public CollectionInfo findByProject(Integer offset, Integer limit,
Set<? extends AbstractFacade.FilterBy> filters,
Set<? extends AbstractFacade.SortBy> sorts, Project project) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,23 @@

package io.hops.hopsworks.persistence.entity.featurestore.featuregroup;

import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;

import javax.persistence.Basic;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.JoinColumn;
import javax.persistence.JoinColumns;
import javax.persistence.OneToOne;
import javax.persistence.Table;
import javax.xml.bind.annotation.XmlRootElement;

@Entity
@Table(name = "embedding_feature", catalog = "hopsworks")
@XmlRootElement
public class EmbeddingFeature {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
Expand All @@ -41,6 +47,11 @@ public class EmbeddingFeature {
private Integer dimension;
@Column(name = "similarity_function_type")
private String similarityFunctionType;
@JoinColumns({
@JoinColumn(name = "hsml_model_version", referencedColumnName = "version"),
@JoinColumn(name = "hsml_model_id", referencedColumnName = "model_id")})
@OneToOne
private ModelVersion modelVersion;

public EmbeddingFeature() {
}
Expand All @@ -53,6 +64,15 @@ public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
this.similarityFunctionType = similarityFunctionType;
}

public EmbeddingFeature(Embedding embedding, String name, Integer dimension,
String similarityFunctionType, ModelVersion modelVersion) {
this.embedding = embedding;
this.name = name;
this.dimension = dimension;
this.similarityFunctionType = similarityFunctionType;
this.modelVersion = modelVersion;
}

public EmbeddingFeature(Integer id, Embedding embedding, String name, Integer dimension,
String similarityFunctionType) {
this.id = id;
Expand Down Expand Up @@ -82,4 +102,7 @@ public String getSimilarityFunctionType() {
return similarityFunctionType;
}

public ModelVersion getModelVersion() {
return modelVersion;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
@Table(name = "model", catalog = "hopsworks")
@XmlRootElement
@NamedQueries({
@NamedQuery(name = "Model.findAll",
query = "SELECT m FROM Model m"),
@NamedQuery(name = "Model.findByProjectAndName",
query
= "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),})
@NamedQuery(name = "Model.findAll",
query = "SELECT m FROM Model m"),
@NamedQuery(name = "Model.findByProjectAndName",
query
= "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),
@NamedQuery(name = "Model.findByProjectIdAndName",
query
= "SELECT m FROM Model m WHERE m.name = :name AND m.project.id = :projectId"),})
public class Model implements Serializable {
private static final long serialVersionUID = 1L;

Expand Down

0 comments on commit 6aadde2

Please sign in to comment.