diff --git a/src/dioptra/client/experiments.py b/src/dioptra/client/experiments.py index 605559e8b..bbc0d3b31 100644 --- a/src/dioptra/client/experiments.py +++ b/src/dioptra/client/experiments.py @@ -31,6 +31,7 @@ from .tags import TagsSubCollectionClient ARTIFACTS: Final[str] = "artifacts" +METRICS: Final[str] = "metrics" MLFLOW_RUN: Final[str] = "mlflowRun" STATUS: Final[str] = "status" @@ -702,3 +703,47 @@ def delete_by_id(self, experiment_id: str | int) -> T: The response from the Dioptra API. """ return self._session.delete(self.url, str(experiment_id)) + + def get_metrics_by_id( + self, + experiment_id: str | int, + index: int = 0, + page_length: int = 10, + sort_by: str | None = None, + descending: bool | None = None, + search: str | None = None, + ) -> T: + """Get the metrics for the jobs in this experiment. + + Args: + experiment_id: The experiment id, an integer. + index: The paging index. Optional, defaults to 0. + page_length: The maximum number of experiments to return in the paged + response. Optional, defaults to 10. + sort_by: The field to use to sort the returned list. Optional, defaults to + None. + descending: Sort the returned list in descending order. Optional, defaults + to None. + search: Search for jobs using the Dioptra API's query language. + Optional, defaults to None. + + Returns: + The response from the Dioptra API. + """ + + params: dict[str, Any] = { + "experiment_id": experiment_id, + "index": index, + "pageLength": page_length, + } + + if sort_by is not None: + params["sortBy"] = sort_by + + if descending is not None: + params["descending"] = descending + + if search is not None: + params["search"] = search + + return self._session.get(self.url, str(experiment_id), METRICS, params=params) diff --git a/src/dioptra/client/jobs.py b/src/dioptra/client/jobs.py index c14750f3e..aaf0e8506 100644 --- a/src/dioptra/client/jobs.py +++ b/src/dioptra/client/jobs.py @@ -20,7 +20,9 @@ from .snapshots import SnapshotsSubCollectionClient from .tags import TagsSubCollectionClient +METRICS: Final[str] = "metrics" MLFLOW_RUN: Final[str] = "mlflowRun" +SNAPSHOTS: Final[str] = "snapshots" STATUS: Final[str] = "status" T = TypeVar("T") @@ -166,7 +168,7 @@ def delete_by_id(self, job_id: str | int) -> T: """ return self._session.delete(self.url, str(job_id)) - def get_mlflow_run_id(self, job_id: int) -> T: + def get_mlflow_run_id(self, job_id: str | int) -> T: """Gets the MLflow run id for a job. Args: @@ -177,7 +179,7 @@ def get_mlflow_run_id(self, job_id: int) -> T: """ return self._session.get(self.url, str(job_id), MLFLOW_RUN) - def set_mlflow_run_id(self, job_id: int, mlflow_run_id: str) -> T: + def set_mlflow_run_id(self, job_id: str | int, mlflow_run_id: str) -> T: """Sets the MLflow run id for a job. Args: @@ -193,7 +195,7 @@ def set_mlflow_run_id(self, job_id: int, mlflow_run_id: str) -> T: return self._session.post(self.url, str(job_id), MLFLOW_RUN, json_=json_) - def get_status(self, job_id: int) -> T: + def get_status(self, job_id: str | int) -> T: """Gets the status for a job. Args: @@ -203,3 +205,68 @@ def get_status(self, job_id: int) -> T: The response from the Dioptra API. """ return self._session.get(self.url, str(job_id), STATUS) + + def get_metrics_by_id(self, job_id: str | int) -> T: + """Gets all the latest metrics for a job. + + Args: + job_id: The job id, an integer. + + Returns: + The response from the Dioptra API. + """ + return self._session.get(self.url, str(job_id), METRICS) + + def append_metric_by_id( + self, + job_id: str | int, + metric_name: str, + metric_value: float, + metric_step: int | None = None, + ) -> T: + """Posts a new metric to a job. + + Args: + job_id: The job id, an integer. + metric_name: The name of the metric. + metric_value: The value of the metric. + metric_step: The step number of the metric, optional. + + Returns: + The response from the Dioptra API. + """ + json_ = { + "name": metric_name, + "value": metric_value, + } + + if metric_step is not None: + json_["step"] = metric_step + + return self._session.post(self.url, str(job_id), METRICS, json_=json_) + + def get_metrics_snapshots_by_id( + self, + job_id: str | int, + metric_name: str | int, + index: int = 0, + page_length: int = 10, + ) -> T: + """Gets the metric history for a job with a specific metric name. + + Args: + job_id: The job id, an integer. + metric_name: The name of the metric. + index: The paging index. Optional, defaults to 0. + page_length: The maximum number of metrics to return in the paged + response. Optional, defaults to 10. + Returns: + The response from the Dioptra API. + """ + params: dict[str, Any] = { + "index": index, + "pageLength": page_length, + } + return self._session.get( + self.url, str(job_id), METRICS, metric_name, SNAPSHOTS, params=params + ) diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index 58fea2a30..1bb547ec9 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -332,6 +332,13 @@ def __init__(self, message: str): super().__init__(message) +class MLFlowError(DioptraError): + """MLFlow Error.""" + + def __init__(self, message: str): + super().__init__(message) + + def error_result( error: DioptraError, status: http.HTTPStatus, detail: dict[str, typing.Any] ) -> tuple[dict[str, typing.Any], int]: @@ -431,6 +438,11 @@ def handle_user_password_error(error: UserPasswordError): log.debug(error.to_message()) return error_result(error, http.HTTPStatus.UNAUTHORIZED, {}) + @api.errorhandler(MLFlowError) + def handle_mlflow_error(error: MLFlowError): + log.debug(error.to_message()) + return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {}) + @api.errorhandler(DioptraError) def handle_base_error(error: DioptraError): log.debug(error.to_message()) diff --git a/src/dioptra/restapi/v1/experiments/controller.py b/src/dioptra/restapi/v1/experiments/controller.py index 6ffb50f60..2dcaf0a71 100644 --- a/src/dioptra/restapi/v1/experiments/controller.py +++ b/src/dioptra/restapi/v1/experiments/controller.py @@ -35,6 +35,7 @@ from dioptra.restapi.v1.entrypoints.schema import EntrypointRefSchema from dioptra.restapi.v1.jobs.schema import ( ExperimentJobGetQueryParameters, + ExperimentJobsMetricsSchema, JobMlflowRunSchema, JobPageSchema, JobSchema, @@ -45,6 +46,7 @@ ExperimentJobIdService, ExperimentJobIdStatusService, ExperimentJobService, + ExperimentMetricsService, ) from dioptra.restapi.v1.schemas import IdListSchema, IdStatusResponseSchema from dioptra.restapi.v1.shared.drafts.controller import ( @@ -64,6 +66,7 @@ from .schema import ( ExperimentDraftSchema, ExperimentGetQueryParameters, + ExperimentMetricsGetQueryParameters, ExperimentMutableFieldsSchema, ExperimentPageSchema, ExperimentSchema, @@ -521,6 +524,71 @@ def post(self, id: int, jobId: int): return utils.build_artifact(artifact) +@api.route("//metrics") +@api.param("id", "ID for the Experiment resource.") +class ExperimentIdMetricsEndpoint(Resource): + @inject + def __init__( + self, + experiment_metrics_service: ExperimentMetricsService, + *args, + **kwargs, + ) -> None: + """Initialize the Experiment Metrics resource. + + All arguments are provided via dependency injection. + + Args: + experiment_metrics_service: A ExperimentMetricsService object. + """ + self._experiment_metrics_service = experiment_metrics_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(query_params_schema=ExperimentMetricsGetQueryParameters, api=api) + @responds(schema=ExperimentJobsMetricsSchema, api=api) + def get(self, id: int): + """Gets all of the latest metrics for every job in the experiment.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="ExperimentIdMetricsEndpoint", + request_type="GET", + experiment_id=id, + ) + + parsed_query_params = request.parsed_query_params # type: ignore + search_string = unquote(parsed_query_params["search"]) + page_index = parsed_query_params["index"] + page_length = parsed_query_params["page_length"] + sort_by_string = unquote(parsed_query_params["sort_by"]) + descending = parsed_query_params["descending"] + + jobs_metrics, total_num_jobs = self._experiment_metrics_service.get( + experiment_id=id, + search_string=search_string, + page_index=page_index, + page_length=page_length, + sort_by_string=sort_by_string, + descending=descending, + error_if_not_found=True, + log=log, + ) + + return utils.build_paging_envelope( + f"/experiments/{id}/metrics", + build_fn=utils.build_metrics_snapshots, + data=jobs_metrics, + group_id=None, + query=None, + draft_type=None, + index=page_index, + length=page_length, + total_num_elements=total_num_jobs, + sort_by=None, + descending=None, + ) + + @api.route("//entrypoints") @api.param("id", "ID for the Experiment resource.") class ExperimentIdEntrypointsEndpoint(Resource): diff --git a/src/dioptra/restapi/v1/experiments/schema.py b/src/dioptra/restapi/v1/experiments/schema.py index c078cf67a..ce4edc2a8 100644 --- a/src/dioptra/restapi/v1/experiments/schema.py +++ b/src/dioptra/restapi/v1/experiments/schema.py @@ -117,6 +117,14 @@ class ExperimentPageSchema(BasePageSchema): ) +class ExperimentMetricsGetQueryParameters( + PagingQueryParametersSchema, + SearchQueryParametersSchema, + SortByGetQueryParametersSchema, +): + """The query parameters for the GET method of the /experiments/{id}/metrics""" + + class ExperimentGetQueryParameters( PagingQueryParametersSchema, GroupIdQueryParametersSchema, diff --git a/src/dioptra/restapi/v1/jobs/controller.py b/src/dioptra/restapi/v1/jobs/controller.py index 11576034c..5cf5bcd27 100644 --- a/src/dioptra/restapi/v1/jobs/controller.py +++ b/src/dioptra/restapi/v1/jobs/controller.py @@ -48,10 +48,15 @@ JobPageSchema, JobSchema, JobStatusSchema, + MetricsSchema, + MetricsSnapshotPageSchema, + MetricsSnapshotsGetQueryParameters, ) from .service import ( RESOURCE_TYPE, SEARCHABLE_FIELDS, + JobIdMetricsService, + JobIdMetricsSnapshotsService, JobIdMlflowrunService, JobIdService, JobIdStatusService, @@ -201,7 +206,7 @@ def __init__( All arguments are provided via dependency injection. Args: - job_id_service: A JobIdStatusService object. + job_id_service: A JobIdMlflowrunService object. """ self._job_id_mlflowrun_service = job_id_mlflowrun_service super().__init__(*args, **kwargs) @@ -240,6 +245,123 @@ def post(self, id: int): ) +@api.route("//metrics") +@api.param("id", "ID for the Job resource.") +class JobIdMetricsEndpoint(Resource): + @inject + def __init__( + self, + job_id_metrics_service: JobIdMetricsService, + *args, + **kwargs, + ) -> None: + """Initialize the jobs resource. + + All arguments are provided via dependency injection. + + Args: + job_id_metrics_service: A JobIdMetricsService object. + """ + self._job_id_metrics_service = job_id_metrics_service + super().__init__(*args, **kwargs) + + @login_required + @responds(schema=MetricsSchema(many=True), api=api) + def get(self, id: int): + """Gets a Job resource's latest metrics.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="JobIdMetricsEndpoint", + request_type="GET", + job_id=id, + ) + + return self._job_id_metrics_service.get( + job_id=id, error_if_not_found=True, log=log + ) + + @login_required + @accepts(schema=MetricsSchema, api=api) + @responds(schema=MetricsSchema, api=api) + def post(self, id: int): + """Sets a metric for a Job""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="JobIdMetricsEndpoint", + request_type="POST", + job_id=id, + ) + parsed_obj = request.parsed_obj # type: ignore + return self._job_id_metrics_service.update( + job_id=id, + metric_name=parsed_obj["name"], + metric_value=parsed_obj["value"], + metric_step=parsed_obj["step"], + error_if_not_found=True, + log=log, + ) + + +@api.route("//metrics//snapshots") +@api.param("id", "ID for the Job resource.") +@api.param("name", "Name of the metric.") +class JobIdMetricsSnapshotsEndpoint(Resource): + @inject + def __init__( + self, + job_id_metrics_snapshots_service: JobIdMetricsSnapshotsService, + *args, + **kwargs, + ) -> None: + """Initialize the jobs resource. + + All arguments are provided via dependency injection. + + Args: + job_id_metrics_snapshots_service: A JobIdMetricsSnapshotsService object. + """ + self._job_id_metrics_snapshots_service = job_id_metrics_snapshots_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(query_params_schema=MetricsSnapshotsGetQueryParameters, api=api) + @responds(schema=MetricsSnapshotPageSchema, api=api) + def get(self, id: int, name: str): + """Gets a Job resource's metric history.""" + log = LOGGER.new( + request_id=str(uuid.uuid4()), + resource="JobIdMetricsSnapshotsEndpoint", + request_type="GET", + job_id=id, + metric_name=name, + ) + parsed_query_params = request.parsed_query_params # type: ignore + page_index = parsed_query_params["index"] + page_length = parsed_query_params["page_length"] + metrics_page, total_num_metrics = self._job_id_metrics_snapshots_service.get( + job_id=id, + metric_name=name, + page_index=page_index, + page_length=page_length, + error_if_not_found=True, + log=log, + ) + + return utils.build_paging_envelope( + f"jobs/{id}/metrics/{name}/snapshots", + build_fn=utils.build_metrics_snapshots, + data=metrics_page, + group_id=None, + query=None, + draft_type=None, + index=page_index, + length=page_length, + total_num_elements=total_num_metrics, + sort_by=None, + descending=None, + ) + + JobSnapshotsResource = generate_resource_snapshots_endpoint( api=api, resource_model=models.Job, diff --git a/src/dioptra/restapi/v1/jobs/schema.py b/src/dioptra/restapi/v1/jobs/schema.py index 7a0fe01ea..de549befd 100644 --- a/src/dioptra/restapi/v1/jobs/schema.py +++ b/src/dioptra/restapi/v1/jobs/schema.py @@ -15,6 +15,8 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The schemas for serializing/deserializing Job resources.""" +import re + from marshmallow import Schema, fields, validate from dioptra.restapi.v1.artifacts.schema import ArtifactRefSchema @@ -28,6 +30,9 @@ generate_base_resource_schema, ) +ALLOWED_METRIC_NAME_REGEX = re.compile(r"^([A-Z]|[A-Z_][A-Z0-9_]+)$", flags=re.IGNORECASE) # noqa: B950; fmt: skip + + JobRefSchema = generate_base_resource_ref_schema("Job") JobSnapshotRefSchema = generate_base_resource_ref_schema("Job", keep_snapshot_id=True) @@ -41,14 +46,93 @@ class JobMlflowRunSchema(Schema): ) -class JobStatusSchema(Schema): - """The fields schema for the data in a Job status resource.""" - +class JobIdSchema(Schema): id = fields.Integer( attribute="id", metadata=dict(description="ID for the Job resource."), dump_only=True, ) + + +class MetricsSchema(Schema): + name = fields.String( + attribute="name", + metadata=dict(description="The name of the metric."), + required=True, + validate=validate.Regexp( + ALLOWED_METRIC_NAME_REGEX, + error=( + "'{input}' is not a compatible name for a metric. " + "A metric name must start with a letter or underscore, " + "followed by letters, numbers, or underscores. In " + "addition, '_' is not a valid metric name." + ), + ), + ) + + value = fields.Float( + attribute="value", + metadata=dict(description="The value of the metric."), + required=True, + ) + step = fields.Integer( + attribute="step", + metadata=dict(description="The step value for the metric."), + load_only=True, + required=False, + load_default=0, + ) + + +class MetricsSnapshotSchema(Schema): + name = fields.String( + attribute="name", + metadata=dict(description="The name of the metric."), + ) + value = fields.Float( + attribute="value", + metadata=dict(description="The value of the metric."), + ) + step = fields.Integer( + attribute="step", + metadata=dict(description="The step value for the metric."), + ) + timestamp = fields.Integer( + attribute="timestamp", + metadata=dict(description="The timestamp of the metric in milliseconds."), + ) + + +class MetricsSnapshotPageSchema(BasePageSchema): + data = fields.Nested( + MetricsSnapshotSchema, + many=True, + metadata=dict(description="List of Metric Snapshots in the current page."), + ) + + +class JobIdMetricsSchema(JobIdSchema): + metrics = fields.Nested( + MetricsSchema, + attribute="metrics", + metadata=dict( + description="A list of the latest metrics associated with the job." + ), + many=True, + ) + + +class ExperimentJobsMetricsSchema(BasePageSchema): + data = fields.Nested( + JobIdMetricsSchema, + many=True, + metadata=dict(description="List of metrics for each job in the experiment"), + ) + + +class JobStatusSchema(JobIdSchema): + """The fields schema for the data in a Job status resource.""" + status = fields.String( attribute="status", validate=validate.OneOf( @@ -154,6 +238,13 @@ class JobPageSchema(BasePageSchema): ) +class MetricsSnapshotsGetQueryParameters( + PagingQueryParametersSchema, +): + """The query parameters for the GET method of the + /jobs/{id}/metrics/{name}/snapshots endpoint.""" + + class JobGetQueryParameters( PagingQueryParametersSchema, GroupIdQueryParametersSchema, diff --git a/src/dioptra/restapi/v1/jobs/service.py b/src/dioptra/restapi/v1/jobs/service.py index 81a6168bf..be90f77a5 100644 --- a/src/dioptra/restapi/v1/jobs/service.py +++ b/src/dioptra/restapi/v1/jobs/service.py @@ -18,10 +18,12 @@ from __future__ import annotations from typing import Any, Final, cast +from uuid import UUID import structlog from flask_login import current_user from injector import inject +from mlflow.exceptions import MlflowException from sqlalchemy import func, select from sqlalchemy.orm import aliased from structlog.stdlib import BoundLogger @@ -34,6 +36,7 @@ JobInvalidParameterNameError, JobInvalidStatusTransitionError, JobMlflowRunAlreadySetError, + MLFlowError, SortParameterValidationError, ) from dioptra.restapi.v1 import utils @@ -483,16 +486,12 @@ class JobIdStatusService(object): @inject def __init__( self, - job_id_service: JobIdService, ) -> None: """Initialize the job status service. All arguments are provided via dependency injection. - - Args: - job_id_service: A JobIdService object. """ - self._job_id_service = job_id_service + pass def get( self, @@ -527,6 +526,190 @@ def get( return {"status": job.status, "id": job.resource_id} +class JobIdMetricsService(object): + """The service methods for retrieving the metrics of a job by unique id.""" + + @inject + def __init__( + self, + job_id_mlflowrun_service: JobIdMlflowrunService, + ) -> None: + """Initialize the job metrics service. + + All arguments are provided via dependency injection. + + """ + self._job_id_mlflowrun_service = job_id_mlflowrun_service + + def get( + self, + job_id: int, + **kwargs, + ) -> list[dict[str, Any]]: + """Fetch a job's metrics by its unique id. + + Args: + job_id: The unique id of the job. + + Returns: + The metrics for the requested job if found, otherwise an error message. + """ + from mlflow.tracking import MlflowClient + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Get job metrics by id", job_id=job_id) + + run_id: UUID | None = self._job_id_mlflowrun_service.get( + job_id=job_id, **kwargs + )["mlflow_run_id"] + + try: + client = MlflowClient() + except MlflowException as e: + raise MLFlowError(e.message) from e + + if run_id is None: + metrics = [] + else: + try: + run = client.get_run(run_id.hex) + metrics = [ + {"name": metric, "value": run.data.metrics[metric]} + for metric in run.data.metrics.keys() + ] + except MlflowException: + metrics = [] + + return metrics + + def update( + self, + job_id: int, + metric_name: str, + metric_value: float, + metric_step: int | None = None, + **kwargs, + ) -> dict[str, Any]: + """Update a job's metrics by its unique id. + + Args: + job_id: The unique id of the job. + metric_name: The name of the metric to create or update. + metric_value: The value of the metric being updated. + Returns: + The metric dictionary passed in if successful, otherwise an error message. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Update job metrics by id", job_id=job_id) + + from mlflow.tracking import MlflowClient + + run_id: UUID | None = self._job_id_mlflowrun_service.get( + job_id=job_id, **kwargs + )["mlflow_run_id"] + + if run_id is None: + raise EntityDoesNotExistError("MlFlowRun", run_id=None) + else: + try: + client = MlflowClient() + except MlflowException as e: + raise MLFlowError(e.message) from e + + # this is here just to raise an error if the run does not exist + try: + run = client.get_run(run_id.hex) # noqa: F841 + except MlflowException as e: + raise EntityDoesNotExistError("MlFlowRun", run_id=run_id.hex) from e + + try: + client.log_metric( + run_id.hex, + key=metric_name, + value=metric_value, + step=metric_step, + ) + except MlflowException as e: + raise MLFlowError(e.message) from e + + return {"name": metric_name, "value": metric_value} + + +class JobIdMetricsSnapshotsService(object): + """The service methods for retrieving the historical metrics of a + job by unique id and metric name.""" + + @inject + def __init__(self, job_id_mlflowrun_service: JobIdMlflowrunService) -> None: + """Initialize the job metrics snapshots service. + + All arguments are provided via dependency injection. + """ + self._job_id_mlflowrun_service = job_id_mlflowrun_service + + def get( + self, + job_id: int, + metric_name: str, + page_index: int, + page_length: int, + **kwargs, + ) -> tuple[list[dict[str, Any]], int]: + """Fetch a job's metrics by its unique id and metric name. + + Args: + job_id: The unique id of the job. + metric_name: The name of the metric. + page_index: The index of the first page to be returned. + page_length: The maximum number of experiments to be returned. + Returns: + The metric history for the requested job and metric if found, + otherwise an error message. + """ + from mlflow.tracking import MlflowClient + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Get job metric history by id and name", + job_id=job_id, + metric_name=metric_name, + ) + + run_id: UUID | None = self._job_id_mlflowrun_service.get( + job_id=job_id, **kwargs + )["mlflow_run_id"] + + try: + client = MlflowClient() + except MlflowException as e: + raise MLFlowError(e.message) from e + + if run_id is None: + raise EntityDoesNotExistError("MlFlowRun", run_id=None) + + try: + history = client.get_metric_history(run_id=run_id.hex, key=metric_name) + except MlflowException as e: + raise MLFlowError(e.message) from e + + if history == []: + raise EntityDoesNotExistError("Metric not found", name=metric_name) + + metrics_page = [ + { + "name": metric.key, + "value": metric.value, + "step": metric.step, + "timestamp": metric.timestamp, + } + for metric in history[ + page_index * page_length : (page_index + 1) * page_length + ] + ] + + return metrics_page, len(history) + + class ExperimentJobService(object): """The service methods for submitting and retrieving jobs within an experiment namespace.""" @@ -993,7 +1176,7 @@ def get( job_id: The unique id of the job. Returns: - The status message job object if found, otherwise an error message. + The MlflowRun id of the job object if found, otherwise an error message. """ log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Get job status by id", job_id=job_id) @@ -1055,3 +1238,68 @@ def create( ) return {"mlflow_run_id": job.mlflow_run.mlflow_run_id} + + +class ExperimentMetricsService(object): + """The service methods for retrieving metrics attached to jobs in the experiment.""" + + @inject + def __init__( + self, + experiment_jobs_service: ExperimentJobService, + job_id_metrics_service: JobIdMetricsService, + ) -> None: + """Initialize the experiment service. + + All arguments are provided via dependency injection. + """ + self._job_id_metrics_service = job_id_metrics_service + self._experiment_jobs_service = experiment_jobs_service + + def get( + self, + experiment_id: int, + search_string: str, + page_index: int, + page_length: int, + sort_by_string: str, + descending: bool, + **kwargs, + ): + """Get a list of jobs and the latest metrics associated with each. + + Args: + experiment_id: The unique id of the experiment. + error_if_not_found: If True, raise an error if the experiment is not found. + Defaults to False. + + Returns: + The list of jobs and the metrics associated with them. + + Raises: + EntityDoesNotExistError: If the experiment is not found and + `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Get metrics for all jobs for an experiment by resource id", + resource_id=experiment_id, + ) + + jobs, num_jobs = self._experiment_jobs_service.get( + experiment_id, + search_string=search_string, + page_index=page_index, + page_length=page_length, + sort_by_string=sort_by_string, + descending=descending, + **kwargs, + ) + + job_ids = [job["job"].resource_id for job in jobs] + + metrics_for_jobs = [ + {"id": job_id, "metrics": self._job_id_metrics_service.get(job_id)} + for job_id in job_ids + ] + return metrics_for_jobs, num_jobs diff --git a/src/dioptra/restapi/v1/utils.py b/src/dioptra/restapi/v1/utils.py index 349737517..e003e79a1 100644 --- a/src/dioptra/restapi/v1/utils.py +++ b/src/dioptra/restapi/v1/utils.py @@ -631,11 +631,25 @@ def build_entrypoint(entrypoint_dict: EntrypointDict) -> dict[str, Any]: return data +def build_metrics_snapshots(metrics_snapshots_dict: dict[str, Any]) -> dict[str, Any]: + """Build a Metrics Snapshot response dictionary. + + Args: + metrics_snapshots_dict: The Metrics Snapshots object to convert + into a dictionary. + + Returns: + The Metric Snapshots response dictionary. + """ + # no changes currently + return metrics_snapshots_dict + + def build_job(job_dict: JobDict) -> dict[str, Any]: """Build a Job response dictionary. Args: - job: The Job object to convert into a job response dictionary. + job_dict: The Job object to convert into a job response dictionary. Returns: The Job response dictionary. diff --git a/tests/unit/restapi/lib/actions.py b/tests/unit/restapi/lib/actions.py index 95ccb48cc..b7f10c6f7 100644 --- a/tests/unit/restapi/lib/actions.py +++ b/tests/unit/restapi/lib/actions.py @@ -768,3 +768,71 @@ def remove_tag( f"/{V1_ROOT}/{resource_route}/{resource_id}/tags/{tag_id}", follow_redirects=True, ) + +def post_metrics( + client: FlaskClient, job_id: int, metric_name: str, metric_value: float +) -> TestResponse: + """Remove tag from the resource with the provided unique ID. + + Args: + client: The Flask test client. + job_id: The id of the Job to post metrics to. + metric_name: The name of the metric. + metric_value: The value of the metric. + + Returns: + The response from the API. + """ + + return client.post( + f"/{V1_ROOT}/{V1_JOBS_ROUTE}/{job_id}/metrics", + json={"name": metric_name, "value": metric_value}, + ) + + +def post_mlflowrun( + client: FlaskClient, job_id: int, mlflow_run_id: str +) -> TestResponse: + """Add an mlflow run id to a job. + + Args: + client: The Flask test client. + job_id: The id of the Job. + mlflow_run_id: The id of the mlflow run. + Returns: + The response from the API. + + """ + payload = {"mlflowRunId": mlflow_run_id} + response = client.post( + f"/{V1_ROOT}/{V1_JOBS_ROUTE}/{job_id}/mlflowRun", + json=payload, + follow_redirects=True, + ) + return response + + +def post_mlflowruns( + client: FlaskClient, mlflowruns: dict[str, Any], registered_jobs: dict[str, Any] +) -> dict[str, Any]: + """Add mlflow run ids to multiple jobs. + + Args: + client: The Flask test client. + mlflowruns: A dictionary mapping job key to mlflow run id. + registered_jobs: A dictionary of registered jobs. + + Returns: + The responses from the API. + """ + + responses = {} + + for key in mlflowruns.keys(): + job_id = registered_jobs[key]["id"] + mlflowrun_response = post_mlflowrun( + client=client, job_id=job_id, mlflow_run_id=mlflowruns[key].hex + ).get_json() + responses[key] = mlflowrun_response + + return responses \ No newline at end of file diff --git a/tests/unit/restapi/lib/asserts.py b/tests/unit/restapi/lib/asserts.py index 47e6449d8..6d616b402 100644 --- a/tests/unit/restapi/lib/asserts.py +++ b/tests/unit/restapi/lib/asserts.py @@ -57,7 +57,7 @@ def assert_group_ref_contents_matches_expectations( assert group["id"] == expected_group_id -def assert_tag_ref_contents_matches_expectations(tags: dict[str, Any]) -> None: +def assert_tag_ref_contents_matches_expectations(tags: list[dict[str, Any]]) -> None: for tag in tags: assert isinstance(tag["id"], int) assert isinstance(tag["name"], str) diff --git a/tests/unit/restapi/lib/mock_mlflow.py b/tests/unit/restapi/lib/mock_mlflow.py new file mode 100644 index 000000000..a7fb7a851 --- /dev/null +++ b/tests/unit/restapi/lib/mock_mlflow.py @@ -0,0 +1,190 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from __future__ import annotations + +import time +from typing import Any, Optional + +import structlog +from structlog.stdlib import BoundLogger + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +active_runs: dict[str, list[MockMlflowMetric]] = {} + + +class MockMlflowClient(object): + def __init__(self) -> None: + LOGGER.info( + "Mocking mlflow.tracking.MlflowClient instance", + ) + + def get_run(self, id: str) -> MockMlflowRun: + # Note: In Mlflow, this function would usually throw an MlflowException + # if the run id is not found. For simplicity this is left out in favor of + # assuming the run should exist in this mock instance. + + LOGGER.info("Mocking MlflowClient.get_run() function") + if id not in active_runs: + active_runs[id] = [] + + run = MockMlflowRun(id) + metrics: list[MockMlflowMetric] = active_runs[id] + output_metrics: dict[str, MockMlflowMetric] = {} + for metric in metrics: + # find the latest metric for each metric name + if ( + metric.key not in output_metrics + or metric.timestamp > output_metrics[metric.key].timestamp + ): + output_metrics[metric.key] = metric + + # remove step and timestamp information + for output in output_metrics.keys(): + run.data.metrics[output] = output_metrics[output].value + return run + + def log_metric( + self, id: str, key: str, value: float, step: Optional[int] = None, timestamp: Optional[int] = None + ): + if id not in active_runs: + active_runs[id] = [] + if timestamp is None: + timestamp = time.time_ns() // 1000000 + active_runs[id] += [ + MockMlflowMetric( + key=key, + value=value, + step=0 if step is None else step, + timestamp=timestamp, + ) + ] + + def get_metric_history(self, run_id: str, key: str): + return [metric for metric in active_runs[run_id] if metric.key == key] + + +class MockMlflowRun(object): + def __init__( + self, + id: str, + ) -> None: + LOGGER.info("Mocking mlflow.entities.Run class") + self._id = id + self.data = MockMlflowRunData() + + @property + def id(self) -> str: + LOGGER.info("Mocking mlflow.entities.Run.id getter") + return self._id + + @id.setter + def id(self, value: str) -> None: + LOGGER.info("Mocking mlflow.entities.Run.id setter", value=value) + self._id = value + + @property + def data(self) -> MockMlflowRunData: + LOGGER.info("Mocking mlflow.entities.Run.data getter") + return self._data + + @data.setter + def data(self, value: MockMlflowRunData) -> None: + LOGGER.info("Mocking mlflow.entities.Run.data setter", value=value) + self._data = value + + +class MockMlflowRunData(object): + def __init__( + self, + ) -> None: + LOGGER.info("Mocking mlflow.entities.RunData class") + self._metrics: dict[str, Any] = {} + + @property + def metrics(self) -> dict[str, Any]: + LOGGER.info("Mocking mlflow.entities.RunData.metrics getter") + return self._metrics + + @metrics.setter + def metrics(self, value: dict[str, Any]) -> None: + LOGGER.info("Mocking mlflow.entities.RunData.metrics setter", value=value) + self._metrics = value + + +class MockMlflowMetric(object): + def __init__( + self, + key: str, + value: float, + step: int, + timestamp: int, + ) -> None: + LOGGER.info("Mocking mlflow.entities.Metric class") + self._key = key + self._value = value + self._step = step + self._timestamp = timestamp + + @property + def key(self) -> str: + LOGGER.info("Mocking mlflow.entities.Metric.key getter") + return self._key + + @key.setter + def key(self, value: str) -> None: + LOGGER.info("Mocking mlflow.entities.Metric.key setter", value=value) + self._key = value + + @property + def value(self) -> float: + LOGGER.info("Mocking mlflow.entities.Metric.value getter") + return self._value + + @value.setter + def value(self, value: float) -> None: + LOGGER.info("Mocking mlflow.entities.Metric.value setter", value=value) + self._value = value + + @property + def step(self) -> int: + LOGGER.info("Mocking mlflow.entities.Metric.step getter") + return self._step + + @step.setter + def step(self, value: int) -> None: + LOGGER.info("Mocking mlflow.entities.Metric.step setter", value=value) + self._step = value + + @property + def timestamp(self) -> int: + LOGGER.info("Mocking mlflow.entities.Metric.timestamp getter") + return self._timestamp + + @timestamp.setter + def timestamp(self, value: int) -> None: + LOGGER.info("Mocking mlflow.entities.Metric.timestamp setter", value=value) + self._timestamp = value + + +class MockMlflowException(Exception): + def __init__( + self, + text: str, + ) -> None: + LOGGER.info("Mocking mlflow.exceptions.MlflowException class") + super().__init__(text) diff --git a/tests/unit/restapi/test_utils.py b/tests/unit/restapi/test_utils.py index 513a2e62e..88cfe8a87 100644 --- a/tests/unit/restapi/test_utils.py +++ b/tests/unit/restapi/test_utils.py @@ -16,8 +16,6 @@ # https://creativecommons.org/licenses/by/4.0/legalcode from __future__ import annotations -import pytest - from dioptra.restapi.utils import find_non_unique diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index d83cd2f23..676725d26 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -21,6 +21,7 @@ from typing import Any, cast import pytest +import uuid from flask import Flask from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy @@ -684,3 +685,38 @@ def registered_jobs( "job2": job2_response, "job3": job3_response, } + + +@pytest.fixture +def registered_mlflowrun( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_jobs: dict[str, Any], +) -> dict[str, Any]: + mlflowruns = {"job1": uuid.uuid4(), "job2": uuid.uuid4(), "job3": uuid.uuid4()} + + responses = actions.post_mlflowruns( + client=client, mlflowruns=mlflowruns, registered_jobs=registered_jobs + ) + + return responses + + +@pytest.fixture +def registered_mlflowrun_incomplete( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_jobs: dict[str, Any], +) -> dict[str, Any]: + mlflowruns = { + "job1": uuid.uuid4(), + "job2": uuid.uuid4(), + } # leave job3 out so we can use that in test_mlflowrun() + + responses = actions.post_mlflowruns( + client=client, mlflowruns=mlflowruns, registered_jobs=registered_jobs + ) + + return responses diff --git a/tests/unit/restapi/v1/test_job.py b/tests/unit/restapi/v1/test_job.py index 3d06e4ed7..70af6537d 100644 --- a/tests/unit/restapi/v1/test_job.py +++ b/tests/unit/restapi/v1/test_job.py @@ -30,7 +30,7 @@ from dioptra.client.base import DioptraResponseProtocol from dioptra.client.client import DioptraClient -from ..lib import asserts, helpers, mock_rq, routines +from ..lib import asserts, helpers, mock_mlflow, mock_rq, routines # -- Assertions ------------------------------------------------------------------------ @@ -238,6 +238,89 @@ def assert_job_status_matches_expectations( ) +def assert_job_mlflowrun_matches_expectations( + dioptra_client: DioptraClient[DioptraResponseProtocol], job_id: int, expected: str +) -> None: + import uuid + + response = dioptra_client.jobs.get_mlflow_run_id(job_id=job_id) + assert ( + response.status_code == HTTPStatus.OK + and uuid.UUID(response.json()["mlflowRunId"]).hex == expected + ) + + +def assert_job_mlflowrun_already_set( + dioptra_client: DioptraClient[DioptraResponseProtocol], + job_id: int, + mlflow_run_id: str, +) -> None: + response = dioptra_client.jobs.set_mlflow_run_id( + job_id=job_id, mlflow_run_id=mlflow_run_id + ) + assert ( + response.status_code == HTTPStatus.BAD_REQUEST + and response.json()["error"] == "JobMlflowRunAlreadySetError" + ) + + +def assert_job_metrics_validation_error( + dioptra_client: DioptraClient[DioptraResponseProtocol], + job_id: int, + metric_name: str, + metric_value: float, +) -> None: + response = dioptra_client.jobs.append_metric_by_id( + job_id=job_id, metric_name=metric_name, metric_value=metric_value + ) + assert response.status_code == HTTPStatus.BAD_REQUEST + + +def assert_job_metrics_matches_expectations( + dioptra_client: DioptraClient[DioptraResponseProtocol], + job_id: int, + expected: list[dict[str, Any]], +) -> None: + response = dioptra_client.jobs.get_metrics_by_id(job_id=job_id) + assert response.status_code == HTTPStatus.OK and response.json() == expected + + +def assert_experiment_metrics_matches_expectations( + dioptra_client: DioptraClient[DioptraResponseProtocol], + experiment_id: int, + expected: list[dict[str, Any]], +) -> None: + response = dioptra_client.experiments.get_metrics_by_id(experiment_id=experiment_id) + assert response.status_code == HTTPStatus.OK and response.json()["data"] == expected + + +def assert_job_metrics_snapshots_matches_expectations( + dioptra_client: DioptraClient[DioptraResponseProtocol], + job_id: int, + metric_name: str, + expected: list[dict[str, Any]], +) -> None: + response = dioptra_client.jobs.get_metrics_snapshots_by_id( + job_id=job_id, metric_name=metric_name + ) + assert response.status_code == HTTPStatus.OK + + history = response.json()["data"] + + assert all( + [ + "name" in m and "value" in m and "timestamp" in m and "step" in m + for m in history + ] + ) + assert all( + [ + any([e["name"] == m["name"] and e["value"] == e["value"] for e in expected]) + for m in history + ] + ) + + # -- Tests ----------------------------------------------------------------------------- @@ -338,6 +421,153 @@ def test_create_job( ) +def test_mlflowrun( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_jobs: dict[str, Any], + registered_mlflowrun_incomplete: dict[str, Any], +): + import uuid + + job_uuid = uuid.uuid4().hex + + # explicitly use job3 because we did not set a mlflowrun on this job + + mlflowrun_response = dioptra_client.jobs.set_mlflow_run_id( # noqa: F841 + job_id=registered_jobs["job3"]["id"], mlflow_run_id=job_uuid + ).json() + + assert_job_mlflowrun_matches_expectations( + dioptra_client, job_id=registered_jobs["job3"]["id"], expected=job_uuid + ) + + assert_job_mlflowrun_already_set( + dioptra_client, job_id=registered_jobs["job1"]["id"], mlflow_run_id=job_uuid + ) + + +def test_metrics( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + registered_jobs: dict[str, Any], + registered_experiments: dict[str, Any], + registered_mlflowrun: dict[str, Any], + monkeypatch: MonkeyPatch, +) -> None: + import mlflow.exceptions + import mlflow.tracking + + monkeypatch.setattr(mlflow.tracking, "MlflowClient", mock_mlflow.MockMlflowClient) + monkeypatch.setattr( + mlflow.exceptions, "MlflowException", mock_mlflow.MockMlflowException + ) + + experiment_id = registered_experiments["experiment1"]["snapshot"] + job1_id = registered_jobs["job1"]["id"] + job2_id = registered_jobs["job2"]["id"] + job3_id = registered_jobs["job3"]["id"] + + metric_response = dioptra_client.jobs.append_metric_by_id( # noqa: F841 + job_id=job1_id, + metric_name="accuracy", + metric_value=4.0, + ).json() + + assert_job_metrics_matches_expectations( + dioptra_client, job_id=job1_id, expected=[{"name": "accuracy", "value": 4.0}] + ) + + metric_response = dioptra_client.jobs.append_metric_by_id( # noqa: F841 + job_id=job1_id, + metric_name="accuracy", + metric_value=4.1, + ).json() + + metric_response = dioptra_client.jobs.append_metric_by_id( # noqa: F841 + job_id=job1_id, + metric_name="accuracy", + metric_value=4.2, + ).json() + + metric_response = dioptra_client.jobs.append_metric_by_id( # noqa: F841 + job_id=job1_id, + metric_name="roc_auc", + metric_value=0.99, + ).json() + + metric_response = dioptra_client.jobs.append_metric_by_id( # noqa: F841 + job_id=job2_id, + metric_name="job_2_metric", + metric_value=0.11, + ).json() + + assert_job_metrics_matches_expectations( + dioptra_client, + job_id=job1_id, + expected=[ + {"name": "accuracy", "value": 4.2}, + {"name": "roc_auc", "value": 0.99}, + ], + ) + + assert_job_metrics_validation_error( + dioptra_client, + job_id=job1_id, + metric_name="!+_", + metric_value=4.0, + ) + + assert_job_metrics_validation_error( + dioptra_client, + job_id=job1_id, + metric_name="!!!!!", + metric_value=4.0, + ) + + assert_job_metrics_validation_error( + dioptra_client, + job_id=job1_id, + metric_name="$23", + metric_value=4.0, + ) + + assert_job_metrics_validation_error( + dioptra_client, + job_id=job1_id, + metric_name="abcdefghijk(lmnop)", + metric_value=4.0, + ) + + assert_experiment_metrics_matches_expectations( + dioptra_client, + experiment_id=experiment_id, + expected=[ + { + "id": job1_id, + "metrics": [ + {"name": "accuracy", "value": 4.2}, + {"name": "roc_auc", "value": 0.99}, + ], + }, + {"id": job2_id, "metrics": [{"name": "job_2_metric", "value": 0.11}]}, + {"id": job3_id, "metrics": []}, + ], + ) + + assert_job_metrics_snapshots_matches_expectations( + dioptra_client, + job_id=registered_jobs["job1"]["id"], + metric_name="accuracy", + expected=[ + {"name": "accuracy", "value": 4.2}, + {"name": "accuracy", "value": 4.1}, + {"name": "accuracy", "value": 4.0}, + ], + ) + + def test_job_get_all( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy,