Skip to content

Commit

Permalink
[ODSC-56699] Update MD entities to pydantic models (#1059)
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas authored Feb 5, 2025
2 parents 7ce9342 + 4fb0923 commit f5c3697
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 264 deletions.
16 changes: 15 additions & 1 deletion ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Optional

from ads.aqua.config.utils.serializer import Serializable


class ContainerSpec:
"""
Expand All @@ -15,3 +19,13 @@ class ContainerSpec:
ENV_VARS = "envVars"
RESTRICTED_PARAMS = "restrictedParams"
EVALUATION_CONFIGURATION = "evaluationConfiguration"


class ShapeInfo(Serializable):
instance_shape: Optional[str] = None
instance_count: Optional[int] = None
ocpus: Optional[float] = None
memory_in_gbs: Optional[float] = None

class Config:
extra = "ignore"
70 changes: 3 additions & 67 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from urllib.parse import urlparse
Expand All @@ -11,7 +11,7 @@
from ads.aqua.extension.errors import Errors
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
from ads.aqua.modeldeployment.entities import ModelParams
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
from ads.config import COMPARTMENT_OCID


class AquaDeploymentHandler(AquaAPIhandler):
Expand Down Expand Up @@ -98,71 +98,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

# required input parameters
display_name = input_data.get("display_name")
if not display_name:
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
)
instance_shape = input_data.get("instance_shape")
if not instance_shape:
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
)
model_id = input_data.get("model_id")
if not model_id:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))

compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
project_id = input_data.get("project_id", PROJECT_OCID)
log_group_id = input_data.get("log_group_id")
access_log_id = input_data.get("access_log_id")
predict_log_id = input_data.get("predict_log_id")
description = input_data.get("description")
instance_count = input_data.get("instance_count")
bandwidth_mbps = input_data.get("bandwidth_mbps")
web_concurrency = input_data.get("web_concurrency")
server_port = input_data.get("server_port")
health_check_port = input_data.get("health_check_port")
env_var = input_data.get("env_var")
container_family = input_data.get("container_family")
ocpus = input_data.get("ocpus")
memory_in_gbs = input_data.get("memory_in_gbs")
model_file = input_data.get("model_file")
private_endpoint_id = input_data.get("private_endpoint_id")
container_image_uri = input_data.get("container_image_uri")
cmd_var = input_data.get("cmd_var")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

self.finish(
AquaDeploymentApp().create(
compartment_id=compartment_id,
project_id=project_id,
model_id=model_id,
display_name=display_name,
description=description,
instance_count=instance_count,
instance_shape=instance_shape,
log_group_id=log_group_id,
access_log_id=access_log_id,
predict_log_id=predict_log_id,
bandwidth_mbps=bandwidth_mbps,
web_concurrency=web_concurrency,
server_port=server_port,
health_check_port=health_check_port,
env_var=env_var,
container_family=container_family,
ocpus=ocpus,
memory_in_gbs=memory_in_gbs,
model_file=model_file,
private_endpoint_id=private_endpoint_id,
container_image_uri=container_image_uri,
cmd_var=cmd_var,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)
self.finish(AquaDeploymentApp().create(**input_data))

def read(self, id):
"""Read the information of an Aqua model deployment."""
Expand Down
Loading

0 comments on commit f5c3697

Please sign in to comment.