From 8b2f20ebe1321ce51f6fd7241c1f70771fe9971c Mon Sep 17 00:00:00 2001 From: Avik Basu <3485425+ab93@users.noreply.github.com> Date: Tue, 21 Mar 2023 15:27:01 -0700 Subject: [PATCH] fix: save preproc + thresh model with conf keys (#107) Signed-off-by: Avik Basu --- numaprom/udf/preprocess.py | 10 ++++++++-- numaprom/udf/threshold.py | 3 ++- numaprom/udsink/train_rollout.py | 18 +++++++++++------- pyproject.toml | 2 +- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/numaprom/udf/preprocess.py b/numaprom/udf/preprocess.py index 571326c..25c6e0d 100644 --- a/numaprom/udf/preprocess.py +++ b/numaprom/udf/preprocess.py @@ -5,7 +5,7 @@ from numaprom import get_logger from numaprom.entities import Status, StreamPayload, Header -from numaprom.tools import msg_forward, load_model +from numaprom.tools import msg_forward, load_model, get_metric_config _LOGGER = get_logger(__name__) @@ -18,10 +18,16 @@ def preprocess(_: str, datum: Datum) -> bytes: payload = StreamPayload(**orjson.loads(_in_msg)) _LOGGER.info("%s - Received Payload: %r ", payload.uuid, payload) + # Load config + metric_config = get_metric_config( + metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"] + ) + preprocess_cfgs = metric_config.numalogic_conf.preprocess + # Load preprocess artifact preproc_artifact = load_model( skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]], - dkeys=["preproc"], + dkeys=[_conf.name for _conf in preprocess_cfgs], artifact_type="sklearn", ) if not preproc_artifact: diff --git a/numaprom/udf/threshold.py b/numaprom/udf/threshold.py index 6bb75ca..41c347f 100644 --- a/numaprom/udf/threshold.py +++ b/numaprom/udf/threshold.py @@ -47,6 +47,7 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]: metric_config = get_metric_config( metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"] ) + thresh_cfg = metric_config.numalogic_conf.threshold # Check if payload needs static inference if payload.header == Header.STATIC_INFERENCE: @@ -63,7 +64,7 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]: # load threshold artifact thresh_artifact = load_model( skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]], - dkeys=["thresh"], + dkeys=[thresh_cfg.name], artifact_type="sklearn", ) if not thresh_artifact: diff --git a/numaprom/udsink/train_rollout.py b/numaprom/udsink/train_rollout.py index ad7d54f..840b8cf 100644 --- a/numaprom/udsink/train_rollout.py +++ b/numaprom/udsink/train_rollout.py @@ -68,10 +68,10 @@ def _train_model(uuid, x, model_cfg, trainer_cfg): return train_reconerr.numpy(), model, trainer -def _preprocess(x_raw, preproc_cfg: List[ModelInfo]): +def _preprocess(x_raw, preproc_cfgs: List[ModelInfo]): preproc_factory = PreprocessFactory() preproc_clfs = [] - for _cfg in preproc_cfg: + for _cfg in preproc_cfgs: _clf = preproc_factory.get_instance(_cfg) preproc_clfs.append(_clf) preproc_pl = make_pipeline(*preproc_clfs) @@ -154,8 +154,8 @@ def train_rollout(datums: List[Datum]) -> Responses: responses.append(Response.as_success(_datum.id)) continue - preproc_cfg = metric_config.numalogic_conf.preprocess - x_train, preproc_clf = _preprocess(train_df.to_numpy(), preproc_cfg) + preproc_cfgs = metric_config.numalogic_conf.preprocess + x_train, preproc_clf = _preprocess(train_df.to_numpy(), preproc_cfgs) trainer_cfg = metric_config.numalogic_conf.trainer x_reconerr, anomaly_model, trainer = _train_model( @@ -173,7 +173,11 @@ def train_rollout(datums: List[Datum]) -> Responses: # Save main model version = save_model( - skeys=skeys, dkeys=[model_cfg.name], model=anomaly_model, uuid=payload.uuid + skeys=skeys, + dkeys=[model_cfg.name], + model=anomaly_model, + uuid=payload.uuid, + train_size=train_df.shape[0], ) if version: _LOGGER.info( @@ -185,7 +189,7 @@ def train_rollout(datums: List[Datum]) -> Responses: # Save preproc model version = save_model( skeys=skeys, - dkeys=["preproc"], + dkeys=[_conf.name for _conf in preproc_cfgs], model=preproc_clf, artifact_type="sklearn", uuid=payload.uuid, @@ -205,7 +209,7 @@ def train_rollout(datums: List[Datum]) -> Responses: # Save threshold model version = save_model( skeys=skeys, - dkeys=["thresh"], + dkeys=[thresh_cfg.name], model=thresh_clf, artifact_type="sklearn", uuid=payload.uuid, diff --git a/pyproject.toml b/pyproject.toml index bc73f5c..604f0db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic-prometheus" -version = "0.2.3" +version = "0.2.4" description = "ML inference on numaflow using numalogic on Prometheus metrics" authors = ["Numalogic developers"] packages = [{ include = "numaprom" }]