Skip to content

Commit

Permalink
[HWORKS-987] model version id should be monotonically increasing (#1706)
Browse files Browse the repository at this point in the history
  • Loading branch information
o-alex committed Feb 13, 2024
1 parent b94f853 commit 92a0373
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project
.path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase())
.path(Integer.toString(modelRegistryProject.getId()))
.path(ResourceRequest.Name.MODELS.toString().toLowerCase())
.path(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion())
.path(modelVersion.getModel().getName() + "_" + modelVersion.getVersion())
.build());
return dto;
}
Expand Down Expand Up @@ -145,9 +145,9 @@ public ModelDTO build(UriInfo uriInfo,
ModelDTO modelDTO = new ModelDTO();
uri(modelDTO, uriInfo, userProject, modelRegistryProject, modelVersion);
if (expand(modelDTO, resourceRequest).isExpand()) {
modelDTO.setId(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion());
modelDTO.setId(modelVersion.getModel().getName() + "_" + modelVersion.getVersion());
modelDTO.setName(modelVersion.getModel().getName());
modelDTO.setVersion(modelVersion.getModelVersionPK().getVersion());
modelDTO.setVersion(modelVersion.getVersion());
modelDTO.setUserFullName(modelVersion.getUserFullName());
modelDTO.setCreated(modelVersion.getCreated().getTime());
modelDTO.setMetrics(modelVersion.getMetrics().getAttributes());
Expand All @@ -162,8 +162,7 @@ public ModelDTO build(UriInfo uriInfo,
modelDTO.setCreator(usersBuilder.build(uriInfo, resourceRequest, modelVersion.getCreator()));

DatasetPath modelDsPath = datasetHelper.getDatasetPath(userProject,
modelUtils.getModelFullPath(modelRegistryProject, modelVersion.getModel().getName(),
modelVersion.getModelVersionPK().getVersion()),
modelUtils.getModelFullPath(modelRegistryProject, modelVersion.getModel().getName(), modelVersion.getVersion()),
DatasetType.DATASET);
ModelRegistryTagUri tagUri = new ModelRegistryTagUri(uriInfo, modelRegistryProject,
ResourceRequest.Name.MODELS, modelDTO.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import io.hops.hopsworks.persistence.entity.models.Model;
import io.hops.hopsworks.persistence.entity.models.version.Metrics;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersion;
import io.hops.hopsworks.persistence.entity.models.version.ModelVersionPK;
import io.hops.hopsworks.persistence.entity.project.Project;
import io.hops.hopsworks.persistence.entity.user.Users;
import io.hops.hopsworks.restutils.RESTCodes;
Expand Down Expand Up @@ -123,10 +122,8 @@ public ModelVersion createModelVersion(ModelsController.Accessor accessor, Model
modelVersion.setExperimentProjectName(modelDTO.getExperimentProjectName());
modelVersion.setCreator(accessor.user);

ModelVersionPK modelVersionPK = new ModelVersionPK();
modelVersionPK.setVersion(modelDTO.getVersion());
modelVersionPK.setModelId(model.getId());
modelVersion.setModelVersionPK(modelVersionPK);
modelVersion.setVersion(modelDTO.getVersion());
modelVersion.setModel(model);

//Only attach program and environment if exporting inside Hopsworks
if (!Strings.isNullOrEmpty(jobName) || !Strings.isNullOrEmpty(kernelId)) {
Expand Down Expand Up @@ -164,7 +161,7 @@ public void delete(Users user, Project userProject, Project parentProject, Model

String modelPath = Utils.getProjectPath(userProject.getName())
+ parentProject.getName() + "::" + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName()
+ "/" + modelVersion.getModelVersionPK().getVersion();
+ "/" + modelVersion.getVersion();
deleteInternal(user, userProject, modelPath, modelVersion);
}
}
Expand All @@ -174,8 +171,7 @@ public void delete(Users user, Project project, ModelVersion modelVersion) throw
verifyNoModelDeployments(project, modelVersion);

String modelPath = Utils.getProjectPath(project.getName())
+ Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + "/"
+ modelVersion.getModelVersionPK().getVersion();
+ Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + "/" + modelVersion.getVersion();
deleteInternal(user, project, modelPath, modelVersion);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ public Response put(@PathParam("id") String id,
ModelVersion modelVersion = modelsController.createModelVersion(accessor, modelDTO, jobName, kernelId);
ModelDTO dto = modelsBuilder.build(uriInfo, new ResourceRequest(ResourceRequest.Name.MODELS), user, userProject,
modelRegistryProject, modelVersion, modelUtils.getModelFullPath(modelProject, modelVersion.getModel().getName(),
modelVersion.getModelVersionPK().getVersion()));
modelVersion.getVersion()));
UriBuilder builder = uriInfo.getAbsolutePathBuilder().path(id);
return Response.created(builder.build()).entity(dto).build();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void setModel(ModelVersion modelVersion) {
@Override
protected DatasetPath getDatasetPath() throws DatasetException {
return datasetHelper.getDatasetPath(project, modelUtils.getModelFullPath(modelRegistry,
modelVersion.getModel().getName(), modelVersion.getModelVersionPK().getVersion()), DatasetType.DATASET);
modelVersion.getModel().getName(), modelVersion.getVersion()), DatasetType.DATASET);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public EmbeddingFeatureDTO(EmbeddingFeature feature) {
// model registry id is same as project id
feature.getModelVersion().getModel().getProject().getId(),
feature.getModelVersion().getModel().getName(),
feature.getModelVersion().getModelVersionPK().getVersion());
feature.getModelVersion().getVersion());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.JoinColumn;
import javax.persistence.JoinColumns;
import javax.persistence.OneToOne;
import javax.persistence.Table;
import javax.xml.bind.annotation.XmlRootElement;
Expand All @@ -47,9 +46,7 @@ public class EmbeddingFeature {
private Integer dimension;
@Column(name = "similarity_function_type")
private String similarityFunctionType;
@JoinColumns({
@JoinColumn(name = "hsml_model_version", referencedColumnName = "version"),
@JoinColumn(name = "hsml_model_id", referencedColumnName = "model_id")})
@JoinColumn(name = "model_version_id", referencedColumnName = "id")
@OneToOne
private ModelVersion modelVersion;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
import javax.persistence.Basic;
import javax.persistence.Column;
import javax.persistence.Convert;
import javax.persistence.EmbeddedId;
import javax.persistence.Entity;
import javax.persistence.GeneratedValue;
import javax.persistence.GenerationType;
import javax.persistence.Id;
import javax.persistence.JoinColumn;
import javax.persistence.ManyToOne;
import javax.persistence.NamedQueries;
Expand All @@ -36,6 +38,7 @@
import javax.xml.bind.annotation.XmlRootElement;
import java.io.Serializable;
import java.util.Date;
import java.util.Objects;

/**
* A ModelVersion is an instance of a Model.
Expand All @@ -48,22 +51,27 @@
query = "SELECT mv FROM ModelVersion mv"),
@NamedQuery(name = "ModelVersion.findByProjectAndMlId",
query
= "SELECT mv FROM ModelVersion mv WHERE mv.modelVersionPK.version = :version" +
" AND mv.modelVersionPK.modelId = :modelId")
= "SELECT mv FROM ModelVersion mv WHERE mv.version = :version" +
" AND mv.model.id = :modelId")
}
)
public class ModelVersion implements Serializable {

private static final long serialVersionUID = 1L;

@EmbeddedId
private ModelVersionPK modelVersionPK;
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Basic(optional = false)
@Column(name = "id")
private Integer id;

@Basic(optional = false)
@Column(name = "version")
private Integer version;

@ManyToOne(optional = false)
@JoinColumn(name = "model_id",
referencedColumnName = "id",
insertable = false,
updatable = false)
referencedColumnName = "id")
private Model model;

@JoinColumn(name = "user_id",
Expand Down Expand Up @@ -108,6 +116,14 @@ public class ModelVersion implements Serializable {
public ModelVersion() {
}

public Integer getVersion() {
return version;
}

public void setVersion(Integer version) {
this.version = version;
}

public Metrics getMetrics() {
return metrics;
}
Expand Down Expand Up @@ -176,14 +192,6 @@ public void setExperimentProjectName(String experimentProjectName) {
this.experimentProjectName = experimentProjectName;
}

public ModelVersionPK getModelVersionPK() {
return modelVersionPK;
}

public void setModelVersionPK(ModelVersionPK modelVersionPK) {
this.modelVersionPK = modelVersionPK;
}

public Model getModel() {
return model;
}
Expand All @@ -193,7 +201,7 @@ public void setModel(Model model) {
}

public String getMlId() {
return model.getName() + "_" + modelVersionPK.getVersion();
return model.getName() + "_" + version;
}

public Users getCreator() {
Expand All @@ -206,9 +214,7 @@ public void setCreator(Users creator) {

@Override
public int hashCode() {
int hash = 0;
hash += (getModelVersionPK() != null ? getModelVersionPK().hashCode() : 0);
return hash;
return Objects.hash(id);
}

@Override
Expand All @@ -218,8 +224,7 @@ public boolean equals(Object object) {
return false;
}
ModelVersion other = (ModelVersion) object;
if ((this.getModelVersionPK() == null && other.getModelVersionPK() != null) ||
(this.getModelVersionPK() != null && !this.getModelVersionPK().equals(other.getModelVersionPK()))) {
if ((this.id == null && other.id != null) || (this.id != null && !Objects.equals(id, other.id))) {
return false;
}
return true;
Expand Down

This file was deleted.

0 comments on commit 92a0373

Please sign in to comment.