diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java index 3c5fa11012..537c24fa92 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java @@ -86,7 +86,7 @@ public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project .path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase()) .path(Integer.toString(modelRegistryProject.getId())) .path(ResourceRequest.Name.MODELS.toString().toLowerCase()) - .path(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion()) + .path(modelVersion.getModel().getName() + "_" + modelVersion.getVersion()) .build()); return dto; } @@ -145,9 +145,9 @@ public ModelDTO build(UriInfo uriInfo, ModelDTO modelDTO = new ModelDTO(); uri(modelDTO, uriInfo, userProject, modelRegistryProject, modelVersion); if (expand(modelDTO, resourceRequest).isExpand()) { - modelDTO.setId(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion()); + modelDTO.setId(modelVersion.getModel().getName() + "_" + modelVersion.getVersion()); modelDTO.setName(modelVersion.getModel().getName()); - modelDTO.setVersion(modelVersion.getModelVersionPK().getVersion()); + modelDTO.setVersion(modelVersion.getVersion()); modelDTO.setUserFullName(modelVersion.getUserFullName()); modelDTO.setCreated(modelVersion.getCreated().getTime()); modelDTO.setMetrics(modelVersion.getMetrics().getAttributes()); @@ -162,8 +162,7 @@ public ModelDTO build(UriInfo uriInfo, modelDTO.setCreator(usersBuilder.build(uriInfo, resourceRequest, modelVersion.getCreator())); DatasetPath modelDsPath = datasetHelper.getDatasetPath(userProject, - modelUtils.getModelFullPath(modelRegistryProject, modelVersion.getModel().getName(), - modelVersion.getModelVersionPK().getVersion()), + modelUtils.getModelFullPath(modelRegistryProject, modelVersion.getModel().getName(), modelVersion.getVersion()), DatasetType.DATASET); ModelRegistryTagUri tagUri = new ModelRegistryTagUri(uriInfo, modelRegistryProject, ResourceRequest.Name.MODELS, modelDTO.getId()); diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java index b9d896cc4e..de486d9688 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java @@ -51,7 +51,6 @@ import io.hops.hopsworks.persistence.entity.models.Model; import io.hops.hopsworks.persistence.entity.models.version.Metrics; import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; -import io.hops.hopsworks.persistence.entity.models.version.ModelVersionPK; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.restutils.RESTCodes; @@ -123,10 +122,8 @@ public ModelVersion createModelVersion(ModelsController.Accessor accessor, Model modelVersion.setExperimentProjectName(modelDTO.getExperimentProjectName()); modelVersion.setCreator(accessor.user); - ModelVersionPK modelVersionPK = new ModelVersionPK(); - modelVersionPK.setVersion(modelDTO.getVersion()); - modelVersionPK.setModelId(model.getId()); - modelVersion.setModelVersionPK(modelVersionPK); + modelVersion.setVersion(modelDTO.getVersion()); + modelVersion.setModel(model); //Only attach program and environment if exporting inside Hopsworks if (!Strings.isNullOrEmpty(jobName) || !Strings.isNullOrEmpty(kernelId)) { @@ -164,7 +161,7 @@ public void delete(Users user, Project userProject, Project parentProject, Model String modelPath = Utils.getProjectPath(userProject.getName()) + parentProject.getName() + "::" + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() - + "/" + modelVersion.getModelVersionPK().getVersion(); + + "/" + modelVersion.getVersion(); deleteInternal(user, userProject, modelPath, modelVersion); } } @@ -174,8 +171,7 @@ public void delete(Users user, Project project, ModelVersion modelVersion) throw verifyNoModelDeployments(project, modelVersion); String modelPath = Utils.getProjectPath(project.getName()) - + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + "/" - + modelVersion.getModelVersionPK().getVersion(); + + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + "/" + modelVersion.getVersion(); deleteInternal(user, project, modelPath, modelVersion); } diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java index f3dcbe2ed1..bb9945223b 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java @@ -231,7 +231,7 @@ public Response put(@PathParam("id") String id, ModelVersion modelVersion = modelsController.createModelVersion(accessor, modelDTO, jobName, kernelId); ModelDTO dto = modelsBuilder.build(uriInfo, new ResourceRequest(ResourceRequest.Name.MODELS), user, userProject, modelRegistryProject, modelVersion, modelUtils.getModelFullPath(modelProject, modelVersion.getModel().getName(), - modelVersion.getModelVersionPK().getVersion())); + modelVersion.getVersion())); UriBuilder builder = uriInfo.getAbsolutePathBuilder().path(id); return Response.created(builder.build()).entity(dto).build(); } finally { diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java index 83b9d9e94b..a9cc79b983 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java @@ -52,7 +52,7 @@ public void setModel(ModelVersion modelVersion) { @Override protected DatasetPath getDatasetPath() throws DatasetException { return datasetHelper.getDatasetPath(project, modelUtils.getModelFullPath(modelRegistry, - modelVersion.getModel().getName(), modelVersion.getModelVersionPK().getVersion()), DatasetType.DATASET); + modelVersion.getModel().getName(), modelVersion.getVersion()), DatasetType.DATASET); } @Override diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java index e93e5a6bb4..5a65a20ff9 100644 --- a/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/featurestore/featuregroup/EmbeddingFeatureDTO.java @@ -46,7 +46,7 @@ public EmbeddingFeatureDTO(EmbeddingFeature feature) { // model registry id is same as project id feature.getModelVersion().getModel().getProject().getId(), feature.getModelVersion().getModel().getName(), - feature.getModelVersion().getModelVersionPK().getVersion()); + feature.getModelVersion().getVersion()); } } } diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java index 2a8eccb2ca..c77c75dbce 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/featurestore/featuregroup/EmbeddingFeature.java @@ -25,7 +25,6 @@ import javax.persistence.GenerationType; import javax.persistence.Id; import javax.persistence.JoinColumn; -import javax.persistence.JoinColumns; import javax.persistence.OneToOne; import javax.persistence.Table; import javax.xml.bind.annotation.XmlRootElement; @@ -47,9 +46,7 @@ public class EmbeddingFeature { private Integer dimension; @Column(name = "similarity_function_type") private String similarityFunctionType; - @JoinColumns({ - @JoinColumn(name = "hsml_model_version", referencedColumnName = "version"), - @JoinColumn(name = "hsml_model_id", referencedColumnName = "model_id")}) + @JoinColumn(name = "model_version_id", referencedColumnName = "id") @OneToOne private ModelVersion modelVersion; diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java index e0b933be45..20f0db50d3 100644 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java @@ -22,8 +22,10 @@ import javax.persistence.Basic; import javax.persistence.Column; import javax.persistence.Convert; -import javax.persistence.EmbeddedId; import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.GenerationType; +import javax.persistence.Id; import javax.persistence.JoinColumn; import javax.persistence.ManyToOne; import javax.persistence.NamedQueries; @@ -36,6 +38,7 @@ import javax.xml.bind.annotation.XmlRootElement; import java.io.Serializable; import java.util.Date; +import java.util.Objects; /** * A ModelVersion is an instance of a Model. @@ -48,22 +51,27 @@ query = "SELECT mv FROM ModelVersion mv"), @NamedQuery(name = "ModelVersion.findByProjectAndMlId", query - = "SELECT mv FROM ModelVersion mv WHERE mv.modelVersionPK.version = :version" + - " AND mv.modelVersionPK.modelId = :modelId") + = "SELECT mv FROM ModelVersion mv WHERE mv.version = :version" + + " AND mv.model.id = :modelId") } ) public class ModelVersion implements Serializable { private static final long serialVersionUID = 1L; - @EmbeddedId - private ModelVersionPK modelVersionPK; + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Basic(optional = false) + @Column(name = "id") + private Integer id; + + @Basic(optional = false) + @Column(name = "version") + private Integer version; @ManyToOne(optional = false) @JoinColumn(name = "model_id", - referencedColumnName = "id", - insertable = false, - updatable = false) + referencedColumnName = "id") private Model model; @JoinColumn(name = "user_id", @@ -108,6 +116,14 @@ public class ModelVersion implements Serializable { public ModelVersion() { } + public Integer getVersion() { + return version; + } + + public void setVersion(Integer version) { + this.version = version; + } + public Metrics getMetrics() { return metrics; } @@ -176,14 +192,6 @@ public void setExperimentProjectName(String experimentProjectName) { this.experimentProjectName = experimentProjectName; } - public ModelVersionPK getModelVersionPK() { - return modelVersionPK; - } - - public void setModelVersionPK(ModelVersionPK modelVersionPK) { - this.modelVersionPK = modelVersionPK; - } - public Model getModel() { return model; } @@ -193,7 +201,7 @@ public void setModel(Model model) { } public String getMlId() { - return model.getName() + "_" + modelVersionPK.getVersion(); + return model.getName() + "_" + version; } public Users getCreator() { @@ -206,9 +214,7 @@ public void setCreator(Users creator) { @Override public int hashCode() { - int hash = 0; - hash += (getModelVersionPK() != null ? getModelVersionPK().hashCode() : 0); - return hash; + return Objects.hash(id); } @Override @@ -218,8 +224,7 @@ public boolean equals(Object object) { return false; } ModelVersion other = (ModelVersion) object; - if ((this.getModelVersionPK() == null && other.getModelVersionPK() != null) || - (this.getModelVersionPK() != null && !this.getModelVersionPK().equals(other.getModelVersionPK()))) { + if ((this.id == null && other.id != null) || (this.id != null && !Objects.equals(id, other.id))) { return false; } return true; diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java deleted file mode 100644 index 962d8be41a..0000000000 --- a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * This file is part of Hopsworks - * Copyright (C) 2023, Hopsworks AB. All rights reserved - * - * Hopsworks is free software: you can redistribute it and/or modify it under the terms of - * the GNU Affero General Public License as published by the Free Software Foundation, - * either version 3 of the License, or (at your option) any later version. - * - * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; - * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - * PURPOSE. See the GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License along with this program. - * If not, see . - */ - -package io.hops.hopsworks.persistence.entity.models.version; - -import java.io.Serializable; -import javax.persistence.Basic; -import javax.persistence.Column; -import javax.persistence.Embeddable; - -@Embeddable -public class ModelVersionPK implements Serializable { - - @Basic(optional = false) - @Column(name = "version") - private Integer version; - - @Basic(optional = false) - @Column(name = "model_id") - private Integer modelId; - - public ModelVersionPK() { - } - - public ModelVersionPK(Integer version, Integer modelId) { - this.version = version; - this.modelId = modelId; - } - - public Integer getModelId() { - return modelId; - } - - public void setModelId(Integer model) { - this.modelId = model; - } - - @Override - public int hashCode() { - int hash = 0; - hash += (int) version; - hash += (modelId != null ? modelId.hashCode() : 0); - return hash; - } - - @Override - public boolean equals(Object object) { - // TODO: Warning - this method won't work in the case the id fields are not set - if (!(object instanceof ModelVersionPK)) { - return false; - } - ModelVersionPK other = (ModelVersionPK) object; - if (this.version != other.version) { - return false; - } - if (this.modelId != other.modelId) { - return false; - } - return true; - } - - @Override - public String toString() { - return "io.hops.hopsworks.persistence.entity.models.version.ModelVersionPK[ model=" + - modelId + ", version=" + version + " ]"; - } - - public Integer getVersion() { - return version; - } - - public void setVersion(Integer version) { - this.version = version; - } -}