Skip to content

Commit

Permalink
chore: use pzmm for open source models
Browse files Browse the repository at this point in the history
  • Loading branch information
jlwalke2 committed Aug 14, 2024
1 parent 4b0016b commit b5c0e06
Showing 1 changed file with 117 additions and 80 deletions.
197 changes: 117 additions & 80 deletions src/sasctl/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,88 +458,125 @@ def register_model(
)
return model

# If the model is a scikit-learn model, generate the model dictionary
# from it and pickle the model for storage
if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
# Pickle the model so we can store it
model_pkl = pickle.dumps(model)
files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})

target_funcs = [f for f in ("predict", "predict_proba") if hasattr(model, f)]

# Extract model properties
model = _sklearn_to_dict(model)
model["name"] = name

# Get package versions in environment
packages = installed_packages()
if record_packages and packages is not None:
model.setdefault("properties", [])

# Define a custom property to capture each package version
# NOTE: some packages may not conform to the 'name==version' format
# expected here (e.g those installed with pip install -e). Such
# packages also generally contain characters that are not allowed
# in custom properties, so they are excluded here.
for p in packages:
if "==" in p:
n, v = p.split("==")
model["properties"].append(_property("env_%s" % n, v))

# Generate and upload a requirements.txt file
files.append({"name": "requirements.txt", "file": "\n".join(packages)})

# Generate PyMAS wrapper
if not isinstance(model, dict):
try:
mas_module = from_pickle(
model_pkl, target_funcs, input_types=input, array_input=True
)

# Include score code files from ESP and MAS
files.append(
{
"name": "dmcas_packagescorecode.sas",
"file": mas_module.score_code(),
"role": "Score Code",
}
)
files.append(
{
"name": "dmcas_epscorecode.sas",
"file": mas_module.score_code(dest="CAS"),
"role": "score",
}
)
files.append(
{
"name": "python_wrapper.py",
"file": mas_module.score_code(dest="Python"),
}
)
info = utils.get_model_info(model, X=input)
except ValueError as e:
logger.debug("Model of type %s could not be inspected: %s", type(model), e)
raise

from .pzmm import ImportModel, JSONFiles, PickleModel

pzmm_files = {}

pickled_model = PickleModel.pickle_trained_model("model", info.model)
# Returns dict with "prefix.pickle": bytes
assert len(pickled_model) == 1

input_vars = JSONFiles.write_var_json(info.X, is_input=True)
output_vars = JSONFiles.write_var_json(info.y, is_input=False)
metadata = JSONFiles.write_file_metadata_json(model_prefix=name)
properties = JSONFiles.write_model_properties_json(
model_name=name,
model_desc=info.description,
model_algorithm=info.algorithm,
target_variable=info.target_column,
target_values=info.target_values
)

model["inputVariables"] = [
var.as_model_metadata() for var in mas_module.variables if not var.out
]

model["outputVariables"] = [
var.as_model_metadata() for var in mas_module.variables if var.out
]
except ValueError:
# PyMAS creation failed, most likely because input data wasn't
# provided
logger.exception("Unable to inspect model %s", model)

warn(
"Unable to determine input/output variables. "
" Model variables will not be specified and some "
"model functionality may not be available."
)
else:
# Otherwise, the model better be a dictionary of metadata
if not isinstance(model, dict):
raise TypeError(
"Expected an instance of '%r' but received '%r'." % ({}, model)
)
pzmm_files.update(pickled_model)
pzmm_files.update(input_vars)
pzmm_files.update(output_vars)
pzmm_files.update(metadata)
pzmm_files.update(properties)

model_obj, _ = ImportModel.import_model(pzmm_files, name, project)
return model_obj

# # If the model is a scikit-learn model, generate the model dictionary
# # from it and pickle the model for storage
# if all(hasattr(model, attr) for attr in ["_estimator_type", "get_params"]):
# # Pickle the model so we can store it
# model_pkl = pickle.dumps(model)
# files.append({"name": "model.pkl", "file": model_pkl, "role": "Python Pickle"})
#
# target_funcs = [f for f in ("predict", "predict_proba") if hasattr(model, f)]
#
# # Extract model properties
# model = _sklearn_to_dict(model)
# model["name"] = name
#
# # Get package versions in environment
# packages = installed_packages()
# if record_packages and packages is not None:
# model.setdefault("properties", [])
#
# # Define a custom property to capture each package version
# # NOTE: some packages may not conform to the 'name==version' format
# # expected here (e.g those installed with pip install -e). Such
# # packages also generally contain characters that are not allowed
# # in custom properties, so they are excluded here.
# for p in packages:
# if "==" in p:
# n, v = p.split("==")
# model["properties"].append(_property("env_%s" % n, v))
#
# # Generate and upload a requirements.txt file
# files.append({"name": "requirements.txt", "file": "\n".join(packages)})
#
# # Generate PyMAS wrapper
# try:
# mas_module = from_pickle(
# model_pkl, target_funcs, input_types=input, array_input=True
# )
#
# # Include score code files from ESP and MAS
# files.append(
# {
# "name": "dmcas_packagescorecode.sas",
# "file": mas_module.score_code(),
# "role": "Score Code",
# }
# )
# files.append(
# {
# "name": "dmcas_epscorecode.sas",
# "file": mas_module.score_code(dest="CAS"),
# "role": "score",
# }
# )
# files.append(
# {
# "name": "python_wrapper.py",
# "file": mas_module.score_code(dest="Python"),
# }
# )
#
# model["inputVariables"] = [
# var.as_model_metadata() for var in mas_module.variables if not var.out
# ]
#
# model["outputVariables"] = [
# var.as_model_metadata() for var in mas_module.variables if var.out
# ]
# except ValueError:
# # PyMAS creation failed, most likely because input data wasn't
# # provided
# logger.exception("Unable to inspect model %s", model)
#
# warn(
# "Unable to determine input/output variables. "
# " Model variables will not be specified and some "
# "model functionality may not be available."
# )
# else:
# # Otherwise, the model better be a dictionary of metadata
# if not isinstance(model, dict):
# raise TypeError(
# "Expected an instance of '%r' but received '%r'." % ({}, model)
# )

# If we got this far, then `model` is a dictionary of model metadata.

if create_project:
project = _create_project(project, model, repo_obj)
Expand Down

0 comments on commit b5c0e06

Please sign in to comment.