Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(restapi): add metrics capabilities to RESTAPI #672

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/dioptra/client/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .tags import TagsSubCollectionClient

ARTIFACTS: Final[str] = "artifacts"
METRICS: Final[str] = "metrics"
MLFLOW_RUN: Final[str] = "mlflowRun"
STATUS: Final[str] = "status"

Expand Down Expand Up @@ -702,3 +703,47 @@ def delete_by_id(self, experiment_id: str | int) -> T:
The response from the Dioptra API.
"""
return self._session.delete(self.url, str(experiment_id))

def get_metrics_by_id(
self,
experiment_id: str | int,
index: int = 0,
page_length: int = 10,
sort_by: str | None = None,
descending: bool | None = None,
search: str | None = None,
) -> T:
"""Get the metrics for the jobs in this experiment.

Args:
experiment_id: The experiment id, an integer.
index: The paging index. Optional, defaults to 0.
page_length: The maximum number of experiments to return in the paged
response. Optional, defaults to 10.
sort_by: The field to use to sort the returned list. Optional, defaults to
None.
descending: Sort the returned list in descending order. Optional, defaults
to None.
search: Search for jobs using the Dioptra API's query language.
Optional, defaults to None.

Returns:
The response from the Dioptra API.
"""

params: dict[str, Any] = {
"experiment_id": experiment_id,
"index": index,
"pageLength": page_length,
}

if sort_by is not None:
params["sortBy"] = sort_by

if descending is not None:
params["descending"] = descending

if search is not None:
params["search"] = search

return self._session.get(self.url, str(experiment_id), METRICS, params=params)
73 changes: 70 additions & 3 deletions src/dioptra/client/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
from .snapshots import SnapshotsSubCollectionClient
from .tags import TagsSubCollectionClient

METRICS: Final[str] = "metrics"
MLFLOW_RUN: Final[str] = "mlflowRun"
SNAPSHOTS: Final[str] = "snapshots"
STATUS: Final[str] = "status"

T = TypeVar("T")
Expand Down Expand Up @@ -166,7 +168,7 @@ def delete_by_id(self, job_id: str | int) -> T:
"""
return self._session.delete(self.url, str(job_id))

def get_mlflow_run_id(self, job_id: int) -> T:
def get_mlflow_run_id(self, job_id: str | int) -> T:
"""Gets the MLflow run id for a job.

Args:
Expand All @@ -177,7 +179,7 @@ def get_mlflow_run_id(self, job_id: int) -> T:
"""
return self._session.get(self.url, str(job_id), MLFLOW_RUN)

def set_mlflow_run_id(self, job_id: int, mlflow_run_id: str) -> T:
def set_mlflow_run_id(self, job_id: str | int, mlflow_run_id: str) -> T:
"""Sets the MLflow run id for a job.

Args:
Expand All @@ -193,7 +195,7 @@ def set_mlflow_run_id(self, job_id: int, mlflow_run_id: str) -> T:

return self._session.post(self.url, str(job_id), MLFLOW_RUN, json_=json_)

def get_status(self, job_id: int) -> T:
def get_status(self, job_id: str | int) -> T:
"""Gets the status for a job.

Args:
Expand All @@ -203,3 +205,68 @@ def get_status(self, job_id: int) -> T:
The response from the Dioptra API.
"""
return self._session.get(self.url, str(job_id), STATUS)

def get_metrics_by_id(self, job_id: str | int) -> T:
"""Gets all the latest metrics for a job.

Args:
job_id: The job id, an integer.

Returns:
The response from the Dioptra API.
"""
return self._session.get(self.url, str(job_id), METRICS)

def append_metric_by_id(
self,
job_id: str | int,
metric_name: str,
metric_value: float,
metric_step: int | None = None,
) -> T:
"""Posts a new metric to a job.

Args:
job_id: The job id, an integer.
metric_name: The name of the metric.
metric_value: The value of the metric.
metric_step: The step number of the metric, optional.

Returns:
The response from the Dioptra API.
"""
json_ = {
"name": metric_name,
"value": metric_value,
}

if metric_step is not None:
json_["step"] = metric_step

return self._session.post(self.url, str(job_id), METRICS, json_=json_)

def get_metrics_snapshots_by_id(
self,
job_id: str | int,
metric_name: str | int,
index: int = 0,
page_length: int = 10,
) -> T:
"""Gets the metric history for a job with a specific metric name.

Args:
job_id: The job id, an integer.
metric_name: The name of the metric.
index: The paging index. Optional, defaults to 0.
page_length: The maximum number of metrics to return in the paged
response. Optional, defaults to 10.
Returns:
The response from the Dioptra API.
"""
params: dict[str, Any] = {
"index": index,
"pageLength": page_length,
}
return self._session.get(
self.url, str(job_id), METRICS, metric_name, SNAPSHOTS, params=params
)
12 changes: 12 additions & 0 deletions src/dioptra/restapi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ def __init__(self, message: str):
super().__init__(message)


class MLFlowError(DioptraError):
"""MLFlow Error."""

def __init__(self, message: str):
super().__init__(message)


def error_result(
error: DioptraError, status: http.HTTPStatus, detail: dict[str, typing.Any]
) -> tuple[dict[str, typing.Any], int]:
Expand Down Expand Up @@ -431,6 +438,11 @@ def handle_user_password_error(error: UserPasswordError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.UNAUTHORIZED, {})

@api.errorhandler(MLFlowError)
def handle_mlflow_error(error: MLFlowError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {})

@api.errorhandler(DioptraError)
def handle_base_error(error: DioptraError):
log.debug(error.to_message())
Expand Down
68 changes: 68 additions & 0 deletions src/dioptra/restapi/v1/experiments/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from dioptra.restapi.v1.entrypoints.schema import EntrypointRefSchema
from dioptra.restapi.v1.jobs.schema import (
ExperimentJobGetQueryParameters,
ExperimentJobsMetricsSchema,
JobMlflowRunSchema,
JobPageSchema,
JobSchema,
Expand All @@ -45,6 +46,7 @@
ExperimentJobIdService,
ExperimentJobIdStatusService,
ExperimentJobService,
ExperimentMetricsService,
)
from dioptra.restapi.v1.schemas import IdListSchema, IdStatusResponseSchema
from dioptra.restapi.v1.shared.drafts.controller import (
Expand All @@ -64,6 +66,7 @@
from .schema import (
ExperimentDraftSchema,
ExperimentGetQueryParameters,
ExperimentMetricsGetQueryParameters,
ExperimentMutableFieldsSchema,
ExperimentPageSchema,
ExperimentSchema,
Expand Down Expand Up @@ -521,6 +524,71 @@ def post(self, id: int, jobId: int):
return utils.build_artifact(artifact)


@api.route("/<int:id>/metrics")
@api.param("id", "ID for the Experiment resource.")
class ExperimentIdMetricsEndpoint(Resource):
@inject
def __init__(
self,
experiment_metrics_service: ExperimentMetricsService,
*args,
**kwargs,
) -> None:
"""Initialize the Experiment Metrics resource.

All arguments are provided via dependency injection.

Args:
experiment_metrics_service: A ExperimentMetricsService object.
"""
self._experiment_metrics_service = experiment_metrics_service
super().__init__(*args, **kwargs)

@login_required
@accepts(query_params_schema=ExperimentMetricsGetQueryParameters, api=api)
@responds(schema=ExperimentJobsMetricsSchema, api=api)
def get(self, id: int):
"""Gets all of the latest metrics for every job in the experiment."""
log = LOGGER.new(
request_id=str(uuid.uuid4()),
resource="ExperimentIdMetricsEndpoint",
request_type="GET",
experiment_id=id,
)

parsed_query_params = request.parsed_query_params # type: ignore
search_string = unquote(parsed_query_params["search"])
page_index = parsed_query_params["index"]
page_length = parsed_query_params["page_length"]
sort_by_string = unquote(parsed_query_params["sort_by"])
descending = parsed_query_params["descending"]

jobs_metrics, total_num_jobs = self._experiment_metrics_service.get(
experiment_id=id,
search_string=search_string,
page_index=page_index,
page_length=page_length,
sort_by_string=sort_by_string,
descending=descending,
error_if_not_found=True,
log=log,
)

return utils.build_paging_envelope(
f"/experiments/{id}/metrics",
build_fn=utils.build_metrics_snapshots,
data=jobs_metrics,
group_id=None,
query=None,
draft_type=None,
index=page_index,
length=page_length,
total_num_elements=total_num_jobs,
sort_by=None,
descending=None,
)


@api.route("/<int:id>/entrypoints")
@api.param("id", "ID for the Experiment resource.")
class ExperimentIdEntrypointsEndpoint(Resource):
Expand Down
8 changes: 8 additions & 0 deletions src/dioptra/restapi/v1/experiments/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ class ExperimentPageSchema(BasePageSchema):
)


class ExperimentMetricsGetQueryParameters(
PagingQueryParametersSchema,
SearchQueryParametersSchema,
SortByGetQueryParametersSchema,
):
"""The query parameters for the GET method of the /experiments/{id}/metrics"""


class ExperimentGetQueryParameters(
PagingQueryParametersSchema,
GroupIdQueryParametersSchema,
Expand Down
Loading
Loading