Skip to content

Commit

Permalink
fix: save preproc + thresh model with conf keys (#107)
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 authored Mar 21, 2023
1 parent a6e6df5 commit 8b2f20e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
10 changes: 8 additions & 2 deletions numaprom/udf/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from numaprom import get_logger
from numaprom.entities import Status, StreamPayload, Header
from numaprom.tools import msg_forward, load_model
from numaprom.tools import msg_forward, load_model, get_metric_config

_LOGGER = get_logger(__name__)

Expand All @@ -18,10 +18,16 @@ def preprocess(_: str, datum: Datum) -> bytes:
payload = StreamPayload(**orjson.loads(_in_msg))
_LOGGER.info("%s - Received Payload: %r ", payload.uuid, payload)

# Load config
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
preprocess_cfgs = metric_config.numalogic_conf.preprocess

# Load preprocess artifact
preproc_artifact = load_model(
skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]],
dkeys=["preproc"],
dkeys=[_conf.name for _conf in preprocess_cfgs],
artifact_type="sklearn",
)
if not preproc_artifact:
Expand Down
3 changes: 2 additions & 1 deletion numaprom/udf/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]:
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
thresh_cfg = metric_config.numalogic_conf.threshold

# Check if payload needs static inference
if payload.header == Header.STATIC_INFERENCE:
Expand All @@ -63,7 +64,7 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]:
# load threshold artifact
thresh_artifact = load_model(
skeys=[payload.composite_keys["namespace"], payload.composite_keys["name"]],
dkeys=["thresh"],
dkeys=[thresh_cfg.name],
artifact_type="sklearn",
)
if not thresh_artifact:
Expand Down
18 changes: 11 additions & 7 deletions numaprom/udsink/train_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def _train_model(uuid, x, model_cfg, trainer_cfg):
return train_reconerr.numpy(), model, trainer


def _preprocess(x_raw, preproc_cfg: List[ModelInfo]):
def _preprocess(x_raw, preproc_cfgs: List[ModelInfo]):
preproc_factory = PreprocessFactory()
preproc_clfs = []
for _cfg in preproc_cfg:
for _cfg in preproc_cfgs:
_clf = preproc_factory.get_instance(_cfg)
preproc_clfs.append(_clf)
preproc_pl = make_pipeline(*preproc_clfs)
Expand Down Expand Up @@ -154,8 +154,8 @@ def train_rollout(datums: List[Datum]) -> Responses:
responses.append(Response.as_success(_datum.id))
continue

preproc_cfg = metric_config.numalogic_conf.preprocess
x_train, preproc_clf = _preprocess(train_df.to_numpy(), preproc_cfg)
preproc_cfgs = metric_config.numalogic_conf.preprocess
x_train, preproc_clf = _preprocess(train_df.to_numpy(), preproc_cfgs)

trainer_cfg = metric_config.numalogic_conf.trainer
x_reconerr, anomaly_model, trainer = _train_model(
Expand All @@ -173,7 +173,11 @@ def train_rollout(datums: List[Datum]) -> Responses:

# Save main model
version = save_model(
skeys=skeys, dkeys=[model_cfg.name], model=anomaly_model, uuid=payload.uuid
skeys=skeys,
dkeys=[model_cfg.name],
model=anomaly_model,
uuid=payload.uuid,
train_size=train_df.shape[0],
)
if version:
_LOGGER.info(
Expand All @@ -185,7 +189,7 @@ def train_rollout(datums: List[Datum]) -> Responses:
# Save preproc model
version = save_model(
skeys=skeys,
dkeys=["preproc"],
dkeys=[_conf.name for _conf in preproc_cfgs],
model=preproc_clf,
artifact_type="sklearn",
uuid=payload.uuid,
Expand All @@ -205,7 +209,7 @@ def train_rollout(datums: List[Datum]) -> Responses:
# Save threshold model
version = save_model(
skeys=skeys,
dkeys=["thresh"],
dkeys=[thresh_cfg.name],
model=thresh_clf,
artifact_type="sklearn",
uuid=payload.uuid,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic-prometheus"
version = "0.2.3"
version = "0.2.4"
description = "ML inference on numaflow using numalogic on Prometheus metrics"
authors = ["Numalogic developers"]
packages = [{ include = "numaprom" }]
Expand Down

0 comments on commit 8b2f20e

Please sign in to comment.