Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add jitter #305

Merged
merged 11 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ class ModelInfo:


# TODO add this in the right config
@dataclass
class JitterConf:
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
"""
Schema for defining the jitter config to solve the Thundering Herd problem.

Args:
----
jitter_sec: Jitter in seconds
jitter_steps_sec: Step interval value (in secs) for jitter_sec value (default = 120 sec)
"""

jitter_sec: int = 30 * 60
jitter_steps_sec: int = 2 * 60


@dataclass
class RegistryInfo:
"""Registry config base class.
Expand All @@ -44,8 +59,10 @@ class RegistryInfo:
conf: kwargs for instantiating the model class
"""

name: str = MISSING
conf: dict[str, Any] = field(default_factory=dict)
name: str
model_expiry_sec: int
jitter_conf: JitterConf = field(default_factory=JitterConf)
extra_param: dict[str, Any] = field(default_factory=dict)


@dataclass
Expand Down Expand Up @@ -73,8 +90,7 @@ class TrainerConf:
train_hours: int = 24 * 8 # 8 days worth of data
min_train_size: int = 2000
retrain_freq_hr: int = 24
model_expiry_sec: int = 172800 # 48 hrs # TODO: revisit this
retry_secs: int = 600 # 10 min # TODO: revisit this
retry_sec: int = 600 # 10 min
batch_size: int = 64
pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf)

Expand Down
27 changes: 22 additions & 5 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import random
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, Union, Optional

from numalogic.tools.exceptions import ConfigError
from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T, state_dict_t


Expand Down Expand Up @@ -129,6 +129,21 @@ def construct_key(skeys: KEYS, dkeys: KEYS) -> str:
return "::".join([_static_key, _dynamic_key])


def _apply_jitter(ts: int, jitter_sec: int, jitter_steps_sec: int):
"""
Applies jitter to the ttl value to solve Thundering Herd problem.
z
Note: Jitter izs not applied if jitter_sec and jitter_steps_sec are both 0.
"""
if jitter_sec == jitter_steps_sec == 0:
return ts
if jitter_sec < jitter_steps_sec:
raise ConfigError("jitter_sec should be at least 60*jitter_steps_sec")
begin = ts if ts - jitter_sec < 0 else ts - jitter_sec
end = ts + jitter_sec + 1
return random.randrange(begin, end, jitter_steps_sec)


class ArtifactCache(Generic[M_K, A_D]):
r"""Base class for all artifact caches.
Caches support saving, loading and deletion, but not artifact versioning.
Expand All @@ -137,15 +152,17 @@ class ArtifactCache(Generic[M_K, A_D]):
----
cachesize: size of the cache
ttl: time to live for each item in the cache
jitter_sec: jitter in seconds to add to the ttl (to solve Thundering Herd problem)
jitter_steps_sec: Step interval value (in mins) for jitter_sec value
"""

_STORETYPE = "cache"

__slots__ = ("_cachesize", "_ttl")
__slots__ = ("_cachesize", "_ttl", "jitter_sec", "jitter_steps_sec")

def __init__(self, cachesize: int, ttl: int):
def __init__(self, cachesize: int, ttl: int, jitter_sec: int, jitter_steps_sec: int):
self._cachesize = cachesize
self._ttl = ttl
self._ttl = _apply_jitter(ts=ttl, jitter_sec=jitter_sec, jitter_steps_sec=jitter_steps_sec)

@property
def cachesize(self):
Expand Down
19 changes: 13 additions & 6 deletions numalogic/registry/localcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from copy import deepcopy
from threading import Lock
from typing import Optional

from cachetools import TTLCache

from numalogic.registry.artifact import ArtifactCache, ArtifactData
Expand All @@ -24,15 +23,23 @@ class LocalLRUCache(ArtifactCache, metaclass=Singleton):

Args:
----
cachesize: Size of the cache,
i.e. number of elements the cache can hold
ttl: Time to live for each item in seconds
cachesize: Size of the cache, i.e. number of elements the cache can hold
ttl: Time to live for each item in seconds
jitter_sec: Jitter in seconds to add to the ttl (to solve Thundering Herd problem)
(default = 120 secs)
jitter_steps_sec: Step interval value (in secs) for jitter_sec value (default = 60 secs)
"""

__cache: Optional[TTLCache] = None

def __init__(self, cachesize: int = 512, ttl: int = 300):
super().__init__(cachesize, ttl)
def __init__(
self,
cachesize: int = 512,
ttl: int = 300,
jitter_sec: int = 120,
jitter_steps_sec: int = 60,
):
super().__init__(cachesize, ttl, jitter_sec, jitter_steps_sec)
if not self.__cache:
self.__cache = TTLCache(maxsize=cachesize, ttl=ttl)
self.__lock = Lock()
Expand Down
33 changes: 22 additions & 11 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from redis.exceptions import RedisError

from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache
from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache, _apply_jitter
from numalogic.registry._serialize import loads, dumps
from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError
from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact
Expand All @@ -33,6 +33,9 @@ class RedisRegistry(ArtifactManager):
----
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
jitter_sec: Jitter (in secs) added to model timestamp information to solve
Thundering Herd problem (default = 30 mins)
jitter_steps_sec: Step interval value (in sec) for jitter_sec value (default = 120 secs)
cache_registry: Cache registry to use (default = None).
transactional: Flag to indicate if the registry should be transactional or
not (default = False).
Expand All @@ -51,18 +54,29 @@ class RedisRegistry(ArtifactManager):
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl", "cache_registry", "transactional")
__slots__ = (
"client",
"ttl",
"jitter_sec",
"jitter_steps_sec",
"cache_registry",
"transactional",
)

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
jitter_sec: int = 30 * 60,
jitter_steps_sec: int = 2 * 60,
cache_registry: Optional[ArtifactCache] = None,
transactional: bool = True,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.jitter_sec = jitter_sec
self.jitter_steps_sec = jitter_steps_sec
self.cache_registry = cache_registry
self.transactional = transactional

Expand Down Expand Up @@ -116,9 +130,7 @@ def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None:

def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]:
if self.cache_registry:
if key:
return self.cache_registry.delete(key)
return self.cache_registry.clear()
return self.cache_registry.delete(key) if key else self.cache_registry.clear()
return None

def __get_artifact_data(
Expand Down Expand Up @@ -196,16 +208,15 @@ def __save_artifact(
latest_key = self.__construct_latest_key(key)
pipe.set(name=latest_key, value=new_version_key)
_LOGGER.debug("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key)
serialized_metadata = ""
if metadata:
serialized_metadata = orjson.dumps(metadata)
serialized_metadata = orjson.dumps(metadata) if metadata else b""
serialized_artifact = dumps(deserialized_object=artifact)
_cur_ts = int(time.time())
pipe.hset(
name=new_version_key,
mapping={
"artifact": serialized_artifact,
"version": str(version),
"timestamp": time.time(),
"version": version,
"timestamp": _apply_jitter(_cur_ts, self.jitter_sec, self.jitter_steps_sec),
"metadata": serialized_metadata,
},
)
Expand Down Expand Up @@ -403,7 +414,7 @@ def save_multiple(
"""
dict_model_ver = {}
try:
for key, value in dict_artifacts.items():
for value in dict_artifacts.values():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
Expand Down
6 changes: 6 additions & 0 deletions numalogic/tools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class ConfigNotFoundError(RuntimeError):
pass


class ConfigError(RuntimeError):
"""Raised when a config value has a problem."""

pass


class ModelVersionError(Exception):
"""Raised when a model version is not found in the registry."""

Expand Down
6 changes: 5 additions & 1 deletion numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from omegaconf import OmegaConf

from numalogic.config import NumalogicConf
from numalogic.config import NumalogicConf, RegistryInfo

from numalogic.connectors import (
ConnectorType,
RedisConf,
Expand All @@ -26,6 +27,9 @@ class StreamConf:
class PipelineConf:
stream_confs: dict[str, StreamConf] = field(default_factory=dict)
redis_conf: Optional[RedisConf] = None
registry_conf: Optional[RegistryInfo] = field(
default_factory=lambda: RegistryInfo(name="RedisRegistry", model_expiry_sec=172800)
)
prometheus_conf: Optional[PrometheusConf] = None
druid_conf: Optional[DruidConf] = None

Expand Down
12 changes: 9 additions & 3 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ class InferenceUDF(NumalogicUDF):

def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None):
super().__init__(is_async=False)
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client,
cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL, cachesize=LOCAL_CACHE_SIZE),
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_sec=self.registry_conf.jitter_conf.jitter_sec,
jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec,
),
)
self.pl_conf = pl_conf or PipelineConf()

# TODO: remove, and have an update config method
def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
13 changes: 10 additions & 3 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ def __init__(
pl_conf: Optional[PipelineConf] = None,
):
super().__init__()
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL)
client=r_client,
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_sec=self.registry_conf.jitter_conf.jitter_sec,
jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec,
),
)
self.pl_conf = pl_conf or PipelineConf()
self.postproc_factory = PostprocessFactory()

def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
13 changes: 10 additions & 3 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ class PreprocessUDF(NumalogicUDF):

def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None):
super().__init__()
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL)
client=r_client,
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_sec=self.registry_conf.jitter_conf.jitter_sec,
jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec,
),
)
self.pl_conf = pl_conf or PipelineConf()
self.preproc_factory = PreprocessFactory()

def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
15 changes: 12 additions & 3 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ def __init__(
):
super().__init__(is_async=False)
self.r_client = r_client
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.model_registry = model_registry_cls(client=r_client)
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
model_expiry_sec = self.pl_conf.registry_conf.model_expiry_sec
jitter_sec = self.registry_conf.jitter_conf.jitter_sec
jitter_steps_sec = self.registry_conf.jitter_conf.jitter_steps_sec
self.model_registry = model_registry_cls(
client=r_client,
ttl=model_expiry_sec,
jitter_sec=jitter_sec,
jitter_steps_sec=jitter_steps_sec,
)
self.druid_conf = self.pl_conf.druid_conf

data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher")
Expand Down Expand Up @@ -169,7 +178,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages:

# set the retry and retrain_freq
retrain_freq_ts = _conf.numalogic_conf.trainer.retrain_freq_hr
retry_ts = _conf.numalogic_conf.trainer.retry_secs
retry_ts = _conf.numalogic_conf.trainer.retry_sec
if not self.train_msg_deduplicator.ack_read(
key=payload.composite_keys,
uuid=payload.uuid,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.0a9"
version = "0.6.0a10"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
Loading