Skip to content

Commit

Permalink
Enable custom metrics (#162)
Browse files Browse the repository at this point in the history
Signed-off-by: Saravanan Balasubramanian <[email protected]>
Signed-off-by: Avik Basu <[email protected]>
Co-authored-by: Avik Basu <[email protected]>
Co-authored-by: Avik Basu <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2023
1 parent 5653128 commit e922195
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 8 deletions.
11 changes: 11 additions & 0 deletions numaprom/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from numaprom.metrics._metrics import (
increase_redis_conn_error,
inc_inference_count,
start_metrics_server,
)

__all__ = [
"increase_redis_conn_error",
"inc_inference_count",
"start_metrics_server",
]
25 changes: 25 additions & 0 deletions numaprom/metrics/_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from prometheus_client import start_http_server
from prometheus_client import Counter

from numaprom import LOGGER

# Metrics
REDIS_CONN_ERROR_COUNT = Counter("numaprom_redis_conn_error_count", "", ["vertex"])
INFERENCE_COUNT = Counter(
"numaprom_inference_count", "", ["model", "namespace", "app", "metric", "status"]
)


def increase_redis_conn_error(vertex: str) -> None:
global REDIS_CONN_ERROR_COUNT
REDIS_CONN_ERROR_COUNT.labels(vertex).inc()


def inc_inference_count(model: str, namespace: str, app: str, metric: str, status: str) -> None:
global INFERENCE_COUNT
INFERENCE_COUNT.labels(model, namespace, app, metric, status).inc()


def start_metrics_server(port: int) -> None:
LOGGER.info("Starting Prometheus metrics server on port: {port}", port=port)
start_http_server(port)
13 changes: 13 additions & 0 deletions numaprom/udf/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import time
from typing import Final

from numalogic.config import NumalogicConf
from numalogic.models.autoencoder import AutoencoderTrainer
from numalogic.registry import ArtifactData, RedisRegistry, LocalLRUCache
Expand All @@ -13,9 +15,12 @@
from numaprom.clients.sentinel import get_redis_client_from_conf
from numaprom.entities import PayloadFactory
from numaprom.entities import Status, StreamPayload, Header
from numaprom.metrics import increase_redis_conn_error, inc_inference_count
from numaprom.tools import msg_forward
from numaprom.watcher import ConfigManager


_VERTEX: Final[str] = "inference"
REDIS_CLIENT = get_redis_client_from_conf(master_node=False)
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", 3600)) # default ttl set to 1 hour

Expand Down Expand Up @@ -84,6 +89,7 @@ def inference(_: list[str], datum: Datum) -> bytes:
)
payload.set_header(Header.STATIC_INFERENCE)
payload.set_status(Status.RUNTIME_ERROR)
increase_redis_conn_error(_VERTEX)
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)

if not artifact_data:
Expand Down Expand Up @@ -123,6 +129,13 @@ def inference(_: list[str], datum: Datum) -> bytes:
payload.set_status(Status.RUNTIME_ERROR)
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)

inc_inference_count(
model=payload.get_metadata("version"),
namespace=payload.composite_keys.get("namespace"),
app=payload.composite_keys.get("app"),
metric=payload.composite_keys.get("name"),
status=payload.header,
)
LOGGER.info("{uuid} - Sending Payload: {payload} ", uuid=payload.uuid, payload=payload)
LOGGER.debug(
"{uuid} - Time taken in inference: {time} sec",
Expand Down
6 changes: 6 additions & 0 deletions numaprom/udf/postprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
from typing import Final

import numpy as np
from orjson import orjson
Expand All @@ -10,9 +11,13 @@
from numaprom import LOGGER, UnifiedConf
from numaprom.clients.sentinel import get_redis_client_from_conf
from numaprom.entities import Status, PrometheusPayload, StreamPayload, Header
from numaprom.metrics import increase_redis_conn_error
from numaprom.tools import msgs_forward, WindowScorer

from numaprom.watcher import ConfigManager


_VERTEX: Final[str] = "postprocess"
AUTH = os.getenv("REDIS_AUTH")
SCORE_PRECISION = int(os.getenv("SCORE_PRECISION", 3))
UNDEFINED_SCORE = -1.0
Expand Down Expand Up @@ -148,6 +153,7 @@ def _publish(final_score: float, payload: StreamPayload) -> list[bytes]:
uuid=payload.uuid,
warn=warn,
)
increase_redis_conn_error(_VERTEX)
unified_anomaly, anomalies = __save_to_redis(
payload=payload, final_score=final_score, recreate=True, unified_config=unified_config
)
Expand Down
5 changes: 4 additions & 1 deletion numaprom/udf/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
from typing import Final

import orjson
from numalogic.registry import RedisRegistry, LocalLRUCache
Expand All @@ -10,9 +11,10 @@
from numaprom.clients.sentinel import get_redis_client
from numaprom.entities import Status, StreamPayload, Header
from numaprom.tools import msg_forward
from numaprom.metrics import increase_redis_conn_error
from numaprom.watcher import ConfigManager


_VERTEX: Final[str] = "preprocess"
AUTH = os.getenv("REDIS_AUTH")
REDIS_CONF = ConfigManager.get_redis_config()
REDIS_CLIENT = get_redis_client(
Expand Down Expand Up @@ -56,6 +58,7 @@ def preprocess(_: list[str], datum: Datum) -> bytes:
)
payload.set_header(Header.STATIC_INFERENCE)
payload.set_status(Status.RUNTIME_ERROR)
increase_redis_conn_error(_VERTEX)
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
except Exception as ex:
LOGGER.exception(
Expand Down
6 changes: 6 additions & 0 deletions numaprom/udf/threshold.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
from collections import OrderedDict
from typing import Final

from numalogic.registry import RedisRegistry, LocalLRUCache
from numalogic.tools.exceptions import RedisRegistryError
Expand All @@ -11,9 +12,12 @@
from numaprom._constants import TRAIN_VTX_KEY, POSTPROC_VTX_KEY
from numaprom.clients.sentinel import get_redis_client_from_conf
from numaprom.entities import Status, TrainerPayload, PayloadFactory, Header
from numaprom.metrics import increase_redis_conn_error
from numaprom.tools import conditional_forward, calculate_static_thresh
from numaprom.watcher import ConfigManager


_VERTEX: Final[str] = "threshold"
LOCAL_CACHE_TTL = int(os.getenv("LOCAL_CACHE_TTL", 3600)) # default ttl set to 1 hour


Expand Down Expand Up @@ -78,6 +82,7 @@ def threshold(_: list[str], datum: Datum) -> list[tuple[str, bytes]]:
keys=payload.composite_keys,
err=err,
)
increase_redis_conn_error(_VERTEX)
payload.set_header(Header.STATIC_INFERENCE)
payload.set_status(Status.RUNTIME_ERROR)
return [
Expand All @@ -98,6 +103,7 @@ def threshold(_: list[str], datum: Datum) -> list[tuple[str, bytes]]:
(TRAIN_VTX_KEY, orjson.dumps(train_payload)),
(POSTPROC_VTX_KEY, _get_static_thresh_payload(payload, metric_config)),
]

if not thresh_artifact:
LOGGER.info(
"{uuid} - Threshold artifact not found, performing static thresholding. Keys: {keys}",
Expand Down
5 changes: 5 additions & 0 deletions numaprom/udf/window.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import time
import uuid
from typing import Final

import numpy as np
import numpy.typing as npt
Expand All @@ -11,9 +12,12 @@
from numaprom import LOGGER
from numaprom.clients.sentinel import get_redis_client_from_conf
from numaprom.entities import StreamPayload, Status, Header
from numaprom.metrics import increase_redis_conn_error
from numaprom.tools import msg_forward, create_composite_keys
from numaprom.watcher import ConfigManager

_VERTEX: Final[str] = "window"


# TODO get the replacement value from config
def _clean_arr(
Expand Down Expand Up @@ -84,6 +88,7 @@ def window(_: list[str], datum: Datum) -> bytes | None:
)
except (RedisError, RedisClusterException) as warn:
LOGGER.warning("Redis connection failed, recreating the redis client, err: {err}", err=warn)
increase_redis_conn_error(_VERTEX)
elements = __aggregate_window(
unique_key, msg["timestamp"], value, win_size, buff_size, recreate=True
)
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic-prometheus"
version = "0.4.15"
version = "0.5.0"
description = "ML inference on numaflow using numalogic on Prometheus metrics"
authors = ["Numalogic developers"]
packages = [{ include = "numaprom" }]
Expand All @@ -26,6 +26,7 @@ orjson = "^3.8.4"
omegaconf = "^2.3.0"
watchdog = "^3.0.0"
loguru = "^0.7.0"
prometheus-client = "^0.17"

[tool.poetry.group.dev]
optional = true
Expand Down
2 changes: 2 additions & 0 deletions starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from numaprom._constants import CONFIG_PATHS
from numaprom.factory import HandlerFactory
from numaprom.watcher import Watcher, ConfigHandler
from numaprom.metrics import start_metrics_server


def run_watcher():
Expand All @@ -20,6 +21,7 @@ def run_watcher():
background_thread.start()

step_handler = HandlerFactory.get_handler(sys.argv[2])
start_metrics_server(8490)
server_type = sys.argv[1]

if server_type == "udsink":
Expand Down
8 changes: 4 additions & 4 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@
from datetime import datetime
from unittest.mock import patch, Mock

from pynumaflow.sink import Datum
from numalogic.tools.exceptions import InvalidDataShapeError
from pynumaflow.sink import Datum

from numaprom._constants import TESTS_DIR
from numaprom.clients.prometheus import Prometheus
from tests import train, redis_client, train_rollout
from tests.tools import (
mock_argocd_query_metric,
mock_rollout_query_metric,
mock_rollout_query_metric2,
mock_rollout_query_metric3,
)
from tests import train, redis_client, train_rollout

DATA_DIR = os.path.join(TESTS_DIR, "resources", "data")
STREAM_DATA_PATH = os.path.join(DATA_DIR, "stream.json")


def as_datum(data: str | bytes | dict, msg_id="1") -> Datum:
if type(data) is not bytes:
if not isinstance(data, bytes):
data = json.dumps(data).encode("utf-8")
elif type(data) == dict:
elif isinstance(data, dict):
data = json.dumps(data)

return Datum(
Expand Down
2 changes: 1 addition & 1 deletion tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def mockenv(**envvars):


def get_datum(data: str or bytes) -> Datum:
if type(data) is not bytes:
if not isinstance(data, bytes):
data = json.dumps(data).encode("utf-8")

return Datum(
Expand Down

0 comments on commit e922195

Please sign in to comment.