Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add jitter #305

Merged
merged 11 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions numalogic/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ class ModelInfo:


# TODO add this in the right config
@dataclass
class JitterConf:
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
"""
Schema for defining the jitter config to solve the Thundering Herd problem.

Args:
----
jitter_secs: Jitter in seconds
jitter_steps: Jitter in steps
ab93 marked this conversation as resolved.
Show resolved Hide resolved
"""

jitter_secs: int = 30 * 60
jitter_steps: int = 2 * 60


@dataclass
class RegistryInfo:
"""Registry config base class.
Expand All @@ -44,8 +59,10 @@ class RegistryInfo:
conf: kwargs for instantiating the model class
"""

name: str = MISSING
conf: dict[str, Any] = field(default_factory=dict)
name: str
model_expiry_sec: int
jitter_conf: JitterConf = field(default_factory=JitterConf)
extra_param: dict[str, Any] = field(default_factory=dict)


@dataclass
Expand Down Expand Up @@ -73,8 +90,7 @@ class TrainerConf:
train_hours: int = 24 * 8 # 8 days worth of data
min_train_size: int = 2000
retrain_freq_hr: int = 24
model_expiry_sec: int = 172800 # 48 hrs # TODO: revisit this
retry_secs: int = 600 # 10 min # TODO: revisit this
retry_secs: int = 600 # 10 min
batch_size: int = 64
pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf)

Expand Down
11 changes: 7 additions & 4 deletions numalogic/registry/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import random
from dataclasses import dataclass
from typing import Any, Generic, TypeVar, Union, Optional

Expand Down Expand Up @@ -137,15 +136,19 @@ class ArtifactCache(Generic[M_K, A_D]):
----
cachesize: size of the cache
ttl: time to live for each item in the cache
jitter_secs: jitter in seconds to add to the ttl (to solve Thundering Herd problem)
jitter_steps: granularity of jitter_secs
"""

_STORETYPE = "cache"

__slots__ = ("_cachesize", "_ttl")

def __init__(self, cachesize: int, ttl: int):
def __init__(self, cachesize: int, ttl: int, jitter_secs: int = 0, jitter_steps: int = 1):
self._cachesize = cachesize
self._ttl = ttl
self._ttl = abs(
random.randrange(ttl - jitter_secs, ttl + jitter_secs + 1, jitter_steps * 60)
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
)

@property
def cachesize(self):
Expand Down
11 changes: 7 additions & 4 deletions numalogic/registry/localcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from copy import deepcopy
from threading import Lock
from typing import Optional

from cachetools import TTLCache

from numalogic.registry.artifact import ArtifactCache, ArtifactData
Expand All @@ -27,14 +26,18 @@ class LocalLRUCache(ArtifactCache, metaclass=Singleton):
cachesize: Size of the cache,
i.e. number of elements the cache can hold
ttl: Time to live for each item in seconds
jitter_secs: Jitter in seconds to add to the ttl (to solve Thundering Herd problem)
jitter_steps: Granularity of the jitter_secs
"""

__cache: Optional[TTLCache] = None

def __init__(self, cachesize: int = 512, ttl: int = 300):
super().__init__(cachesize, ttl)
def __init__(
self, cachesize: int = 512, ttl: int = 120, jitter_secs: int = 0, jitter_steps: int = 1
):
super().__init__(cachesize, ttl, jitter_secs, jitter_steps)
if not self.__cache:
self.__cache = TTLCache(maxsize=cachesize, ttl=ttl)
self.__cache = TTLCache(maxsize=self.cachesize, ttl=self._ttl)
self.__lock = Lock()

def __contains__(self, artifact_key: str) -> bool:
Expand Down
29 changes: 19 additions & 10 deletions numalogic/registry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import time
from datetime import datetime, timedelta
import random
from typing import Optional
import orjson
import redis.client
Expand All @@ -33,6 +34,9 @@ class RedisRegistry(ArtifactManager):
----
client: Take in the redis client already established/created
ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800)
jitter_secs: Jitter (in secs) added to model timestamp information to solve
Thundering Herd problem (default = 30 mins)
jitter_steps: granularity of jitter secs in secs() (default = 2 mins)
cache_registry: Cache registry to use (default = None).
transactional: Flag to indicate if the registry should be transactional or
not (default = False).
Expand All @@ -51,18 +55,22 @@ class RedisRegistry(ArtifactManager):
>>> loaded_artifact = registry.load(skeys, dkeys)
"""

__slots__ = ("client", "ttl", "cache_registry", "transactional")
__slots__ = ("client", "ttl", "jitter_secs", "jitter_steps", "cache_registry", "transactional")

def __init__(
self,
client: redis_client_t,
ttl: int = 604800,
jitter_secs: int = 30 * 60,
jitter_steps: int = 2 * 60,
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
cache_registry: Optional[ArtifactCache] = None,
transactional: bool = True,
):
super().__init__("")
self.client = client
self.ttl = ttl
self.jitter_secs = jitter_secs
self.jitter_steps = jitter_steps
self.cache_registry = cache_registry
self.transactional = transactional

Expand Down Expand Up @@ -116,9 +124,7 @@ def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None:

def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]:
if self.cache_registry:
if key:
return self.cache_registry.delete(key)
return self.cache_registry.clear()
return self.cache_registry.delete(key) if key else self.cache_registry.clear()
return None

def __get_artifact_data(
Expand Down Expand Up @@ -196,16 +202,19 @@ def __save_artifact(
latest_key = self.__construct_latest_key(key)
pipe.set(name=latest_key, value=new_version_key)
_LOGGER.debug("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key)
serialized_metadata = ""
if metadata:
serialized_metadata = orjson.dumps(metadata)
serialized_metadata = orjson.dumps(metadata) if metadata else ""
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
serialized_artifact = dumps(deserialized_object=artifact)
_cur_ts = int(time.time())
pipe.hset(
name=new_version_key,
mapping={
"artifact": serialized_artifact,
"version": str(version),
"timestamp": time.time(),
"version": version,
"timestamp": random.randrange(
_cur_ts - self.jitter_secs,
_cur_ts + self.jitter_secs + 1,
60 * self.jitter_steps,
s0nicboOm marked this conversation as resolved.
Show resolved Hide resolved
),
"metadata": serialized_metadata,
},
)
Expand Down Expand Up @@ -403,7 +412,7 @@ def save_multiple(
"""
dict_model_ver = {}
try:
for key, value in dict_artifacts.items():
for value in dict_artifacts.values():
dict_model_ver[":".join(value.dkeys)] = self.save(
skeys=skeys,
dkeys=value.dkeys,
Expand Down
6 changes: 5 additions & 1 deletion numalogic/udfs/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from omegaconf import OmegaConf

from numalogic.config import NumalogicConf
from numalogic.config import NumalogicConf, RegistryInfo

from numalogic.connectors import (
ConnectorType,
RedisConf,
Expand All @@ -26,6 +27,9 @@ class StreamConf:
class PipelineConf:
stream_confs: dict[str, StreamConf] = field(default_factory=dict)
redis_conf: Optional[RedisConf] = None
registry_conf: Optional[RegistryInfo] = field(
default_factory=lambda: RegistryInfo(name="RedisRegistry", model_expiry_sec=172800)
)
prometheus_conf: Optional[PrometheusConf] = None
druid_conf: Optional[DruidConf] = None

Expand Down
12 changes: 9 additions & 3 deletions numalogic/udfs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,18 @@ class InferenceUDF(NumalogicUDF):

def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None):
super().__init__(is_async=False)
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client,
cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL, cachesize=LOCAL_CACHE_SIZE),
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_secs=self.registry_conf.jitter_conf.jitter_secs,
jitter_steps=self.registry_conf.jitter_conf.jitter_steps,
),
)
self.pl_conf = pl_conf or PipelineConf()

# TODO: remove, and have an update config method
def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
13 changes: 10 additions & 3 deletions numalogic/udfs/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ def __init__(
pl_conf: Optional[PipelineConf] = None,
):
super().__init__()
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL)
client=r_client,
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_secs=self.registry_conf.jitter_conf.jitter_secs,
jitter_steps=self.registry_conf.jitter_conf.jitter_steps,
),
)
self.pl_conf = pl_conf or PipelineConf()
self.postproc_factory = PostprocessFactory()

def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
13 changes: 10 additions & 3 deletions numalogic/udfs/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,18 @@ class PreprocessUDF(NumalogicUDF):

def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None):
super().__init__()
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
self.model_registry = model_registry_cls(
client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL)
client=r_client,
cache_registry=LocalLRUCache(
ttl=LOCAL_CACHE_TTL,
cachesize=LOCAL_CACHE_SIZE,
jitter_secs=self.registry_conf.jitter_conf.jitter_secs,
jitter_steps=self.registry_conf.jitter_conf.jitter_steps,
),
)
self.pl_conf = pl_conf or PipelineConf()
self.preproc_factory = PreprocessFactory()

def register_conf(self, config_id: str, conf: StreamConf) -> None:
Expand Down
13 changes: 11 additions & 2 deletions numalogic/udfs/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,18 @@ def __init__(
):
super().__init__(is_async=False)
self.r_client = r_client
model_registry_cls = RegistryFactory.get_cls("RedisRegistry")
self.model_registry = model_registry_cls(client=r_client)
self.pl_conf = pl_conf or PipelineConf()
self.registry_conf = self.pl_conf.registry_conf
model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name)
model_expiry_sec = self.pl_conf.registry_conf.model_expiry_sec
jitter_secs = self.registry_conf.jitter_conf.jitter_secs
jitter_steps = self.registry_conf.jitter_conf.jitter_steps
self.model_registry = model_registry_cls(
client=r_client,
ttl=model_expiry_sec,
jitter_secs=jitter_secs,
jitter_steps=jitter_steps,
)
self.druid_conf = self.pl_conf.druid_conf

data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher")
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "numalogic"
version = "0.6.0a8"
version = "0.6.0a9"
description = "Collection of operational Machine Learning models and tools."
authors = ["Numalogic Developers"]
packages = [{ include = "numalogic" }]
Expand Down
10 changes: 6 additions & 4 deletions tests/config/test_optdeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

class TestOptionalDependencies(unittest.TestCase):
def setUp(self) -> None:
self.regconf = RegistryInfo(name="RedisRegistry", conf=dict(ttl=50))
self.regconf = RegistryInfo(
name="RedisRegistry", model_expiry_sec=10, extra_param=dict(ttl=50)
)

@patch("numalogic.config.factory.getattr", side_effect=AttributeError)
def test_not_installed_dep_01(self, _):
Expand All @@ -30,7 +32,7 @@ def test_not_installed_dep_01(self, _):
server = fakeredis.FakeServer()
redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False)
with self.assertRaises(ImportError):
model_factory.get_cls("RedisRegistry")(redis_cli, **self.regconf.conf)
model_factory.get_cls("RedisRegistry")(redis_cli, **self.regconf.extra_param)

@patch("numalogic.config.factory.getattr", side_effect=AttributeError)
def test_not_installed_dep_02(self, _):
Expand All @@ -40,13 +42,13 @@ def test_not_installed_dep_02(self, _):
server = fakeredis.FakeServer()
redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False)
with self.assertRaises(ImportError):
model_factory.get_instance(self.regconf)(redis_cli, **self.regconf.conf)
model_factory.get_instance(self.regconf)(redis_cli, **self.regconf.extra_param)

def test_unknown_registry(self):
from numalogic.config.factory import RegistryFactory

model_factory = RegistryFactory()
reg_conf = RegistryInfo(name="UnknownRegistry")
reg_conf = RegistryInfo(name="UnknownRegistry", model_expiry_sec=1)
with self.assertRaises(UnknownConfigArgsError):
model_factory.get_cls("UnknownRegistry")
with self.assertRaises(UnknownConfigArgsError):
Expand Down
Loading