diff --git a/numalogic/config/__init__.py b/numalogic/config/__init__.py index c9bb5e67..dd7e8aff 100644 --- a/numalogic/config/__init__.py +++ b/numalogic/config/__init__.py @@ -18,6 +18,7 @@ TrainerConf, ScoreConf, AggregatorConf, + ScoreAdjustConf, ) from numalogic.config.factory import ( ModelFactory, @@ -41,6 +42,7 @@ "RegistryFactory", "TrainerConf", "ScoreConf", + "ScoreAdjustConf", "AggregatorConf", "AggregatorFactory", ] diff --git a/numalogic/tools/adjust.py b/numalogic/tools/adjust.py deleted file mode 100644 index 32786c67..00000000 --- a/numalogic/tools/adjust.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2022 The Numaproj Authors. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import numpy.typing as npt - -from numalogic.models.threshold import SigmoidThreshold - - -class ScoreAdjuster: - """ - Adjusts the model output based on the metric input. - - Args: - ---- - adjust_weight: weight given to static thresholding output - (between 0 and 1) - metric_weights: weights given to each kpi/metric - upper_limits: upper limits for each metric - kwargs: kwargs for SigmoidThreshold - - Raises - ------ - ValueError: if adjust_weight is not between 0 and 1 - """ - - __slots__ = ("_adjust_wt", "_kpi_wts", "_thresholder") - - def __init__( - self, adjust_weight: float, metric_weights: list[float], upper_limits: list[float], **kwargs - ): - if adjust_weight <= 0.0 or adjust_weight >= 1: - raise ValueError("adjust_weight needs to be between 0 and 1") - self._adjust_wt = adjust_weight - self._kpi_wts = np.asarray(metric_weights, dtype=np.float32).reshape(-1, 1) - self._thresholder = SigmoidThreshold(*upper_limits, **kwargs) - - def adjust( - self, metric_in: npt.NDArray[float], model_scores: npt.NDArray[float] - ) -> npt.NDArray[float]: - """ - Adjusts the model output based on the metric input. - - Args: - ---- - metric_in: metric input to the model - model_scores: model output scores - - Returns - ------- - adjusted_scores: adjusted scores - """ - model_scores = np.reshape(-1, 1) - feature_scores = self._thresholder.score_samples(metric_in) - weighted_scores = np.dot(feature_scores, self._kpi_wts) - return (self._adjust_wt * weighted_scores) + ((1 - self._adjust_wt) * model_scores) - - # @classmethod - # def from_conf(cls, conf: ScoreAdjustConf) -> Self: - # """ - # Creates an instance of ScoreAdjuster from ScoreAdjustConf. - # - # Args: - # ---- - # conf: ScoreAdjustConf - # - # Returns - # ------- - # ScoreAdjuster instance - # """ - # return cls(conf.weight, conf.metric_weights, conf.upper_limits) diff --git a/numalogic/udfs/inference.py b/numalogic/udfs/inference.py index 29c80027..94a073a7 100644 --- a/numalogic/udfs/inference.py +++ b/numalogic/udfs/inference.py @@ -23,7 +23,12 @@ _increment_counter, ) from numalogic.udfs.entities import StreamPayload, Status -from numalogic.udfs.tools import _load_artifact, _update_info_metric, get_trainer_message +from numalogic.udfs.tools import ( + _load_artifact, + _update_info_metric, + get_trainer_message, + get_static_thresh_message, +) _LOGGER = logging.getLogger(__name__) @@ -131,7 +136,10 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # Send training request if artifact loading is not successful if not artifact_data: - return Messages(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) + msgs = Messages(get_trainer_message(keys, _stream_conf, payload)) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs # Perform inference try: @@ -145,7 +153,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: payload.composite_keys, payload.metrics, ) - return Messages(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) + # Send training request if inference fails + msgs = Messages(get_trainer_message(keys, _stream_conf, payload)) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs msgs = Messages() status = ( @@ -162,6 +174,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: **payload.metadata, }, ) + # Send trainer message if artifact is stale if status == Status.ARTIFACT_STALE: msgs.append(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) diff --git a/numalogic/udfs/postprocess.py b/numalogic/udfs/postprocess.py index 0ddb844f..16a5c9a6 100644 --- a/numalogic/udfs/postprocess.py +++ b/numalogic/udfs/postprocess.py @@ -30,7 +30,7 @@ _increment_counter, ) from numalogic.udfs.entities import StreamPayload, Header, Status, OutputPayload -from numalogic.udfs.tools import _load_artifact, get_trainer_message +from numalogic.udfs.tools import _load_artifact, get_trainer_message, get_static_thresh_message # TODO: move to config LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", "3600")) @@ -123,7 +123,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: payload = replace( payload, status=Status.ARTIFACT_NOT_FOUND, header=Header.TRAIN_REQUEST ) - return Messages(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) + # Send training request if artifact loading is not successful + msgs = Messages(get_trainer_message(keys, _stream_conf, payload)) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs if payload.header == Header.STATIC_INFERENCE: _LOGGER.warning("Static inference not supported in postprocess yet") @@ -154,7 +158,11 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: payload.composite_keys, payload.metrics, ) - return Messages(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) + # Send training request if postprocess fails + msgs = Messages(get_trainer_message(keys, _stream_conf, payload)) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs payload = replace( payload, diff --git a/numalogic/udfs/preprocess.py b/numalogic/udfs/preprocess.py index 4ed16094..7638a0d0 100644 --- a/numalogic/udfs/preprocess.py +++ b/numalogic/udfs/preprocess.py @@ -32,6 +32,7 @@ _load_artifact, _update_info_metric, get_trainer_message, + get_static_thresh_message, ) # TODO: move to config @@ -136,7 +137,6 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # Drop message if dataframe shape conditions are not met if raw_df.shape[0] < _stream_conf.window_size or raw_df.shape[1] != len(_conf.metrics): _LOGGER.critical("Dataframe shape: (%f, %f) error ", raw_df.shape[0], raw_df.shape[1]) - print(_metric_label_values) _increment_counter( counter=DATASHAPE_ERROR_COUNTER, labels=_metric_label_values, @@ -174,7 +174,10 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: ) payload = replace(payload, status=Status.ARTIFACT_FOUND) else: - return Messages(get_trainer_message(keys, _stream_conf, payload)) + msgs = Messages(get_trainer_message(keys, _stream_conf, payload)) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs # Model will not be in registry else: # Load configuration for the config_id @@ -220,7 +223,13 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: payload, status=Status.RUNTIME_ERROR, ) - return Messages(get_trainer_message(keys, _stream_conf, payload, *_metric_label_values)) + msgs = Messages( + get_trainer_message(keys, _stream_conf, payload, *_metric_label_values), + ) + if _conf.numalogic_conf.score.adjust: + msgs.append(get_static_thresh_message(keys, payload)) + return msgs + _increment_counter( counter=MSG_PROCESSED_COUNTER, labels=_metric_label_values, diff --git a/numalogic/udfs/staticthresh.py b/numalogic/udfs/staticthresh.py index e69de29b..d2d02a5b 100644 --- a/numalogic/udfs/staticthresh.py +++ b/numalogic/udfs/staticthresh.py @@ -0,0 +1,170 @@ +import logging + +from orjson import orjson +from pynumaflow.mapper import Datum, Messages, Message + +from numalogic.config import AggregatorFactory, ScoreAdjustConf, AggregatorConf +from numalogic.models.threshold import SigmoidThreshold +from numalogic.tools.aggregators import aggregate_window, aggregate_features +from numalogic.udfs import NumalogicUDF, PipelineConf +import numpy.typing as npt + +from numalogic.udfs.entities import StreamPayload, OutputPayload + + +_LOGGER = logging.getLogger(__name__) + + +class StaticThresholdUDF(NumalogicUDF): + """ + Static thresholding UDF, which computes the static anomaly scores. + + Args: + pl_conf: PipelineConf object + """ + + def __init__(self, pl_conf: PipelineConf): + super().__init__(pl_conf=pl_conf, _vtx="staticthresh") + + def exec(self, keys: list[str], datum: Datum) -> Messages: + """ + Processes the input data and computes the static anomaly scores. + + Args: + ------- + keys: List of keys + datum: Datum object. + + Returns + ------- + Messages instance + """ + payload = StreamPayload(**orjson.loads(datum.value)) + conf = self.get_ml_pipeline_conf(payload.config_id, payload.pipeline_id) + adjust_conf = conf.numalogic_conf.score.adjust + + if not adjust_conf: + _LOGGER.warning( + "%s - No score adjust config found for config_id: %s, pipeline_id: %s", + ) + return Messages(Message.to_drop()) + + try: + y_features = self.compute( + input_=payload.get_data(original=True, metrics=list(adjust_conf.upper_limits)), + adjust_conf=adjust_conf, + ) + y_unified = self.compute_unified_score(y_features, adjust_conf.feature_agg) + except RuntimeError: + _LOGGER.exception( + "%s - Error occurred while computing static anomaly scores", + payload.uuid, + ) + return Messages(Message.to_drop()) + + out_payload = OutputPayload( + uuid=payload.uuid, + config_id=payload.config_id, + pipeline_id=payload.pipeline_id, + composite_keys=payload.composite_keys, + timestamp=payload.end_ts, + unified_anomaly=y_unified, + data=self._additional_scores(payload.metrics, y_features, y_unified), + metadata=payload.metadata, + ) + return Messages(Message(keys=keys, value=out_payload.to_json(), tags=["output"])) + + @staticmethod + def _additional_scores( + feat_names: list[str], y_features: npt.NDArray[float], y_unified: float + ) -> dict[str, float]: + """ + Additional scores to be computed. + + Args: + ------- + feat_names: List of feature names + y_features: Anomaly scores + y_unified: Unified anomaly score + + Returns + ------- + Additional scores + """ + scores_payload = {"unified_ST": y_unified} + if (scores_len := len(y_features)) == len(feat_names): + _LOGGER.debug( + "Scores length: %s does not match feat_names: %s", + scores_len, + feat_names, + ) + scores_payload |= dict(zip(feat_names, y_features)) + return scores_payload + + @classmethod + def compute( + cls, input_: npt.NDArray[float], adjust_conf: ScoreAdjustConf, **_ + ) -> npt.NDArray[float]: + """ + Compute static thresholding over the raw input features. + + Args: + input_: Input data + adjust_conf: Score adjust Config + + Returns + ------- + npt.NDArray[float] + """ + scorer = SigmoidThreshold(*adjust_conf.upper_limits.values()) + try: + return cls.compute_feature_scores(scorer.score_samples(input_), adjust_conf.window_agg) + except Exception as err: + raise RuntimeError("Static Thresholding failed!") from err + + @classmethod + def compute_feature_scores( + cls, scores: npt.NDArray[float], win_agg_conf: AggregatorConf + ) -> npt.NDArray[float]: + """ + Aggregate scores over the window length. + + Args: + ------- + scores: anomaly scores (Shape: seq_len x n_features) + win_agg_conf: Window aggregator Config + + Returns + ------- + Aggregated scores of shape (n_features,) + """ + return aggregate_window( + scores, + agg_func=AggregatorFactory.get_func(win_agg_conf.method), + **win_agg_conf.conf, + ) + + @classmethod + def compute_unified_score( + cls, scores: npt.NDArray[float], feat_agg_conf: AggregatorConf + ) -> float: + """ + Aggregate scores over the features to get a unified score. + + Args: + ------- + scores: anomaly scores (Shape: n_features, ) + feat_agg_conf: Feature aggregator Config + + Returns + ------- + Unified score (float) + """ + try: + return aggregate_features( + scores.reshape(1, -1), + agg_func=AggregatorFactory.get_func(feat_agg_conf.method), + **feat_agg_conf.conf, + ).item() + except Exception as err: + raise RuntimeError("Unified Score computation failed!") from err diff --git a/numalogic/udfs/tools.py b/numalogic/udfs/tools.py index 856875ee..4384ba44 100644 --- a/numalogic/udfs/tools.py +++ b/numalogic/udfs/tools.py @@ -454,3 +454,25 @@ def get_trainer_message( train_payload.composite_keys, ) return Message(keys=keys, value=train_payload.to_json(), tags=["train"]) + + +def get_static_thresh_message(keys: list[str], payload: StreamPayload) -> Message: + """ + Get message for static thresholding request. + + Args: + ------- + keys: List of keys + stream_conf: StreamConf instance + payload: StreamPayload object + + Returns + ------- + Mapper Message instance + """ + _LOGGER.info( + "%s - Sending static thresholding request for: %s", + payload.uuid, + payload.composite_keys, + ) + return Message(keys=keys, value=payload.to_json(), tags=["staticthresh"]) diff --git a/tests/udfs/test_inference.py b/tests/udfs/test_inference.py index 03a26c42..6ca1eb00 100644 --- a/tests/udfs/test_inference.py +++ b/tests/udfs/test_inference.py @@ -7,7 +7,14 @@ from orjson import orjson from pynumaflow.mapper import Datum -from numalogic.config import NumalogicConf, ModelInfo, TrainerConf, LightningTrainerConf +from numalogic.config import ( + NumalogicConf, + ModelInfo, + TrainerConf, + LightningTrainerConf, + ScoreConf, + ScoreAdjustConf, +) from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import RedisRegistry, ArtifactData from numalogic.tools.exceptions import RedisRegistryError @@ -101,6 +108,28 @@ def udf(): REDIS_CLIENT.flushall() +@pytest.fixture +def udf_with_adjust(): + udf = InferenceUDF(REDIS_CLIENT) + udf.register_conf( + "conf1", + StreamConf( + ml_pipelines={ + "pipeline1": MLPipelineConf( + pipeline_id="pipeline1", + numalogic_conf=NumalogicConf( + model=ModelInfo(name="VanillaAE", conf={"seq_len": 12, "n_features": 1}), + trainer=TrainerConf(pltrainer_conf=LightningTrainerConf(max_epochs=1)), + score=ScoreConf(adjust=ScoreAdjustConf(upper_limits={"failed": 20})), + ), + ) + } + ), + ) + yield udf + REDIS_CLIENT.flushall() + + @pytest.fixture() def udf_args(): return KEYS, Datum( @@ -194,7 +223,7 @@ def test_registry_error(udf, udf_args, mocker): assert msgs[0].tags == ["train"] -def test_compute_err(udf, udf_args, mocker): +def test_compute_err_01(udf, udf_args, mocker): mocker.patch.object( RedisRegistry, "load", @@ -212,7 +241,26 @@ def test_compute_err(udf, udf_args, mocker): assert msgs[0].tags == ["train"] -def test_model_pass_error(udf, udf_args, mocker): +def test_compute_err_02(udf_with_adjust, udf_args, mocker): + mocker.patch.object( + RedisRegistry, + "load", + return_value=ArtifactData( + artifact=VanillaAE(seq_len=12, n_features=2), + extras=dict(version="0", timestamp=time.time(), source="registry"), + metadata={}, + ), + ) + mocker.patch.object(InferenceUDF, "compute", side_effect=RuntimeError) + msgs = udf_with_adjust(*udf_args) + assert len(msgs) == 2 + payload = TrainerPayload(**orjson.loads(msgs[0].value)) + assert Header.TRAIN_REQUEST == payload.header + assert msgs[0].tags == ["train"] + assert msgs[1].tags == ["staticthresh"] + + +def test_model_pass_error_01(udf, udf_args, mocker): mocker.patch.object( RedisRegistry, "load", @@ -227,3 +275,21 @@ def test_model_pass_error(udf, udf_args, mocker): payload = TrainerPayload(**orjson.loads(msgs[0].value)) assert Header.TRAIN_REQUEST == payload.header assert msgs[0].tags == ["train"] + + +def test_model_pass_error_02(udf_with_adjust, udf_args, mocker): + mocker.patch.object( + RedisRegistry, + "load", + return_value=ArtifactData( + artifact=VanillaAE(seq_len=12, n_features=1), + extras=dict(version="0", timestamp=time.time(), source="registry"), + metadata={}, + ), + ) + msgs = udf_with_adjust(*udf_args) + assert len(msgs) == 2 + payload = TrainerPayload(**orjson.loads(msgs[0].value)) + assert Header.TRAIN_REQUEST == payload.header + assert msgs[0].tags == ["train"] + assert msgs[1].tags == ["staticthresh"] diff --git a/tests/udfs/test_postprocess.py b/tests/udfs/test_postprocess.py index f20977da..e369b94f 100644 --- a/tests/udfs/test_postprocess.py +++ b/tests/udfs/test_postprocess.py @@ -142,30 +142,33 @@ def test_postprocess(udf, mocker, artifact, data): def test_postprocess_no_artifact(udf): - msg = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) - assert len(msg) == 1 - payload = TrainerPayload(**orjson.loads(msg[0].value)) + msgs = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) + assert len(msgs) == 2 + payload = TrainerPayload(**orjson.loads(msgs[0].value)) assert payload.header == Header.TRAIN_REQUEST - assert msg[0].tags == ["train"] + assert msgs[0].tags == ["train"] + assert msgs[1].tags == ["staticthresh"] def test_postprocess_runtime_err_01(udf, mocker, artifact): mocker.patch.object(RedisRegistry, "load", return_value=artifact) mocker.patch.object(PostprocessUDF, "compute", side_effect=RuntimeError) - msg = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) - assert len(msg) == 1 - assert msg[0].tags == ["train"] - payload = TrainerPayload(**orjson.loads(msg[0].value)) + msgs = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) + assert len(msgs) == 2 + assert msgs[0].tags == ["train"] + payload = TrainerPayload(**orjson.loads(msgs[0].value)) assert payload.header == Header.TRAIN_REQUEST + assert msgs[1].tags == ["staticthresh"] def test_postprocess_runtime_err_02(udf, mocker, bad_artifact): mocker.patch.object(RedisRegistry, "load", return_value=bad_artifact) - msg = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) - assert len(msg) == 1 - payload = TrainerPayload(**orjson.loads(msg[0].value)) + msgs = udf(KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW)) + assert len(msgs) == 2 + payload = TrainerPayload(**orjson.loads(msgs[0].value)) assert payload.header == Header.TRAIN_REQUEST - assert msg[0].tags == ["train"] + assert msgs[0].tags == ["train"] + assert msgs[1].tags == ["staticthresh"] def test_compute(udf, artifact): diff --git a/tests/udfs/test_preprocess.py b/tests/udfs/test_preprocess.py index 18e0ad24..2676670e 100644 --- a/tests/udfs/test_preprocess.py +++ b/tests/udfs/test_preprocess.py @@ -100,9 +100,13 @@ def test_preprocess_run_time_error(setup, mocker): mocker.patch.object(PreprocessUDF, "compute", side_effect=RuntimeError) udf1, _ = setup msg = udf1(KEYS, DATUM) - payload = TrainerPayload(**orjson.loads(msg[0].value)) - assert len(msg) == 1 - assert payload.header == Header.TRAIN_REQUEST + assert len(msg) == 2 + payload_1 = TrainerPayload(**orjson.loads(msg[0].value)) + assert payload_1.header == Header.TRAIN_REQUEST + assert msg[0].tags == ["train"] + payload_2 = StreamPayload(**orjson.loads(msg[1].value)) + assert msg[1].tags == ["staticthresh"] + assert payload_2.status == Status.RUNTIME_ERROR def test_preprocess_data_error(setup): diff --git a/tests/udfs/test_staticthresh.py b/tests/udfs/test_staticthresh.py new file mode 100644 index 00000000..88960953 --- /dev/null +++ b/tests/udfs/test_staticthresh.py @@ -0,0 +1,120 @@ +import logging +import os +from datetime import datetime + +from numpy.testing import assert_array_almost_equal +import pytest +from omegaconf import OmegaConf +from orjson import orjson +from pynumaflow.mapper import Datum + +from numalogic._constants import TESTS_DIR +from numalogic.udfs import PipelineConf +from numalogic.udfs.entities import Status, Header, OutputPayload +from numalogic.udfs.staticthresh import StaticThresholdUDF + +logging.basicConfig(level=logging.DEBUG) +KEYS = ["service-mesh", "1", "2"] +DATUM_KW = { + "event_time": datetime.now(), + "watermark": datetime.now(), +} +DATA = { + "uuid": "dd7dfb43-532b-49a3-906e-f78f82ad9c4b", + "config_id": "druid-config", + "pipeline_id": "pipeline2", + "composite_keys": ["service-mesh", "1", "2"], + "data": [ + [2.055191, 2.205468], + [2.4223375, 1.4583645], + [2.8268616, 2.4160783], + [2.1107504, 1.458458], + [2.446076, 2.2556527], + [2.7057548, 2.579097], + [3.034152, 2.521946], + [1.7857871, 1.8762474], + [1.4797148, 2.4363635], + [1.526145, 2.6486845], + [1.0459993, 1.3363016], + [1.6239338, 1.4365934], + ], + "raw_data": [ + [11.0, 14.0], + [17.0, 4.0], + [22.0, 13.0], + [17.0, 7.0], + [23.0, 18.0], + [15.0, 15.0], + [16.0, 9.0], + [10.0, 10.0], + [3.0, 12.0], + [6.0, 21.0], + [5.0, 7.0], + [10.0, 8.0], + ], + "metrics": ["col1", "col2"], + "timestamps": [ + 1691623140000.0, + 1691623200000.0, + 1691623260000.0, + 1691623320000.0, + 1691623380000.0, + 1691623440000.0, + 1691623500000.0, + 1691623560000.0, + 1691623620000.0, + 1691623680000.0, + 1691623740000.0, + 1691623800000.0, + ], + "status": Status.RUNTIME_ERROR, + "header": Header.MODEL_INFERENCE, + "artifact_versions": {"pipeline2:StdDevThreshold": "0", "pipeline2:VanillaAE": "0"}, + "metadata": { + "tags": {"asset_alias": "data", "asset_id": "123456789", "env": "prd"}, + }, +} + + +@pytest.fixture +def conf() -> PipelineConf: + _given_conf = OmegaConf.load(os.path.join(TESTS_DIR, "udfs", "resources", "_config.yaml")) + schema = OmegaConf.structured(PipelineConf) + return PipelineConf(**OmegaConf.merge(schema, _given_conf)) + + +@pytest.fixture +def udf(conf) -> StaticThresholdUDF: + return StaticThresholdUDF(pl_conf=conf) + + +@pytest.fixture +def udf_args(): + return KEYS, Datum(keys=KEYS, value=orjson.dumps(DATA), **DATUM_KW) + + +def test_staticthresh(udf, udf_args): + msgs = udf(*udf_args) + assert len(msgs) == 1 + assert msgs[0].tags == ["output"] + payload = OutputPayload(**orjson.loads(msgs[0].value)) + + assert_array_almost_equal(payload.unified_anomaly, 1.666, decimal=3) + assert payload.data + assert_array_almost_equal(payload.data["col1"], 1.666, decimal=3) + assert_array_almost_equal(payload.data["col2"], 1.25, decimal=3) + assert payload.unified_anomaly == payload.data.get("unified_ST") + + +def test_err_01(udf, udf_args, mocker): + mocker.patch.object(StaticThresholdUDF, "compute_feature_scores", side_effect=RuntimeError) + msgs = udf(*udf_args) + assert len(msgs) == 1 + assert not msgs[0].value + + +def test_err_02(udf, udf_args, mocker): + mocker.patch.object(StaticThresholdUDF, "compute_unified_score", side_effect=RuntimeError) + msgs = udf(*udf_args) + assert len(msgs) == 1 + assert not msgs[0].value