diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/FilterBy.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/FilterBy.java index 09933ffdc5..1c64362b51 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/FilterBy.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/FilterBy.java @@ -16,35 +16,45 @@ package io.hops.hopsworks.api.modelregistry.models; import io.hops.hopsworks.common.dao.AbstractFacade; +import io.hops.hopsworks.common.models.version.ModelVersionFacade; public class FilterBy implements AbstractFacade.FilterBy { - private String param = null; - private String value = null; + private final ModelVersionFacade.Filters filter; + private final String param; public FilterBy(String param) { - if(param.contains(":")) { - String[] paramSplit = param.split(":"); - this.param = paramSplit[0]; - this.value = paramSplit[1]; + if (param.contains(":")) { + this.filter = ModelVersionFacade.Filters.valueOf(param.substring(0, param.indexOf(':')).toUpperCase()); + this.param = param.substring(param.indexOf(':') + 1); + } else { + this.filter = ModelVersionFacade.Filters.valueOf(param); + this.param = this.filter.getDefaultParam(); } } + @Override public String getParam() { return param; } + @Override + public String getValue() { + return this.filter.getValue(); + } + @Override public String getSql() { - return null; + return this.filter.getSql(); } @Override public String getField() { - return null; + return this.filter.getField(); } - public String getValue() { - return value; + @Override + public String toString() { + return filter.toString(); } -} +} \ No newline at end of file diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelExpansions.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelExpansions.java index aa0f04e01f..19e02d01d1 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelExpansions.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelExpansions.java @@ -17,6 +17,7 @@ package io.hops.hopsworks.api.modelregistry.models; import io.hops.hopsworks.api.modelregistry.models.provenance.ModelTrainingDatasetResourceRequest; +import io.hops.hopsworks.api.user.UserResourceRequest; import io.hops.hopsworks.common.api.Expansions; import io.hops.hopsworks.common.api.ResourceRequest; @@ -35,6 +36,9 @@ public ModelExpansions(String queryParam) { case TRAININGDATASETS: resourceRequest = new ModelTrainingDatasetResourceRequest(name, queryParam); break; + case USERS: + resourceRequest = new UserResourceRequest(name, queryParam); + break; case MODELSCHEMA: resourceRequest = new ModelSchemaResourceRequest(name, queryParam); break; diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelUtils.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelUtils.java index d09ba64939..12e72aae7d 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelUtils.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelUtils.java @@ -15,7 +15,6 @@ */ package io.hops.hopsworks.api.modelregistry.models; -import com.google.common.base.Strings; import io.hops.hopsworks.api.modelregistry.models.dto.ModelDTO; import io.hops.hopsworks.common.dao.project.ProjectFacade; import io.hops.hopsworks.common.dataset.DatasetController; @@ -23,31 +22,21 @@ import io.hops.hopsworks.common.hdfs.DistributedFsService; import io.hops.hopsworks.common.hdfs.HdfsUsersController; import io.hops.hopsworks.common.hdfs.Utils; -import io.hops.hopsworks.common.provenance.state.dto.ProvStateDTO; -import io.hops.hopsworks.common.python.environment.EnvironmentController; import io.hops.hopsworks.common.util.AccessController; import io.hops.hopsworks.common.util.Settings; import io.hops.hopsworks.exceptions.DatasetException; import io.hops.hopsworks.exceptions.GenericException; -import io.hops.hopsworks.exceptions.JobException; -import io.hops.hopsworks.exceptions.MetadataException; -import io.hops.hopsworks.exceptions.ModelRegistryException; + import io.hops.hopsworks.exceptions.ProjectException; -import io.hops.hopsworks.exceptions.PythonException; -import io.hops.hopsworks.exceptions.ServiceException; import io.hops.hopsworks.persistence.entity.dataset.Dataset; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.restutils.RESTCodes; -import org.json.JSONObject; import javax.ejb.EJB; import javax.ejb.Stateless; import javax.ejb.TransactionAttribute; import javax.ejb.TransactionAttributeType; -import javax.ws.rs.core.Response; -import javax.ws.rs.core.UriBuilder; -import javax.ws.rs.core.UriInfo; import java.util.logging.Level; @Stateless @@ -64,12 +53,6 @@ public class ModelUtils { private HdfsUsersController hdfsUsersController; @EJB private DistributedFsService dfs; - @EJB - private ModelsController modelsController; - @EJB - private EnvironmentController environmentController; - @EJB - private ModelConverter modelConverter; public String getModelsDatasetPath(Project userProject, Project modelRegistryProject) { String modelsPath = Utils.getProjectPath(userProject.getName()) + Settings.HOPS_MODELS_DATASET + "/"; @@ -139,39 +122,10 @@ public ModelsController.Accessor getModelsAccessor(Users user, Project userProje } } - public Response createModel(UriInfo uriInfo, ModelsController.Accessor accessor, String mlId, ModelDTO modelDTO, - String jobName, String kernelId) - throws DatasetException, MetadataException, JobException, ServiceException, PythonException, - ModelRegistryException { - String realName = accessor.user.getFname() + " " + accessor.user.getLname(); - //Only attach program and environment if exporting inside Hopsworks - if (!Strings.isNullOrEmpty(jobName) || !Strings.isNullOrEmpty(kernelId)) { - - modelDTO.setProgram(modelsController.versionProgram(accessor, jobName, kernelId, - modelDTO.getName(), modelDTO.getVersion())); - //Export environment to correct path here - modelDTO.setEnvironment(environmentController.exportEnv(accessor.experimentProject, accessor.user, - getModelFullPath(accessor.modelProject, modelDTO.getName(), modelDTO.getVersion()) + - "/" + Settings.ENVIRONMENT_FILE - )); - } - - modelDTO.setModelRegistryId(accessor.modelProject.getId()); - - modelsController.attachModel(accessor.udfso, accessor.modelProject, realName, modelDTO); - UriBuilder builder = uriInfo.getAbsolutePathBuilder().path(mlId); - return Response.created(builder.build()).entity(modelDTO).build(); - } - public String getModelFullPath(Project modelRegistryProject, String modelName, Integer modelVersion) { return Utils.getProjectPath(modelRegistryProject.getName()) + Settings.HOPS_MODELS_DATASET + "/" + modelName + "/" + modelVersion; } - - public ModelDTO convertProvenanceHitToModel(ProvStateDTO model) throws ModelRegistryException { - JSONObject summary = new JSONObject(model.getXattrs().get(ModelsBuilder.MODEL_SUMMARY_XATTR_NAME)); - return modelConverter.unmarshalDescription(summary.toString()); - } public String[] getModelNameAndVersion(String mlId) { int splitIndex = mlId.lastIndexOf("_"); diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBeanParam.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBeanParam.java index 8933d8633e..a360c29220 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBeanParam.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBeanParam.java @@ -37,6 +37,14 @@ public class ModelsBeanParam { @BeanParam private ModelExpansionBeanParam expansions; + public ModelsBeanParam( + @QueryParam("sort_by") String sortBy, + @QueryParam("filter_by") Set filter) { + this.sortBy = sortBy; + this.sortBySet = getSortBy(sortBy); + this.filter = filter; + } + private Set getSortBy(String param) { if (param == null || param.isEmpty()) { return new LinkedHashSet<>(); @@ -52,10 +60,12 @@ private Set getSortBy(String param) { return sortBys; } - public ModelsBeanParam(@QueryParam("filter_by") Set filter, @QueryParam("sort_by") String sortBy) { - this.filter = filter; + public String getSortBy() { + return sortBy; + } + + public void setSortBy(String sortBy) { this.sortBy = sortBy; - sortBySet = getSortBy(sortBy); } public Set getFilter() { @@ -66,15 +76,6 @@ public void setFilter(Set filter) { this.filter = filter; } - - public String getSortBy() { - return sortBy; - } - - public void setSortBy(String sortBy) { - this.sortBy = sortBy; - } - public Set getSortBySet() { return sortBySet; } diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java index fc583c4c49..3c5fa11012 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsBuilder.java @@ -17,58 +17,33 @@ import io.hops.hopsworks.api.dataset.inode.InodeBuilder; import io.hops.hopsworks.api.dataset.inode.InodeDTO; -import io.hops.hopsworks.api.modelregistry.dto.ModelRegistryDTO; import io.hops.hopsworks.api.modelregistry.models.dto.ModelDTO; import io.hops.hopsworks.api.modelregistry.models.tags.ModelRegistryTagUri; import io.hops.hopsworks.api.tags.TagBuilder; +import io.hops.hopsworks.api.user.UsersBuilder; import io.hops.hopsworks.common.api.ResourceRequest; import io.hops.hopsworks.common.dao.AbstractFacade; -import io.hops.hopsworks.common.dao.hdfsUser.HdfsUsersFacade; -import io.hops.hopsworks.common.dao.project.ProjectFacade; -import io.hops.hopsworks.common.dao.project.team.ProjectTeamFacade; -import io.hops.hopsworks.common.dao.user.UserFacade; -import io.hops.hopsworks.common.dataset.DatasetController; import io.hops.hopsworks.common.dataset.FilePreviewMode; import io.hops.hopsworks.common.dataset.util.DatasetHelper; import io.hops.hopsworks.common.dataset.util.DatasetPath; -import io.hops.hopsworks.common.featurestore.FeaturestoreFacade; -import io.hops.hopsworks.common.hdfs.HdfsUsersController; -import io.hops.hopsworks.common.hdfs.Utils; -import io.hops.hopsworks.common.hdfs.inode.InodeController; -import io.hops.hopsworks.common.provenance.core.Provenance; -import io.hops.hopsworks.common.provenance.state.ProvStateParamBuilder; -import io.hops.hopsworks.common.provenance.state.ProvStateParser; -import io.hops.hopsworks.common.provenance.state.ProvStateController; -import io.hops.hopsworks.common.provenance.state.dto.ProvStateDTO; -import io.hops.hopsworks.common.provenance.util.ProvHelper; +import io.hops.hopsworks.common.models.version.ModelVersionFacade; import io.hops.hopsworks.common.util.Settings; import io.hops.hopsworks.exceptions.DatasetException; import io.hops.hopsworks.exceptions.GenericException; import io.hops.hopsworks.exceptions.MetadataException; import io.hops.hopsworks.exceptions.ModelRegistryException; -import io.hops.hopsworks.exceptions.ProvenanceException; import io.hops.hopsworks.exceptions.FeatureStoreMetadataException; -import io.hops.hopsworks.persistence.entity.dataset.Dataset; import io.hops.hopsworks.persistence.entity.dataset.DatasetType; -import io.hops.hopsworks.persistence.entity.hdfs.inode.Inode; -import io.hops.hopsworks.persistence.entity.hdfs.user.HdfsUsers; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; import io.hops.hopsworks.persistence.entity.project.Project; -import io.hops.hopsworks.persistence.entity.project.team.ProjectTeam; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.restutils.RESTCodes; -import org.apache.hadoop.fs.Path; -import org.javatuples.Pair; -import org.opensearch.search.sort.SortOrder; import javax.ejb.EJB; import javax.ejb.Stateless; import javax.ejb.TransactionAttribute; import javax.ejb.TransactionAttributeType; import javax.ws.rs.core.UriInfo; -import java.util.EnumSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -76,27 +51,11 @@ @TransactionAttribute(TransactionAttributeType.NEVER) public class ModelsBuilder { - public final static String MODEL_SUMMARY_XATTR_NAME = "model_summary"; - private static final Logger LOGGER = Logger.getLogger(ModelsBuilder.class.getName()); + @EJB - private ProvStateController provenanceController; - @EJB - private Settings settings; - @EJB - private UserFacade userFacade; - @EJB - private HdfsUsersFacade hdfsUsersFacade; - @EJB - private HdfsUsersController hdfsUsersController; - @EJB - private ProjectFacade projectFacade; - @EJB - private ProjectTeamFacade projectTeamFacade; - @EJB - private FeaturestoreFacade featurestoreFacade; - @EJB - private ModelsController modelsController; + private ModelVersionFacade modelVersionFacade; + @EJB private InodeBuilder inodeBuilder; @EJB @@ -106,9 +65,7 @@ public class ModelsBuilder { @EJB private TagBuilder tagsBuilder; @EJB - private InodeController inodeController; - @EJB - private DatasetController datasetController; + private UsersBuilder usersBuilder; public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project modelRegistryProject) { dto.setHref(uriInfo.getBaseUriBuilder() @@ -122,14 +79,14 @@ public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project } public ModelDTO uri(ModelDTO dto, UriInfo uriInfo, Project userProject, Project modelRegistryProject, - ProvStateDTO fileProvenanceHit) { + ModelVersion modelVersion) { dto.setHref(uriInfo.getBaseUriBuilder() .path(ResourceRequest.Name.PROJECT.toString().toLowerCase()) .path(Integer.toString(userProject.getId())) .path(ResourceRequest.Name.MODELREGISTRIES.toString().toLowerCase()) .path(Integer.toString(modelRegistryProject.getId())) .path(ResourceRequest.Name.MODELS.toString().toLowerCase()) - .path(fileProvenanceHit.getMlId()) + .path(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion()) .build()); return dto; } @@ -154,41 +111,19 @@ public ModelDTO build(UriInfo uriInfo, expand(dto, resourceRequest); dto.setCount(0l); if(dto.isExpand()) { - validatePagination(resourceRequest); - ProvStateDTO fileState; try { - Pair provFilesParamBuilder - = buildModelProvenanceParams(userProject, modelRegistryProject, resourceRequest); - if(provFilesParamBuilder.getValue1() == null) { - //no endpoint - no results - return dto; - } - Inode paramProjectInode = inodeController.getProjectRoot( - provFilesParamBuilder.getValue1().getParentProject().getName()); - fileState = provenanceController.provFileStateList( - paramProjectInode, - provFilesParamBuilder.getValue0()); - - List models = new LinkedList<>(fileState.getItems()); - - dto.setCount(fileState.getCount()); + AbstractFacade.CollectionInfo models = modelVersionFacade.findByProject( + resourceRequest.getOffset(), resourceRequest.getLimit(), resourceRequest.getFilter(), + resourceRequest.getSort(), modelRegistryProject); + dto.setCount(models.getCount()); String modelsDatasetPath = modelUtils.getModelsDatasetPath(userProject, modelRegistryProject); - for(ProvStateDTO fileProvStateHit: models) { - ModelDTO modelDTO - = build(uriInfo, resourceRequest, user, userProject, - modelRegistryProject, fileProvStateHit, modelsDatasetPath); + for(ModelVersion modelVersion: models.getItems()) { + ModelDTO modelDTO = build(uriInfo, resourceRequest, user, userProject, modelRegistryProject, modelVersion, + modelsDatasetPath); if(modelDTO != null) { dto.addItem(modelDTO); } } - } catch (ProvenanceException e) { - if (ProvHelper.missingMappingForField( e)) { - LOGGER.log(Level.WARNING, "Could not find opensearch mapping for experiments query", e); - return dto; - } else { - throw new ModelRegistryException(RESTCodes.ModelRegistryErrorCode.MODEL_LIST_FAILED, Level.FINE, - "Unable to list models for project " + modelRegistryProject.getName(), e.getMessage(), e); - } } catch(DatasetException e) { throw new ModelRegistryException(RESTCodes.ModelRegistryErrorCode.MODEL_LIST_FAILED, Level.FINE, "Unable to list models for project " + modelRegistryProject.getName(), e.getMessage(), e); @@ -203,175 +138,62 @@ public ModelDTO build(UriInfo uriInfo, Users user, Project userProject, Project modelRegistryProject, - ProvStateDTO fileProvenanceHit, + ModelVersion modelVersion, String modelsFolder) - throws DatasetException, ModelRegistryException, FeatureStoreMetadataException, MetadataException { + throws ModelRegistryException, GenericException, FeatureStoreMetadataException, + MetadataException, DatasetException { ModelDTO modelDTO = new ModelDTO(); - uri(modelDTO, uriInfo, userProject, modelRegistryProject, fileProvenanceHit); + uri(modelDTO, uriInfo, userProject, modelRegistryProject, modelVersion); if (expand(modelDTO, resourceRequest).isExpand()) { - if (fileProvenanceHit.getXattrs() != null - && fileProvenanceHit.getXattrs().containsKey(MODEL_SUMMARY_XATTR_NAME)) { - ModelDTO modelSummary = modelUtils.convertProvenanceHitToModel(fileProvenanceHit); - modelDTO.setId(fileProvenanceHit.getMlId()); - modelDTO.setName(modelSummary.getName()); - modelDTO.setVersion(modelSummary.getVersion()); - modelDTO.setUserFullName(modelSummary.getUserFullName()); - modelDTO.setCreated(fileProvenanceHit.getCreateTime()); - modelDTO.setMetrics(modelSummary.getMetrics()); - modelDTO.setDescription(modelSummary.getDescription()); - modelDTO.setProgram(modelSummary.getProgram()); - modelDTO.setFramework(modelSummary.getFramework()); - DatasetPath modelDsPath = datasetHelper.getDatasetPath(userProject, - modelUtils.getModelFullPath(modelRegistryProject, modelSummary.getName(), modelSummary.getVersion()), - DatasetType.DATASET); - ModelRegistryTagUri tagUri = new ModelRegistryTagUri(uriInfo, modelRegistryProject, - ResourceRequest.Name.MODELS, modelDTO.getId()); - modelDTO.setTags(tagsBuilder.build(tagUri, resourceRequest, user, modelDsPath)); - - String modelVersionPath = modelsFolder + "/" + modelDTO.getName() + "/" + modelDTO.getVersion() + "/"; - - DatasetPath modelSchemaPath = datasetHelper.getDatasetPath(userProject, - modelVersionPath + Settings.HOPS_MODELS_SCHEMA, DatasetType.DATASET); - if(resourceRequest.contains(ResourceRequest.Name.MODELSCHEMA) && modelSchemaPath.getInode() != null) { - InodeDTO modelSchemaDTO = inodeBuilder.buildBlob(uriInfo, new ResourceRequest(ResourceRequest.Name.INODES), - user, modelSchemaPath, modelSchemaPath.getInode(), FilePreviewMode.HEAD); - modelDTO.setModelSchema(modelSchemaDTO); - } else { - InodeDTO modelSchemaDTO = inodeBuilder.buildResource(uriInfo, modelRegistryProject, modelSchemaPath); - modelDTO.setModelSchema(modelSchemaDTO); - } - - DatasetPath inputExamplePath = datasetHelper.getDatasetPath(userProject, - modelVersionPath + Settings.HOPS_MODELS_INPUT_EXAMPLE, DatasetType.DATASET); - if(resourceRequest.contains(ResourceRequest.Name.INPUTEXAMPLE) && inputExamplePath.getInode() != null) { - InodeDTO inputExampleDTO = inodeBuilder.buildBlob(uriInfo, new ResourceRequest(ResourceRequest.Name.INODES), - user, inputExamplePath, - inputExamplePath.getInode(), FilePreviewMode.HEAD); - modelDTO.setInputExample(inputExampleDTO); - } else { - InodeDTO inputExampleDTO = inodeBuilder.buildResource(uriInfo, modelRegistryProject, inputExamplePath); - modelDTO.setInputExample(inputExampleDTO); - } - - modelDTO.setEnvironment(modelSummary.getEnvironment()); - modelDTO.setExperimentId(modelSummary.getExperimentId()); - modelDTO.setExperimentProjectName(modelSummary.getExperimentProjectName()); - modelDTO.setProjectName(modelSummary.getProjectName()); - modelDTO.setModelRegistryId(modelRegistryProject.getId()); - } - } - return modelDTO; - } - - private Pair buildFilter(Project project, - Project modelRegistryProject, Set filters) - throws GenericException, ProvenanceException, DatasetException { - ProvStateParamBuilder provFilesParamBuilder = new ProvStateParamBuilder(); - if(filters != null) { - Users filterUser = null; - Project filterUserProject = project; - for (AbstractFacade.FilterBy filterBy : filters) { - if(filterBy.getParam().compareToIgnoreCase(Filters.NAME_EQ.name()) == 0) { - provFilesParamBuilder.filterByXAttr(MODEL_SUMMARY_XATTR_NAME + ".name", filterBy.getValue()); - } else if(filterBy.getParam().compareToIgnoreCase(Filters.NAME_LIKE.name()) == 0) { - provFilesParamBuilder.filterLikeXAttr(MODEL_SUMMARY_XATTR_NAME + ".name", filterBy.getValue()); - } else if(filterBy.getParam().compareToIgnoreCase(Filters.VERSION.name()) == 0) { - provFilesParamBuilder.filterByXAttr(MODEL_SUMMARY_XATTR_NAME + ".version", filterBy.getValue()); - } else if(filterBy.getParam().compareToIgnoreCase(Filters.ID_EQ.name()) == 0) { - provFilesParamBuilder.filterByXAttr(MODEL_SUMMARY_XATTR_NAME + ".id", filterBy.getValue()); - } else if (filterBy.getParam().compareToIgnoreCase(Filters.USER.name()) == 0) { - try { - filterUser = userFacade.find(Integer.parseInt(filterBy.getValue())); - } catch(NumberFormatException e) { - throw new GenericException(RESTCodes.GenericErrorCode.ILLEGAL_ARGUMENT, Level.INFO, - "expected int user id, found: " + filterBy.getValue()); - } - } else if (filterBy.getParam().compareToIgnoreCase(Filters.USER_PROJECT.name()) == 0) { - try { - filterUserProject = projectFacade.find(Integer.parseInt(filterBy.getValue())); - } catch(NumberFormatException e) { - throw new GenericException(RESTCodes.GenericErrorCode.ILLEGAL_ARGUMENT, Level.INFO, - "expected int user project id, found: " + filterBy.getValue()); - } - } else { - throw new GenericException(RESTCodes.GenericErrorCode.ILLEGAL_ARGUMENT, Level.INFO, - "Filter by - found: " + filterBy.getParam() + " expected:" + EnumSet.allOf(Filters.class)); - } + modelDTO.setId(modelVersion.getModel().getName() + "_" + modelVersion.getModelVersionPK().getVersion()); + modelDTO.setName(modelVersion.getModel().getName()); + modelDTO.setVersion(modelVersion.getModelVersionPK().getVersion()); + modelDTO.setUserFullName(modelVersion.getUserFullName()); + modelDTO.setCreated(modelVersion.getCreated().getTime()); + modelDTO.setMetrics(modelVersion.getMetrics().getAttributes()); + modelDTO.setDescription(modelVersion.getDescription()); + modelDTO.setProgram(modelVersion.getProgram()); + modelDTO.setFramework(modelVersion.getFramework()); + modelDTO.setEnvironment(modelVersion.getEnvironment()); + modelDTO.setExperimentId(modelVersion.getExperimentId()); + modelDTO.setExperimentProjectName(modelVersion.getExperimentProjectName()); + modelDTO.setProjectName(modelRegistryProject.getName()); + modelDTO.setModelRegistryId(modelRegistryProject.getId()); + modelDTO.setCreator(usersBuilder.build(uriInfo, resourceRequest, modelVersion.getCreator())); + + DatasetPath modelDsPath = datasetHelper.getDatasetPath(userProject, + modelUtils.getModelFullPath(modelRegistryProject, modelVersion.getModel().getName(), + modelVersion.getModelVersionPK().getVersion()), + DatasetType.DATASET); + ModelRegistryTagUri tagUri = new ModelRegistryTagUri(uriInfo, modelRegistryProject, + ResourceRequest.Name.MODELS, modelDTO.getId()); + modelDTO.setTags(tagsBuilder.build(tagUri, resourceRequest, user, modelDsPath)); + + String modelVersionPath = modelsFolder + "/" + modelDTO.getName() + "/" + modelDTO.getVersion() + "/"; + + DatasetPath modelSchemaPath = datasetHelper.getDatasetPath(userProject, + modelVersionPath + Settings.HOPS_MODELS_SCHEMA, DatasetType.DATASET); + if(resourceRequest.contains(ResourceRequest.Name.MODELSCHEMA) && modelSchemaPath.getInode() != null) { + InodeDTO modelSchemaDTO = inodeBuilder.buildBlob(uriInfo, new ResourceRequest(ResourceRequest.Name.INODES), + user, modelSchemaPath, modelSchemaPath.getInode(), FilePreviewMode.HEAD); + modelDTO.setModelSchema(modelSchemaDTO); + } else { + InodeDTO modelSchemaDTO = inodeBuilder.buildResource(uriInfo, modelRegistryProject, modelSchemaPath); + modelDTO.setModelSchema(modelSchemaDTO); } - if(filterUser != null) { - ProjectTeam member = projectTeamFacade.findByPrimaryKey(filterUserProject, filterUser); - if(member == null) { - throw new GenericException(RESTCodes.GenericErrorCode.ILLEGAL_ARGUMENT, Level.INFO, - "Selected user: " + filterUser.getUid() + " is not part of project:" + filterUserProject.getId()); - } - String hdfsUserStr = hdfsUsersController.getHdfsUserName(filterUserProject, filterUser); - HdfsUsers hdfsUsers = hdfsUsersFacade.findByName(hdfsUserStr); - provFilesParamBuilder.filterByField(ProvStateParser.FieldsP.USER_ID, hdfsUsers.getId().toString()); - } - } - Inode projectInode = inodeController.getProjectRoot(modelRegistryProject.getName()); - Dataset modelsDataset = datasetController.getByName(modelRegistryProject, Settings.HOPS_MODELS_DATASET); - Path modelsDatasetPath = Utils.getDatasetPath(modelsDataset, settings); - Inode modelsDatasetInode = inodeController.getProjectDatasetInode(projectInode, modelsDatasetPath.toString(), - modelsDataset); - ModelRegistryDTO modelRegistryDTO = ModelRegistryDTO.fromDataset(modelRegistryProject, modelsDatasetInode); - provFilesParamBuilder - .filterByField(ProvStateParser.FieldsP.PROJECT_I_ID, projectInode.getId()) - .filterByField(ProvStateParser.FieldsP.DATASET_I_ID, modelRegistryDTO.getDatasetInodeId()); - return Pair.with(provFilesParamBuilder, modelRegistryDTO); - } - - private void buildSortOrder(ProvStateParamBuilder provFilesParamBuilder, Set sort) { - if(sort != null) { - for(AbstractFacade.SortBy sortBy: sort) { - if(sortBy.getValue().compareToIgnoreCase(SortBy.NAME.name()) == 0) { - provFilesParamBuilder.sortByXAttr(MODEL_SUMMARY_XATTR_NAME + ".name", - SortOrder.valueOf(sortBy.getParam().getValue())); - } else { - String sortKeyName = sortBy.getValue(); - String sortKeyOrder = sortBy.getParam().getValue(); - provFilesParamBuilder.sortByXAttr(MODEL_SUMMARY_XATTR_NAME + ".metrics." + sortKeyName, - SortOrder.valueOf(sortKeyOrder)); - } + DatasetPath inputExamplePath = datasetHelper.getDatasetPath(userProject, + modelVersionPath + Settings.HOPS_MODELS_INPUT_EXAMPLE, DatasetType.DATASET); + if(resourceRequest.contains(ResourceRequest.Name.INPUTEXAMPLE) && inputExamplePath.getInode() != null) { + InodeDTO inputExampleDTO = inodeBuilder.buildBlob(uriInfo, new ResourceRequest(ResourceRequest.Name.INODES), + user, inputExamplePath, + inputExamplePath.getInode(), FilePreviewMode.HEAD); + modelDTO.setInputExample(inputExampleDTO); + } else { + InodeDTO inputExampleDTO = inodeBuilder.buildResource(uriInfo, modelRegistryProject, inputExamplePath); + modelDTO.setInputExample(inputExampleDTO); } } - } - - private void validatePagination(ResourceRequest resourceRequest) { - if(resourceRequest.getLimit() == null || resourceRequest.getLimit() <= 0) { - resourceRequest.setLimit(settings.getOpenSearchDefaultScrollPageSize()); - } - - if(resourceRequest.getOffset() == null || resourceRequest.getOffset() <= 0) { - resourceRequest.setOffset(0); - } - } - - protected enum SortBy { - NAME - } - - protected enum Filters { - NAME_EQ, - NAME_LIKE, - VERSION, - ID_EQ, - USER, - USER_PROJECT - } - - private Pair buildModelProvenanceParams(Project project, - Project modelRegistryProject, - ResourceRequest resourceRequest) - throws ProvenanceException, GenericException, DatasetException { - Pair builder - = buildFilter(project, modelRegistryProject, resourceRequest.getFilter()); - builder.getValue0() - .filterByField(ProvStateParser.FieldsP.ML_TYPE, Provenance.MLType.MODEL.name()) - .hasXAttr(MODEL_SUMMARY_XATTR_NAME) - .paginate(resourceRequest.getOffset(), resourceRequest.getLimit()); - buildSortOrder(builder.getValue0(), resourceRequest.getSort()); - return builder; + return modelDTO; } } diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java index 0f23fa7374..b9d896cc4e 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsController.java @@ -26,14 +26,11 @@ import io.hops.hopsworks.common.hdfs.DistributedFileSystemOps; import io.hops.hopsworks.common.hdfs.Utils; import io.hops.hopsworks.common.hdfs.inode.InodeController; -import io.hops.hopsworks.common.hdfs.xattrs.XAttrsController; import io.hops.hopsworks.common.jobs.JobController; import io.hops.hopsworks.common.jupyter.JupyterController; -import io.hops.hopsworks.common.provenance.core.Provenance; -import io.hops.hopsworks.common.provenance.state.ProvStateParamBuilder; -import io.hops.hopsworks.common.provenance.state.ProvStateParser; -import io.hops.hopsworks.common.provenance.state.ProvStateController; -import io.hops.hopsworks.common.provenance.state.dto.ProvStateDTO; +import io.hops.hopsworks.common.models.ModelFacade; +import io.hops.hopsworks.common.models.version.ModelVersionFacade; +import io.hops.hopsworks.common.python.environment.EnvironmentController; import io.hops.hopsworks.common.serving.ServingController; import io.hops.hopsworks.common.serving.ServingWrapper; import io.hops.hopsworks.common.util.AccessController; @@ -42,9 +39,8 @@ import io.hops.hopsworks.exceptions.DatasetException; import io.hops.hopsworks.exceptions.JobException; import io.hops.hopsworks.exceptions.KafkaException; -import io.hops.hopsworks.exceptions.MetadataException; import io.hops.hopsworks.exceptions.ModelRegistryException; -import io.hops.hopsworks.exceptions.ProvenanceException; +import io.hops.hopsworks.exceptions.PythonException; import io.hops.hopsworks.exceptions.ServiceException; import io.hops.hopsworks.exceptions.ServingException; import io.hops.hopsworks.persistence.entity.dataset.Dataset; @@ -52,119 +48,147 @@ import io.hops.hopsworks.persistence.entity.dataset.DatasetType; import io.hops.hopsworks.persistence.entity.jobs.configuration.spark.SparkJobConfiguration; import io.hops.hopsworks.persistence.entity.jobs.description.Jobs; +import io.hops.hopsworks.persistence.entity.models.Model; +import io.hops.hopsworks.persistence.entity.models.version.Metrics; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersionPK; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.restutils.RESTCodes; import org.apache.hadoop.fs.Path; -import org.json.JSONObject; import javax.ejb.EJB; import javax.ejb.Stateless; import javax.ejb.TransactionAttribute; import javax.ejb.TransactionAttributeType; import javax.inject.Inject; +import java.util.Date; import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; -import static io.hops.hopsworks.api.modelregistry.models.ModelsBuilder.MODEL_SUMMARY_XATTR_NAME; - @Stateless @TransactionAttribute(TransactionAttributeType.NEVER) public class ModelsController { private static final Logger LOGGER = Logger.getLogger(ModelsController.class.getName()); - @EJB - private ProvStateController provenanceController; @EJB private JobController jobController; @EJB private JupyterController jupyterController; @EJB - private XAttrsController xattrCtrl; - @EJB private DatasetController datasetController; @EJB - private ModelConverter modelConverter; - @EJB private DatasetHelper datasetHelper; @EJB private AccessController accessCtrl; @EJB private DatasetController datasetCtrl; @EJB + private ModelFacade modelFacade; + @EJB + private ModelVersionFacade modelVersionFacade; + @EJB private ProjectFacade projectFacade; @EJB private ModelUtils modelUtils; @EJB private InodeController inodeController; @EJB + private EnvironmentController environmentController; + @EJB private Settings settings; @Inject private ServingController servingController; - - public void attachModel(DistributedFileSystemOps udfso, Project modelProject, String userFullName, - ModelDTO modelDTO) - throws DatasetException, ModelRegistryException, MetadataException { - modelDTO.setUserFullName(userFullName); - - String modelPath = Utils.getProjectPath(modelProject.getName()) + Settings.HOPS_MODELS_DATASET + "/" + - modelDTO.getName() + "/" + modelDTO.getVersion(); - - byte[] modelSummaryB = modelConverter.marshalDescription(modelDTO); - xattrCtrl.upsertProvXAttr(udfso, modelPath, MODEL_SUMMARY_XATTR_NAME, modelSummaryB); + public ModelVersion createModelVersion(ModelsController.Accessor accessor, ModelDTO modelDTO, + String jobName, String kernelId) + throws JobException, ServiceException, PythonException { + + Model model = modelFacade.findByProjectAndName(accessor.modelProject, modelDTO.getName()); + if(model == null) { + model = modelFacade.put(accessor.modelProject, modelDTO.getName()); + } + + ModelVersion modelVersion = new ModelVersion(); + modelVersion.setCreated(new Date()); + modelVersion.setDescription(modelDTO.getDescription()); + modelVersion.setEnvironment(modelDTO.getEnvironment()); + modelVersion.setFramework(modelDTO.getFramework()); + modelVersion.setExperimentId(modelVersion.getExperimentId()); + Metrics metrics = new Metrics(); + metrics.setAttributes(modelDTO.getMetrics()); + modelVersion.setMetrics(metrics); + modelVersion.setExperimentProjectName(modelDTO.getExperimentProjectName()); + modelVersion.setCreator(accessor.user); + + ModelVersionPK modelVersionPK = new ModelVersionPK(); + modelVersionPK.setVersion(modelDTO.getVersion()); + modelVersionPK.setModelId(model.getId()); + modelVersion.setModelVersionPK(modelVersionPK); + + //Only attach program and environment if exporting inside Hopsworks + if (!Strings.isNullOrEmpty(jobName) || !Strings.isNullOrEmpty(kernelId)) { + modelVersion.setProgram(versionProgram(accessor, jobName, kernelId, + modelDTO.getName(), modelDTO.getVersion())); + //Export environment to correct path here + modelVersion.setEnvironment(environmentController.exportEnv(accessor.experimentProject, accessor.user, + modelUtils.getModelFullPath(accessor.modelProject, modelDTO.getName(), modelDTO.getVersion()) + + "/" + Settings.ENVIRONMENT_FILE + )[0]); + } + + modelVersionFacade.put(modelVersion); + + return modelVersionFacade.findByProjectAndMlId(model.getId(), modelDTO.getVersion()); } - public ProvStateDTO getModel(Project project, String mlId) throws ProvenanceException { - Inode projectInode = inodeController.getProjectRoot(project.getName()); - ProvStateParamBuilder provFilesParamBuilder = new ProvStateParamBuilder() - .filterByField(ProvStateParser.FieldsP.PROJECT_I_ID, projectInode.getId()) - .filterByField(ProvStateParser.FieldsP.ML_TYPE, Provenance.MLType.MODEL.name()) - .filterByField(ProvStateParser.FieldsP.ML_ID, mlId) - .paginate(0, 1); - ProvStateDTO fileState = provenanceController.provFileStateList(projectInode, provFilesParamBuilder); - if (fileState != null) { - List experiments = fileState.getItems(); - if (experiments != null && !experiments.isEmpty()) { - return experiments.iterator().next(); - } + public ModelVersion getModel(Project project, String mlId) throws ModelRegistryException { + int lastUnderscore = mlId.lastIndexOf("_"); + String[] nameVersionSplit = {mlId.substring(0, lastUnderscore), mlId.substring(lastUnderscore + 1)}; + Model model = modelFacade.findByProjectAndName(project, nameVersionSplit[0]); + if(model == null) { + throw new ModelRegistryException(RESTCodes.ModelRegistryErrorCode.MODEL_NOT_FOUND, + Level.FINE); } - return null; + return modelVersionFacade.findByProjectAndMlId(model.getId(), Integer.valueOf(nameVersionSplit[1])); } - public void delete(Users user, Project userProject, Project parentProject, ProvStateDTO fileState) + public void delete(Users user, Project userProject, Project parentProject, ModelVersion modelVersion) throws DatasetException, ModelRegistryException, KafkaException, ServingException, CryptoPasswordNotFoundException { if(userProject.getId().equals(parentProject.getId())) { - delete(user, userProject, fileState); + delete(user, userProject, modelVersion); } else { - verifyNoModelDeployments(userProject, fileState); - - JSONObject summary = new JSONObject(fileState.getXattrs().get(MODEL_SUMMARY_XATTR_NAME)); - ModelDTO modelSummary = modelConverter.unmarshalDescription(summary.toString()); + verifyNoModelDeployments(userProject, modelVersion); + String modelPath = Utils.getProjectPath(userProject.getName()) - + parentProject.getName() + "::" + Settings.HOPS_MODELS_DATASET + "/" + modelSummary.getName() - + "/" + modelSummary.getVersion(); - deleteInternal(user, userProject, modelPath); + + parentProject.getName() + "::" + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + + "/" + modelVersion.getModelVersionPK().getVersion(); + deleteInternal(user, userProject, modelPath, modelVersion); } } - public void delete(Users user, Project project, ProvStateDTO fileState) throws DatasetException, + public void delete(Users user, Project project, ModelVersion modelVersion) throws DatasetException, ModelRegistryException, KafkaException, ServingException, CryptoPasswordNotFoundException { - verifyNoModelDeployments(project, fileState); - - JSONObject summary = new JSONObject(fileState.getXattrs().get(MODEL_SUMMARY_XATTR_NAME)); - ModelDTO modelSummary = modelConverter.unmarshalDescription(summary.toString()); + verifyNoModelDeployments(project, modelVersion); + String modelPath = Utils.getProjectPath(project.getName()) - + Settings.HOPS_MODELS_DATASET + "/" + modelSummary.getName() + "/" + modelSummary.getVersion(); - deleteInternal(user, project, modelPath); + + Settings.HOPS_MODELS_DATASET + "/" + modelVersion.getModel().getName() + "/" + + modelVersion.getModelVersionPK().getVersion(); + deleteInternal(user, project, modelPath, modelVersion); } - private void deleteInternal(Users user, Project project, String path) throws DatasetException { + private void deleteInternal(Users user, Project project, String path, ModelVersion modelVersion) + throws DatasetException { DatasetPath datasetPath = datasetHelper.getDatasetPath(project, path, DatasetType.DATASET); datasetController.delete(project, user, datasetPath.getFullPath(), datasetPath.getDataset(), datasetPath.isTopLevelDataset()); + Model model = modelFacade.findByProjectAndName(project, modelVersion.getModel().getName()); + modelVersionFacade.remove(modelVersion); + if(model.getVersions().isEmpty()) { + modelFacade.remove(model); + } } public String versionProgram(Accessor accessor, String jobName, String kernelId, String modelName, int modelVersion) @@ -227,9 +251,9 @@ public ModelRegistryDTO verifyModelRegistryAccess(Project userProject, Integer m } } - public void verifyNoModelDeployments(Project project, ProvStateDTO fileState) + public void verifyNoModelDeployments(Project project, ModelVersion modelVersion) throws ModelRegistryException, KafkaException, ServingException, CryptoPasswordNotFoundException { - String[] nameVersionSplit = modelUtils.getModelNameAndVersion(fileState.getMlId()); + String[] nameVersionSplit = modelUtils.getModelNameAndVersion(modelVersion.getMlId()); List deployments = servingController.getAll(project, nameVersionSplit[0], Integer.valueOf(nameVersionSplit[1]), null); if (deployments != null && deployments.size() > 0) { diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java index 733d462f70..f3dcbe2ed1 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/ModelsResource.java @@ -26,7 +26,6 @@ import io.hops.hopsworks.api.util.Pagination; import io.hops.hopsworks.common.api.ResourceRequest; import io.hops.hopsworks.common.hdfs.DistributedFsService; -import io.hops.hopsworks.common.provenance.state.dto.ProvStateDTO; import io.hops.hopsworks.exceptions.CryptoPasswordNotFoundException; import io.hops.hopsworks.exceptions.DatasetException; import io.hops.hopsworks.exceptions.GenericException; @@ -41,6 +40,7 @@ import io.hops.hopsworks.exceptions.ServiceException; import io.hops.hopsworks.exceptions.ServingException; import io.hops.hopsworks.jwt.annotation.JWTRequired; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; import io.hops.hopsworks.persistence.entity.project.Project; import io.hops.hopsworks.persistence.entity.user.Users; import io.hops.hopsworks.persistence.entity.user.security.apiKey.ApiScope; @@ -67,6 +67,7 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import javax.ws.rs.core.SecurityContext; +import javax.ws.rs.core.UriBuilder; import javax.ws.rs.core.UriInfo; import java.util.logging.Level; import java.util.logging.Logger; @@ -157,11 +158,11 @@ public Response get ( Users user = jwtHelper.getUserPrincipal(sc); ResourceRequest resourceRequest = new ResourceRequest(ResourceRequest.Name.MODELS); resourceRequest.setExpansions(modelsBeanParam.getExpansions().getResources()); - ProvStateDTO fileState = modelsController.getModel(modelRegistryProject, id); + ModelVersion modelVersion = modelsController.getModel(modelRegistryProject, id); - if(fileState != null) { + if(modelVersion != null) { ModelDTO dto = modelsBuilder.build(uriInfo, resourceRequest, user, userProject, modelRegistryProject, - fileState, modelUtils.getModelsDatasetPath(userProject, modelRegistryProject)); + modelVersion, modelUtils.getModelsDatasetPath(userProject, modelRegistryProject)); if(dto == null) { throw new GenericException(RESTCodes.GenericErrorCode.NOT_AUTHORIZED_TO_ACCESS, Level.FINE); } else { @@ -189,9 +190,9 @@ public Response delete ( throws DatasetException, ProvenanceException, ModelRegistryException, KafkaException, ServingException, CryptoPasswordNotFoundException { Users user = jwtHelper.getUserPrincipal(sc); - ProvStateDTO fileState = modelsController.getModel(userProject, id); - if(fileState != null) { - modelsController.delete(user, userProject, modelRegistryProject, fileState); + ModelVersion modelVersion = modelsController.getModel(userProject, id); + if(modelVersion != null) { + modelsController.delete(user, userProject, modelRegistryProject, modelVersion); } return Response.noContent().build(); } @@ -214,8 +215,8 @@ public Response put(@PathParam("id") String id, @Context HttpServletRequest req, @Context UriInfo uriInfo, @Context SecurityContext sc) - throws DatasetException, ModelRegistryException, JobException, ServiceException, PythonException, MetadataException, - GenericException, ProjectException { + throws DatasetException, ModelRegistryException, JobException, ServiceException, PythonException, + MetadataException, GenericException, ProjectException, FeatureStoreMetadataException, ProvenanceException { if (modelDTO == null) { throw new IllegalArgumentException("Model summary not provided"); } @@ -223,12 +224,20 @@ public Response put(@PathParam("id") String id, Users user = jwtHelper.getUserPrincipal(sc); Project modelProject = modelUtils.getModelsProjectAndCheckAccess(modelDTO, userProject); Project experimentProject = modelUtils.getExperimentProjectAndCheckAccess(modelDTO, userProject); - ModelsController.Accessor accessor = modelUtils.getModelsAccessor(user, userProject, modelProject, - experimentProject); + ModelsController.Accessor accessor = null; try { - return modelUtils.createModel(uriInfo, accessor, id, modelDTO, jobName, kernelId); + accessor = modelUtils.getModelsAccessor(user, userProject, modelProject, + experimentProject); + ModelVersion modelVersion = modelsController.createModelVersion(accessor, modelDTO, jobName, kernelId); + ModelDTO dto = modelsBuilder.build(uriInfo, new ResourceRequest(ResourceRequest.Name.MODELS), user, userProject, + modelRegistryProject, modelVersion, modelUtils.getModelFullPath(modelProject, modelVersion.getModel().getName(), + modelVersion.getModelVersionPK().getVersion())); + UriBuilder builder = uriInfo.getAbsolutePathBuilder().path(id); + return Response.created(builder.build()).entity(dto).build(); } finally { - dfs.closeDfsClient(accessor.udfso); + if(accessor != null) { + dfs.closeDfsClient(accessor.udfso); + } } } @@ -238,9 +247,8 @@ public ModelTagResource tags(@ApiParam(value = "Id of the model", required = tru throws ModelRegistryException, ProvenanceException { this.tagResource.setProject(userProject); this.tagResource.setModelRegistry(modelRegistryProject); - ProvStateDTO fileState = modelsController.getModel(modelRegistryProject, id); - ModelDTO model = modelUtils.convertProvenanceHitToModel(fileState); - this.tagResource.setModel(model, fileState); + ModelVersion modelVersion = modelsController.getModel(modelRegistryProject, id); + this.tagResource.setModel(modelVersion); return this.tagResource; } } diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/SortBy.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/SortBy.java index 5864fd2255..3e9bcde5f7 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/SortBy.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/SortBy.java @@ -17,24 +17,33 @@ package io.hops.hopsworks.api.modelregistry.models; import io.hops.hopsworks.common.dao.AbstractFacade; +import io.hops.hopsworks.common.models.version.ModelVersionFacade; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Response; +import java.util.Arrays; public class SortBy implements AbstractFacade.SortBy { - private final Sorts sortBy; + private final ModelVersionFacade.Sorts sortBy; private final AbstractFacade.OrderBy param; public SortBy(String param) { - String[] sortByParams = param.split(":"); + final String[] sortByParams = param.split(":"); String sort = ""; try { - sort = sortByParams[0]; - this.sortBy = new Sorts(sort, "DESC"); + sort = sortByParams[0].toUpperCase(); + if (Arrays.stream(ModelVersionFacade.Sorts.values()) + .noneMatch(x -> x.name().equals(sortByParams[0].toUpperCase()))) { + this.sortBy = ModelVersionFacade.Sorts.valueOf("METRIC"); + this.sortBy.setJsonSortKey(sortByParams[0]); + } else { + sort = sortByParams[0].toUpperCase(); + this.sortBy = ModelVersionFacade.Sorts.valueOf(sort); + } } catch (IllegalArgumentException iae) { throw new WebApplicationException("Sort by need to set a valid sort parameter, but found: " + sort, - Response.Status.NOT_FOUND); + Response.Status.NOT_FOUND); } String order = ""; try { @@ -42,46 +51,23 @@ public SortBy(String param) { this.param = AbstractFacade.OrderBy.valueOf(order); } catch (IllegalArgumentException iae) { throw new WebApplicationException("Sort by " + sort + " need to set a valid order(asc|desc), but found: " + order - , Response.Status.NOT_FOUND); + , Response.Status.NOT_FOUND); } } + @Override public String getValue() { return this.sortBy.getValue(); } + @Override public AbstractFacade.OrderBy getParam() { return this.param; } @Override public String getSql() { - return null; + return this.sortBy.getSql(); } - public class Sorts { - - private final String value; - private final String defaultParam; - - Sorts(String value, String defaultParam) { - this.value = value; - this.defaultParam = defaultParam; - } - - public String getValue() { - return value; - } - - public String getDefaultParam() { - return defaultParam; - } - - @Override - public String toString() { - return value; - } - - } } - diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/dto/ModelDTO.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/dto/ModelDTO.java index 389f007dc8..513a580f27 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/dto/ModelDTO.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/dto/ModelDTO.java @@ -17,6 +17,7 @@ package io.hops.hopsworks.api.modelregistry.models.dto; import io.hops.hopsworks.api.dataset.inode.InodeDTO; +import io.hops.hopsworks.api.user.UserDTO; import io.hops.hopsworks.common.tags.TagsDTO; import io.hops.hopsworks.common.api.RestDTO; import io.hops.hopsworks.common.featurestore.trainingdatasets.TrainingDatasetDTO; @@ -26,7 +27,6 @@ import javax.xml.bind.annotation.XmlAnyAttribute; import javax.xml.bind.annotation.XmlRootElement; import javax.xml.namespace.QName; -import java.util.Arrays; import java.util.HashMap; /** @@ -58,12 +58,14 @@ public ModelDTO() { private Long created; + private UserDTO creator; + @XmlAnyAttribute private HashMap metrics; private String description; - private String[] environment; + private String environment; private String program; @@ -143,11 +145,11 @@ public void setId(String id) { this.id = id; } - public String[] getEnvironment() { + public String getEnvironment() { return environment; } - public void setEnvironment(String[] environment) { + public void setEnvironment(String environment) { this.environment = environment; } @@ -224,6 +226,14 @@ public void setTags(TagsDTO tags) { } public String getType() { return type; } + + public UserDTO getCreator() { + return creator; + } + + public void setCreator(UserDTO creator) { + this.creator = creator; + } @Override public String toString() { @@ -239,7 +249,7 @@ public String toString() { ", created=" + created + ", metrics=" + metrics + ", description='" + description + '\'' + - ", environment=" + Arrays.toString(environment) + + ", environment=" + environment + ", program='" + program + '\'' + ", experimentId='" + experimentId + '\'' + ", experimentProjectName='" + experimentProjectName + '\'' + diff --git a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java index 5622ad29a8..5671bdc0ac 100644 --- a/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java +++ b/hopsworks-api/src/main/java/io/hops/hopsworks/api/modelregistry/models/tags/ModelTagResource.java @@ -17,13 +17,14 @@ import io.hops.hopsworks.api.modelregistry.ModelRegistryTagResource; import io.hops.hopsworks.api.modelregistry.models.ModelUtils; -import io.hops.hopsworks.api.modelregistry.models.dto.ModelDTO; +import io.hops.hopsworks.audit.logger.LogLevel; +import io.hops.hopsworks.audit.logger.annotation.Logged; import io.hops.hopsworks.common.api.ResourceRequest; import io.hops.hopsworks.common.dataset.util.DatasetHelper; import io.hops.hopsworks.common.dataset.util.DatasetPath; -import io.hops.hopsworks.common.provenance.state.dto.ProvStateDTO; import io.hops.hopsworks.exceptions.DatasetException; import io.hops.hopsworks.persistence.entity.dataset.DatasetType; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; import javax.ejb.EJB; import javax.ejb.TransactionAttribute; @@ -39,29 +40,27 @@ public class ModelTagResource extends ModelRegistryTagResource { @EJB private DatasetHelper datasetHelper; - private ModelDTO model; - private ProvStateDTO provState; + private ModelVersion modelVersion; /** - * Sets the model and prov state of the tag resource + * Sets the model version for the tag resource * - * @param model - * @param provState + * @param modelVersion */ - public void setModel(ModelDTO model, ProvStateDTO provState) { - this.model = model; - this.provState = provState; + @Logged(logLevel = LogLevel.OFF) + public void setModel(ModelVersion modelVersion) { + this.modelVersion = modelVersion; } @Override protected DatasetPath getDatasetPath() throws DatasetException { - return datasetHelper.getDatasetPath(project, modelUtils.getModelFullPath(modelRegistry, model.getName(), - model.getVersion()), DatasetType.DATASET); + return datasetHelper.getDatasetPath(project, modelUtils.getModelFullPath(modelRegistry, + modelVersion.getModel().getName(), modelVersion.getModelVersionPK().getVersion()), DatasetType.DATASET); } @Override protected String getItemId() { - return provState.getMlId(); + return modelVersion.getMlId(); } @Override diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java new file mode 100644 index 0000000000..1dcca51fd4 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/ModelFacade.java @@ -0,0 +1,207 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.models; + +import io.hops.hopsworks.common.dao.AbstractFacade; +import io.hops.hopsworks.persistence.entity.models.Model; +import io.hops.hopsworks.persistence.entity.project.Project; + +import javax.ejb.Stateless; +import javax.ejb.TransactionAttribute; +import javax.ejb.TransactionAttributeType; +import javax.persistence.EntityManager; +import javax.persistence.NoResultException; +import javax.persistence.PersistenceContext; +import javax.persistence.Query; +import javax.persistence.TypedQuery; +import java.util.Date; +import java.util.Set; +import java.util.logging.Logger; + +/** + * Facade for management of persistent Model objects. + */ +@Stateless +public class ModelFacade extends AbstractFacade { + + @PersistenceContext(unitName = "kthfsPU") + private EntityManager em; + + private static final Logger LOGGER = Logger.getLogger(ModelFacade.class.getName()); + + public ModelFacade() { + super(Model.class); + } + + @Override + protected EntityManager getEntityManager() { + return em; + } + + @TransactionAttribute(TransactionAttributeType.REQUIRES_NEW) + public Model put(Project project, String name) { + //Argument checking + if (project == null || name == null) { + throw new IllegalArgumentException("Project and name must be non-null."); + } + //First: create a model object + Model model = new Model(); + model.setName(name); + model.setProject(project); + model = em.merge(model); + em.flush(); //To get the id. + return model; + } + + + /** + * Checks if a model with the given name exists in this project. + * + * @param project project to search. + * @param name name of model. + * @return model if exactly one model with that name was found. + */ + public Model findByProjectAndName(Project project, String name) { + TypedQuery query = em.createNamedQuery("Model.findByProjectAndName", Model.class); + query.setParameter("name", name).setParameter("project", project); + try { + return query.getSingleResult(); + } catch (NoResultException e) { + return null; + } + } + + public CollectionInfo findByProject(Integer offset, Integer limit, + Set filters, + Set sorts, Project project) { + + String queryStr = buildQuery("SELECT m FROM Model m ", filters, sorts, "m.project = :project "); + String queryCountStr = + buildQuery("SELECT COUNT(DISTINCT m.name, m.project) FROM Model m ", filters, sorts, "m.project = :project "); + Query query = em.createQuery(queryStr, Model.class).setParameter("project", project); + Query queryCount = em.createQuery(queryCountStr, Model.class).setParameter("project", project); + setFilter(filters, query); + setFilter(filters, queryCount); + setOffsetAndLim(offset, limit, query); + return new CollectionInfo((Long) queryCount.getSingleResult(), query.getResultList()); + } + + + private void setFilter(Set filter, Query q) { + if (filter == null || filter.isEmpty()) { + return; + } + for (FilterBy aFilter : filter) { + setFilterQuery(aFilter, q); + } + } + + private void setFilterQuery(AbstractFacade.FilterBy filterBy, Query q) { + switch (Filters.valueOf(filterBy.getValue())) { + case DATE_CREATED: + Date date = getDate(filterBy.getField(), filterBy.getParam()); + q.setParameter(filterBy.getField(), date); + break; + case NAME: + case DESCRIPTION: + case CREATOR: + q.setParameter(filterBy.getField(), filterBy.getParam()); + break; + default: + break; + } + } + + public enum Sorts { + ID("ID", "e.id ", "ASC"), + NAME("NAME", "e.name ", "ASC"), + DATE_CREATED("DATE_CREATED", "e.created ", "DESC"), + CREATOR("CREATOR", "LOWER(CONCAT (e.creator.fname, e.creator.lname)) " , "ASC"); + private final String value; + private final String sql; + private final String defaultParam; + + private Sorts(String value, String sql, String defaultParam) { + this.value = value; + this.sql = sql; + this.defaultParam = defaultParam; + } + + public String getValue() { + return value; + } + + public String getDefaultParam() { + return defaultParam; + } + + public String getSql() { + return sql; + } + + @Override + public String toString() { + return value; + } + + } + + public enum Filters { + DATE_CREATED("DATE_CREATED", "e.created = :created ","created",""), + NAME("NAME", "e.name LIKE CONCAT('%', :name, '%') ", "name", " "), + DESCRIPTION("NAME", "e.description LIKE CONCAT('%', :description, '%') ", "description", " "), + CREATOR("CREATOR", "(e.creator.username LIKE CONCAT('%', :creator, '%') " + + "OR e.creator.fname LIKE CONCAT('%', :creator, '%') " + + "OR e.creator.lname LIKE CONCAT('%', :creator, '%')) ", "creator", " "); + + private final String value; + private final String sql; + private final String field; + private final String defaultParam; + + private Filters(String value, String sql, String field, String defaultParam) { + this.value = value; + this.sql = sql; + this.field = field; + this.defaultParam = defaultParam; + } + + public String getValue() { + return value; + } + + public String getDefaultParam() { + return defaultParam; + } + + public String getSql() { + return sql; + } + + public String getField() { + return field; + } + + @Override + public String toString() { + return value; + } + } + + +} + diff --git a/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/version/ModelVersionFacade.java b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/version/ModelVersionFacade.java new file mode 100644 index 0000000000..7933850d55 --- /dev/null +++ b/hopsworks-common/src/main/java/io/hops/hopsworks/common/models/version/ModelVersionFacade.java @@ -0,0 +1,203 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.common.models.version; + +import io.hops.hopsworks.common.dao.AbstractFacade; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; +import io.hops.hopsworks.persistence.entity.project.Project; + +import javax.ejb.Stateless; +import javax.persistence.EntityManager; +import javax.persistence.NoResultException; +import javax.persistence.PersistenceContext; +import javax.persistence.Query; +import javax.persistence.TypedQuery; +import java.util.Set; + +/** + * Facade for management of persistent Model Version objects. + */ +@Stateless +public class ModelVersionFacade extends AbstractFacade { + + @PersistenceContext(unitName = "kthfsPU") + private EntityManager em; + + public ModelVersionFacade() { + super(ModelVersion.class); + } + + @Override + protected EntityManager getEntityManager() { + return em; + } + + public CollectionInfo findByProject(Integer offset, Integer limit, + Set filters, + Set sorts, + Project project) { + + String queryStr = buildQuery( + "SELECT * FROM hopsworks.`model_version` JOIN `hopsworks`.model ON `hopsworks`.model_version.model_id=model.id ", + filters, sorts, "`hopsworks`.model.project_id = ?project_id "); + + String queryCountStr = + buildQuery("SELECT COUNT(DISTINCT concat(model_version.model_id, model_version.version)) " + + "FROM hopsworks.`model_version` JOIN `hopsworks`.model ON `hopsworks`.model_version.model_id=model.id ", + filters, sorts, "`hopsworks`.model.project_id = ?project_id "); + + Query query = em.createNativeQuery(queryStr, ModelVersion.class).setParameter("project_id", project.getId()); + Query queryCount = em.createNativeQuery(queryCountStr).setParameter("project_id", project.getId()); + setFilter(filters, query); + setFilter(filters, queryCount); + setOffsetAndLim(offset, limit, query); + return new CollectionInfo((Long)queryCount.getSingleResult(), query.getResultList()); + } + + public ModelVersion findByProjectAndMlId(Integer modelId, Integer version) { + TypedQuery query = em.createNamedQuery("ModelVersion.findByProjectAndMlId", ModelVersion.class); + query.setParameter("modelId", modelId).setParameter("version", version); + try { + return query.getSingleResult(); + } catch (NoResultException e) { + return null; + } + } + + private void setFilter(Set filter, Query q) { + if (filter == null || filter.isEmpty()) { + return; + } + for (FilterBy aFilter : filter) { + setFilterQuery(aFilter, q); + } + } + + private void setFilterQuery(AbstractFacade.FilterBy filterBy, Query q) { + switch (Filters.valueOf(filterBy.getValue())) { + case NAME_EQ: + case NAME_LIKE: + case VERSION: + q.setParameter(filterBy.getField(), filterBy.getParam()); + break; + default: + break; + } + } + + public enum Sorts { + NAME("NAME", "`hopsworks`.model.name " , "ASC"), + METRIC("METRIC", "JSON_VALUE(`metrics`, '$.attributes.METRIC') IS NULL, " + + "CAST(JSON_VALUE(`metrics`, '$.attributes.METRIC') AS FLOAT) ", + "ASC"); //sort twice needed to make sure nulls always at the end of sorted items + private final String value; + private final String sql; + private final String defaultParam; + + private String jsonSortKey; + + private Sorts(String value, String sql, String defaultParam) { + this.value = value; + this.sql = sql; + this.defaultParam = defaultParam; + } + + public String getValue() { + return value; + } + + public String getDefaultParam() { + return defaultParam; + } + + public String getSql() { + if (this.value.equals(Sorts.METRIC.value)) { + return sql.replace("METRIC", this.getJsonSortKey()); + } else { + return sql; + } + } + + public String getJoin() { + return null; + } + + @Override + public String toString() { + return value; + } + + public String getJsonSortKey() { + return jsonSortKey; + } + + public void setJsonSortKey(String jsonSortKey) { + this.jsonSortKey = jsonSortKey; + } + } + public enum Filters { + NAME_EQ ("NAME_EQ", + "`hopsworks`.model.name = ?name", + "name", ""), + NAME_LIKE ("NAME_LIKE", + "`hopsworks`.model.name LIKE CONCAT('%', ?name, '%') ", + "name", " "), + VERSION ("VERSION", + "`hopsworks`.model_version.version = ?version ", + "version", ""); + private final String value; + private final String sql; + private final String field; + private final String defaultParam; + + private Filters(String value, String sql, String field, String defaultParam) { + this.value = value; + this.sql = sql; + this.field = field; + this.defaultParam = defaultParam; + } + + public String getValue() { + return value; + } + + public String getDefaultParam() { + return defaultParam; + } + + public String getSql() { + return sql; + } + + public String getField() { + return field; + } + + @Override + public String toString() { + return value; + } + + } + + public ModelVersion put(ModelVersion modelVersion) { + //Finally: persist it, getting the assigned id. + modelVersion = em.merge(modelVersion); + em.flush(); //To get the id. + return modelVersion; + } +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java new file mode 100644 index 0000000000..b6f7cdb955 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/Model.java @@ -0,0 +1,133 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.models; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import io.hops.hopsworks.persistence.entity.models.version.ModelVersion; +import io.hops.hopsworks.persistence.entity.project.Project; + +import javax.persistence.Basic; +import javax.persistence.Column; +import javax.persistence.Entity; +import javax.persistence.GeneratedValue; +import javax.persistence.GenerationType; +import javax.persistence.Id; +import javax.persistence.JoinColumn; +import javax.persistence.ManyToOne; +import javax.persistence.NamedQueries; +import javax.persistence.NamedQuery; +import javax.persistence.OneToMany; +import javax.persistence.Table; +import javax.validation.constraints.Size; +import javax.xml.bind.annotation.XmlRootElement; +import javax.xml.bind.annotation.XmlTransient; +import java.io.Serializable; +import java.util.Collection; + +/** + * Description of the model. + */ +@Entity +@Table(name = "model", catalog = "hopsworks") +@XmlRootElement +@NamedQueries({ + @NamedQuery(name = "Model.findAll", + query = "SELECT m FROM Model m"), + @NamedQuery(name = "Model.findByProjectAndName", + query + = "SELECT m FROM Model m WHERE m.name = :name AND m.project = :project"),}) +public class Model implements Serializable { + private static final long serialVersionUID = 1L; + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + @Basic(optional = false) + @Column(name = "id") + private Integer id; + + @Size(max = 255) + @Column(name = "name") + private String name; + + @JoinColumn(name = "project_id", + referencedColumnName = "id") + @ManyToOne(optional = false) + private Project project; + + @OneToMany(mappedBy = "model") + private Collection versions; + + @JsonIgnore + @XmlTransient + public Collection getVersions() { + return versions; + } + + public void setVersions(Collection versions) { + this.versions = versions; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public Project getProject() { + return project; + } + + public void setProject(Project project) { + this.project = project; + } + + public Integer getId() { + return id; + } + + public void setId(Integer id) { + this.id = id; + } + + @Override + public int hashCode() { + int hash = 0; + hash += getProject().getId(); + hash += (name != null ? name.hashCode() : 0); + return hash; + } + + @Override + public final boolean equals(Object object) { + // TODO: Warning - this method won't work in the case the id fields are not set + if (!(object instanceof Model)) { + return false; + } + Model other = (Model) object; + if (this.project.getId() != other.getProject().getId()) { + return false; + } + if ((this.name == null && other.name != null) || + (this.name != null && !this.name.equals(other.name))) { + return false; + } + return true; + } +} + diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/Metrics.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/Metrics.java new file mode 100644 index 0000000000..09ae4cf668 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/Metrics.java @@ -0,0 +1,40 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.models.version; + +import javax.xml.bind.annotation.XmlAnyAttribute; +import javax.xml.namespace.QName; +import java.util.HashMap; + +public class Metrics { + + @XmlAnyAttribute + private HashMap attributes; + + public Metrics() { + //Needed for JAXB + } + + public HashMap getAttributes() { + return this.attributes; + } + + public void setAttributes(HashMap attributes) { + this.attributes = attributes; + } + +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelMetricsConverter.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelMetricsConverter.java new file mode 100644 index 0000000000..e6e806885c --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelMetricsConverter.java @@ -0,0 +1,71 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.models.version; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import org.json.JSONObject; + +import javax.persistence.AttributeConverter; +import javax.persistence.Converter; +import java.util.logging.Logger; + +@Converter +public class ModelMetricsConverter implements AttributeConverter { + + public ModelMetricsConverter() { + objectMapper.configure(SerializationFeature.WRAP_ROOT_VALUE, false); + objectMapper.configure(DeserializationFeature.UNWRAP_ROOT_VALUE, false); + } + + private final ObjectMapper objectMapper = new ObjectMapper(); + + public T readValue(String jsonConfig, Class resultClass) throws JsonProcessingException { + if(jsonConfig == null) { + jsonConfig = new JSONObject().toString(); + } + return objectMapper.readValue(jsonConfig, resultClass); + } + + public String writeValue(Object value) throws JsonProcessingException { + return objectMapper.writeValueAsString(value); + } + + private static final Logger LOGGER = Logger.getLogger(ModelMetricsConverter.class.getName()); + + @Override + public String convertToDatabaseColumn(Metrics metrics) { + String jsonConfig; + try { + jsonConfig = writeValue(metrics); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to marshal value:" + metrics, e); + } + return jsonConfig; + } + + @Override + public Metrics convertToEntityAttribute(String jsonConfig) { + try { + return readValue(jsonConfig, Metrics.class); + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to unmarshal value:" + jsonConfig, e); + } + } +} diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java new file mode 100644 index 0000000000..e0b933be45 --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersion.java @@ -0,0 +1,228 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.models.version; + +import io.hops.hopsworks.persistence.entity.models.Model; +import io.hops.hopsworks.persistence.entity.user.Users; + +import javax.persistence.Basic; +import javax.persistence.Column; +import javax.persistence.Convert; +import javax.persistence.EmbeddedId; +import javax.persistence.Entity; +import javax.persistence.JoinColumn; +import javax.persistence.ManyToOne; +import javax.persistence.NamedQueries; +import javax.persistence.NamedQuery; +import javax.persistence.Table; +import javax.persistence.Temporal; +import javax.persistence.TemporalType; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Size; +import javax.xml.bind.annotation.XmlRootElement; +import java.io.Serializable; +import java.util.Date; + +/** + * A ModelVersion is an instance of a Model. + */ +@Entity +@Table(name = "model_version", catalog = "hopsworks") +@XmlRootElement +@NamedQueries({ + @NamedQuery(name = "ModelVersion.findAll", + query = "SELECT mv FROM ModelVersion mv"), + @NamedQuery(name = "ModelVersion.findByProjectAndMlId", + query + = "SELECT mv FROM ModelVersion mv WHERE mv.modelVersionPK.version = :version" + + " AND mv.modelVersionPK.modelId = :modelId") + } +) +public class ModelVersion implements Serializable { + + private static final long serialVersionUID = 1L; + + @EmbeddedId + private ModelVersionPK modelVersionPK; + + @ManyToOne(optional = false) + @JoinColumn(name = "model_id", + referencedColumnName = "id", + insertable = false, + updatable = false) + private Model model; + + @JoinColumn(name = "user_id", + referencedColumnName = "uid") + @ManyToOne(optional = false) + private Users creator; + + @Basic(optional = false) + @NotNull + @Column(name = "created") + @Temporal(TemporalType.TIMESTAMP) + private Date created; + + @Size(max = 1000) + @Column(name = "description") + private String description; + + @Column(name = "metrics") + @Convert(converter = ModelMetricsConverter.class) + private Metrics metrics; + + @Size(max = 1000) + @Column(name = "program") + private String program; + + @Size(max = 128) + @Column(name = "framework") + private String framework; + + @Size(max = 1000) + @Column(name = "environment") + private String environment; + + @Size(max = 128) + @Column(name = "experiment_id") + private String experimentId; + + @Size(max = 128) + @Column(name = "experiment_project_name") + private String experimentProjectName; + + public ModelVersion() { + } + + public Metrics getMetrics() { + return metrics; + } + + public void setMetrics(Metrics metrics) { + this.metrics = metrics; + } + + public String getProgram() { + return program; + } + + public void setProgram(String program) { + this.program = program; + } + + public String getEnvironment() { + return environment; + } + + public void setEnvironment(String environment) { + this.environment = environment; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public Date getCreated() { + return created; + } + + public void setCreated(Date created) { + this.created = created; + } + + public String getUserFullName() { + return this.creator.getFname() + " " + this.creator.getLname(); + } + + public String getFramework() { + return framework; + } + + public void setFramework(String framework) { + this.framework = framework; + } + + public String getExperimentId() { + return experimentId; + } + + public void setExperimentId(String experimentId) { + this.experimentId = experimentId; + } + + public String getExperimentProjectName() { + return experimentProjectName; + } + + public void setExperimentProjectName(String experimentProjectName) { + this.experimentProjectName = experimentProjectName; + } + + public ModelVersionPK getModelVersionPK() { + return modelVersionPK; + } + + public void setModelVersionPK(ModelVersionPK modelVersionPK) { + this.modelVersionPK = modelVersionPK; + } + + public Model getModel() { + return model; + } + + public void setModel(Model model) { + this.model = model; + } + + public String getMlId() { + return model.getName() + "_" + modelVersionPK.getVersion(); + } + + public Users getCreator() { + return creator; + } + + public void setCreator(Users creator) { + this.creator = creator; + } + + @Override + public int hashCode() { + int hash = 0; + hash += (getModelVersionPK() != null ? getModelVersionPK().hashCode() : 0); + return hash; + } + + @Override + public boolean equals(Object object) { + // TODO: Warning - this method won't work in the case the id fields are not set + if (!(object instanceof ModelVersion)) { + return false; + } + ModelVersion other = (ModelVersion) object; + if ((this.getModelVersionPK() == null && other.getModelVersionPK() != null) || + (this.getModelVersionPK() != null && !this.getModelVersionPK().equals(other.getModelVersionPK()))) { + return false; + } + return true; + } +} + diff --git a/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java new file mode 100644 index 0000000000..962d8be41a --- /dev/null +++ b/hopsworks-persistence/src/main/java/io/hops/hopsworks/persistence/entity/models/version/ModelVersionPK.java @@ -0,0 +1,88 @@ +/* + * This file is part of Hopsworks + * Copyright (C) 2023, Hopsworks AB. All rights reserved + * + * Hopsworks is free software: you can redistribute it and/or modify it under the terms of + * the GNU Affero General Public License as published by the Free Software Foundation, + * either version 3 of the License, or (at your option) any later version. + * + * Hopsworks is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; + * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + * PURPOSE. See the GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License along with this program. + * If not, see . + */ + +package io.hops.hopsworks.persistence.entity.models.version; + +import java.io.Serializable; +import javax.persistence.Basic; +import javax.persistence.Column; +import javax.persistence.Embeddable; + +@Embeddable +public class ModelVersionPK implements Serializable { + + @Basic(optional = false) + @Column(name = "version") + private Integer version; + + @Basic(optional = false) + @Column(name = "model_id") + private Integer modelId; + + public ModelVersionPK() { + } + + public ModelVersionPK(Integer version, Integer modelId) { + this.version = version; + this.modelId = modelId; + } + + public Integer getModelId() { + return modelId; + } + + public void setModelId(Integer model) { + this.modelId = model; + } + + @Override + public int hashCode() { + int hash = 0; + hash += (int) version; + hash += (modelId != null ? modelId.hashCode() : 0); + return hash; + } + + @Override + public boolean equals(Object object) { + // TODO: Warning - this method won't work in the case the id fields are not set + if (!(object instanceof ModelVersionPK)) { + return false; + } + ModelVersionPK other = (ModelVersionPK) object; + if (this.version != other.version) { + return false; + } + if (this.modelId != other.modelId) { + return false; + } + return true; + } + + @Override + public String toString() { + return "io.hops.hopsworks.persistence.entity.models.version.ModelVersionPK[ model=" + + modelId + ", version=" + version + " ]"; + } + + public Integer getVersion() { + return version; + } + + public void setVersion(Integer version) { + this.version = version; + } +} diff --git a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml index 8960220556..2fc5a0248a 100644 --- a/hopsworks-persistence/src/main/resources/META-INF/persistence.xml +++ b/hopsworks-persistence/src/main/resources/META-INF/persistence.xml @@ -120,6 +120,8 @@ io.hops.hopsworks.persistence.entity.log.operation.OperationsLog io.hops.hopsworks.persistence.entity.maggy.MaggyDriver io.hops.hopsworks.persistence.entity.message.Message + io.hops.hopsworks.persistence.entity.models.Model + io.hops.hopsworks.persistence.entity.models.version.ModelVersion io.hops.hopsworks.persistence.entity.project.Project io.hops.hopsworks.persistence.entity.project.service.ProjectServices io.hops.hopsworks.persistence.entity.project.team.ProjectTeam @@ -171,6 +173,7 @@ io.hops.hopsworks.persistence.entity.alertmanager.ConfigConverter io.hops.hopsworks.persistence.entity.alertmanager.AlertReceiver io.hops.hopsworks.persistence.entity.serving.BatchingConfigurationConverter + io.hops.hopsworks.persistence.entity.models.version.ModelMetricsConverter io.hops.hopsworks.persistence.entity.pki.SerialNumber io.hops.hopsworks.persistence.entity.pki.PKIKey io.hops.hopsworks.persistence.entity.pki.PKICertificate