Skip to content

Publish, performance monitor, and score test model versions #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions src/sasctl/_services/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ class ModelManagement(Service):
# TODO: set ds2MultiType
@classmethod
def publish_model(
cls, model, destination, name=None, force=False, reload_model_table=False
cls,
model,
destination,
model_version="latest",
name=None,
force=False,
reload_model_table=False,
):
"""

Expand All @@ -38,6 +44,8 @@ def publish_model(
The name or id of the model, or a dictionary representation of the model.
destination : str
Name of destination to publish the model to.
model_version_id : str or dict, optional
Provide the id, name, or dictionary representation of the version to publish. Defaults to 'latest'.
name : str, optional
Provide a custom name for the published model. Defaults to None.
force : bool, optional
Expand Down Expand Up @@ -68,6 +76,23 @@ def publish_model(

# TODO: Verify allowed formats by destination type.
# As of 19w04 MAS throws HTTP 500 if name is in invalid format.
if model_version != "latest":
if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version_name = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError("Model version is not recognized.")
elif isinstance(model_version, str) and cls.is_uuid(model_version):
model_version_name = mr.get_model_or_version(model, model_version)[
"modelVersionName"
]
else:
model_version_name = model_version
else:
model_version_name = ""

model_name = name or "{}_{}".format(
model_obj["name"].replace(" ", ""), model_obj["id"]
).replace("-", "")
Expand All @@ -79,6 +104,7 @@ def publish_model(
{
"modelName": mp._publish_name(model_name),
"sourceUri": model_uri.get("uri"),
"modelVersionID": model_version_name,
"publishLevel": "model",
}
],
Expand All @@ -104,6 +130,7 @@ def create_performance_definition(
table_prefix,
project=None,
models=None,
modelVersions=None,
library_name="Public",
name=None,
description=None,
Expand Down Expand Up @@ -136,6 +163,9 @@ def create_performance_definition(
The name or id of the model(s), or a dictionary representation of the model(s). For
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
models in the project specified will be used. Defaults to None.
modelVersions: str, list, optional
The name of the model version(s) for models used in the performance definition. If no model versions
are specified, all models will use the latest version. Defaults to None.
library_name : str
The library containing the input data, default is 'Public'.
name : str, optional
Expand Down Expand Up @@ -238,11 +268,44 @@ def create_performance_definition(
"Project %s must have the 'predictionVariable' "
"property set." % project.name
)
print("sup")
if not modelVersions:
updated_models = [model.id for model in models]
else:
updated_models = []
if not isinstance(modelVersions, list):
modelVersions = [modelVersions]

if len(models) < len(modelVersions):
raise ValueError(
"There are too many versions for the amount of models specified."
)

modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
for model, modelVersionName in zip(models, modelVersions):

if (
isinstance(modelVersionName, dict)
and "modelVersionName" in modelVersionName
):

modelVersionName = modelVersionName["modelVersionName"]
elif (
isinstance(modelVersionName, dict)
and "modelVersionName" not in modelVersionName
):

raise ValueError("Model version is not recognized.")

if modelVersionName != "":
updated_models.append(model.id + ":" + modelVersionName)
else:
updated_models.append(model.id)

request = {
"projectId": project.id,
"name": name or project.name + " Performance",
"modelIds": [model.id for model in models],
"modelIds": updated_models,
"championMonitored": monitor_champion,
"challengerMonitored": monitor_challenger,
"maxBins": max_bins,
Expand Down Expand Up @@ -279,7 +342,6 @@ def create_performance_definition(
for v in project.get("variables", [])
if v.get("role") == "output"
]

return cls.post(
"/performanceTasks",
json=request,
Expand Down
3 changes: 2 additions & 1 deletion src/sasctl/_services/model_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .model_repository import ModelRepository
from .service import Service
from ..utils.decorators import deprecated


class ModelPublish(Service):
Expand Down Expand Up @@ -90,7 +91,7 @@ def delete_destination(cls, item):

return cls.delete("/destinations/{name}".format(name=item))

@classmethod
@deprecated("Use publish_model in model_management.py instead.", "1.11.5")
def publish_model(cls, model, destination, name=None, code=None, notes=None):
"""Publish a model to an existing publishing destination.

Expand Down
36 changes: 30 additions & 6 deletions src/sasctl/_services/score_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def create_score_definition(
library_name: str, optional
The library within the CAS server the table exists in. Defaults to "Public".
model_version: str, optional
The user-chosen version of the model with the specified model_id. Defaults to "latest".
The user-chosen version of the model with the specified model version name. Defaults to latest version.

Returns
-------
Expand Down Expand Up @@ -116,7 +116,7 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table and not table_file:
raise HTTPError(
f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
)
elif not table and table_file:
cls._cas_management.upload_file(
Expand All @@ -125,16 +125,40 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table:
raise HTTPError(
f"The file failed to upload properly or another error occurred."
"The file failed to upload properly or another error occurred."
)
# Checks if the inputted table exists, and if not, uploads a file to create a new table

if model_version != "latest":

if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError(
"Model version cannot be found. Please check the inputted model version."
)
elif isinstance(model_version, str) and cls.is_uuid(model_version):
print("hello")
model_version = cls._model_repository.get_model_or_version(
model_id, model_version
)["modelVersionName"]
else:
model_version = model_version

object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}"

else:
object_uri = f"/modelManagement/models/{model_id}"

save_score_def = {
"name": model_name, # used to be score_def_name
"description": description,
"objectDescriptor": {
"uri": f"/modelManagement/models/{model_id}",
"name": f"{model_name}({model_version})",
"uri": object_uri,
"name": f"{model_name} ({model_version})",
"type": f"{object_descriptor_type}",
},
"inputData": {
Expand All @@ -149,7 +173,7 @@ def create_score_definition(
"projectUri": f"/modelRepository/projects/{model_project_id}",
"projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}",
"publishDestination": "",
"versionedModel": f"{model_name}({model_version})",
"versionedModel": f"{model_name} ({model_version})",
},
"mappings": inputMapping,
}
Expand Down
150 changes: 122 additions & 28 deletions tests/unit/test_model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def test_create_performance_definition():
RestObj({"name": "Test Model 2", "id": "67890", "projectId": PROJECT["id"]}),
]
USER = "username"
VERSION_MOCK = {"modelVersionName": "1.0"}
VERSION_MOCK_NONAME = {}

with mock.patch("sasctl.core.Session._get_authorization_token"):
current_session("example.com", USER, "password")
Expand Down Expand Up @@ -111,6 +113,32 @@ def test_create_performance_definition():
table_prefix="TestData",
)

with pytest.raises(ValueError):
# Model verions exceeds models
get_model.side_effect = copy.deepcopy(MODELS)
_ = mm.create_performance_definition(
models=["model1", "model2"],
modelVersions=["1.0", "2.0", "3.0"],
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)

with pytest.raises(ValueError):
# Model version dictionary missing modelVersionName
get_model.side_effect = copy.deepcopy(MODELS)
_ = mm.create_performance_definition(
models=["model1", "model2"],
modelVersions=VERSION_MOCK_NONAME,
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)

get_project.return_value = copy.deepcopy(PROJECT)
get_project.return_value["targetVariable"] = "target"
get_project.return_value["targetLevel"] = "interval"
Expand All @@ -125,21 +153,68 @@ def test_create_performance_definition():
monitor_challenger=True,
monitor_champion=True,
)
url, data = post_models.call_args
assert post_models.call_count == 1
assert PROJECT["id"] == data["json"]["projectId"]
assert MODELS[0]["id"] in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]
assert "TestLibrary" == data["json"]["dataLibrary"]
assert "TestData" == data["json"]["dataPrefix"]
assert "cas-shared-default" == data["json"]["casServerId"]
assert data["json"]["name"]
assert data["json"]["description"]
assert data["json"]["maxBins"] == 3
assert data["json"]["championMonitored"] is True
assert data["json"]["challengerMonitored"] is True

assert post_models.call_count == 1
url, data = post_models.call_args

assert PROJECT["id"] == data["json"]["projectId"]
assert MODELS[0]["id"] in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]
assert "TestLibrary" == data["json"]["dataLibrary"]
assert "TestData" == data["json"]["dataPrefix"]
assert "cas-shared-default" == data["json"]["casServerId"]
assert data["json"]["name"]
assert data["json"]["description"]
assert data["json"]["maxBins"] == 3
assert data["json"]["championMonitored"] is True
assert data["json"]["challengerMonitored"] is True
get_model.side_effect = copy.deepcopy(MODELS)
_ = mm.create_performance_definition(
# One model version as a string name
models=["model1", "model2"],
modelVersions="1.0",
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)

assert post_models.call_count == 2
url, data = post_models.call_args
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]

get_model.side_effect = copy.deepcopy(MODELS)
# List of string type model versions
_ = mm.create_performance_definition(
models=["model1", "model2"],
modelVersions=["1.0", "2.0"],
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)
assert post_models.call_count == 3
url, data = post_models.call_args
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"]

get_model.side_effect = copy.deepcopy(MODELS)
# List of dictionary type and string type model versions
_ = mm.create_performance_definition(
models=["model1", "model2"],
modelVersions=[VERSION_MOCK, "2.0"],
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)
assert post_models.call_count == 4
url, data = post_models.call_args
assert f"{MODELS[0]['id']}:1.0" in data["json"]["modelIds"]
assert f"{MODELS[1]['id']}:2.0" in data["json"]["modelIds"]

with mock.patch(
"sasctl._services.model_management.ModelManagement" ".post"
Expand All @@ -160,20 +235,39 @@ def test_create_performance_definition():
monitor_champion=True,
)

assert post_project.call_count == 1
url, data = post_project.call_args

assert PROJECT["id"] == data["json"]["projectId"]
assert MODELS[0]["id"] in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]
assert "TestLibrary" == data["json"]["dataLibrary"]
assert "TestData" == data["json"]["dataPrefix"]
assert "cas-shared-default" == data["json"]["casServerId"]
assert data["json"]["name"]
assert data["json"]["description"]
assert data["json"]["maxBins"] == 3
assert data["json"]["championMonitored"] is True
assert data["json"]["challengerMonitored"] is True
# one extra test for project with version id

assert post_project.call_count == 1
url, data = post_project.call_args

assert PROJECT["id"] == data["json"]["projectId"]
assert MODELS[0]["id"] in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]
assert "TestLibrary" == data["json"]["dataLibrary"]
assert "TestData" == data["json"]["dataPrefix"]
assert "cas-shared-default" == data["json"]["casServerId"]
assert data["json"]["name"]
assert data["json"]["description"]
assert data["json"]["maxBins"] == 3
assert data["json"]["championMonitored"] is True
assert data["json"]["challengerMonitored"] is True

get_model.side_effect = copy.deepcopy(MODELS)
# Project with model version
_ = mm.create_performance_definition(
project="project",
modelVersions="2.0",
library_name="TestLibrary",
table_prefix="TestData",
max_bins=3,
monitor_challenger=True,
monitor_champion=True,
)

assert post_project.call_count == 2
url, data = post_project.call_args
assert f"{MODELS[0]['id']}:2.0" in data["json"]["modelIds"]
assert MODELS[1]["id"] in data["json"]["modelIds"]

def test_table_prefix_format():
with pytest.raises(ValueError):
Expand Down
Loading
Loading