diff --git a/numalogic/config/_config.py b/numalogic/config/_config.py index 62405ef7..e3d965f6 100644 --- a/numalogic/config/_config.py +++ b/numalogic/config/_config.py @@ -34,6 +34,21 @@ class ModelInfo: # TODO add this in the right config +@dataclass +class JitterConf: + """ + Schema for defining the jitter config to solve the Thundering Herd problem. + + Args: + ---- + jitter_sec: Jitter in seconds + jitter_steps_sec: Step interval value (in secs) for jitter_sec value (default = 120 sec) + """ + + jitter_sec: int = 30 * 60 + jitter_steps_sec: int = 2 * 60 + + @dataclass class RegistryInfo: """Registry config base class. @@ -44,8 +59,10 @@ class RegistryInfo: conf: kwargs for instantiating the model class """ - name: str = MISSING - conf: dict[str, Any] = field(default_factory=dict) + name: str + model_expiry_sec: int + jitter_conf: JitterConf = field(default_factory=JitterConf) + extra_param: dict[str, Any] = field(default_factory=dict) @dataclass @@ -73,8 +90,7 @@ class TrainerConf: train_hours: int = 24 * 8 # 8 days worth of data min_train_size: int = 2000 retrain_freq_hr: int = 24 - model_expiry_sec: int = 172800 # 48 hrs # TODO: revisit this - retry_secs: int = 600 # 10 min # TODO: revisit this + retry_sec: int = 600 # 10 min batch_size: int = 64 pltrainer_conf: LightningTrainerConf = field(default_factory=LightningTrainerConf) diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index 3ba62742..6dedfa38 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -8,11 +8,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import random from dataclasses import dataclass from typing import Any, Generic, TypeVar, Union, Optional +from numalogic.tools.exceptions import ConfigError from numalogic.tools.types import artifact_t, KEYS, META_T, META_VT, EXTRA_T, state_dict_t @@ -129,6 +129,21 @@ def construct_key(skeys: KEYS, dkeys: KEYS) -> str: return "::".join([_static_key, _dynamic_key]) +def _apply_jitter(ts: int, jitter_sec: int, jitter_steps_sec: int): + """ + Applies jitter to the ttl value to solve Thundering Herd problem. + z + Note: Jitter izs not applied if jitter_sec and jitter_steps_sec are both 0. + """ + if jitter_sec == jitter_steps_sec == 0: + return ts + if jitter_sec < jitter_steps_sec: + raise ConfigError("jitter_sec should be at least 60*jitter_steps_sec") + begin = ts if ts - jitter_sec < 0 else ts - jitter_sec + end = ts + jitter_sec + 1 + return random.randrange(begin, end, jitter_steps_sec) + + class ArtifactCache(Generic[M_K, A_D]): r"""Base class for all artifact caches. Caches support saving, loading and deletion, but not artifact versioning. @@ -137,15 +152,17 @@ class ArtifactCache(Generic[M_K, A_D]): ---- cachesize: size of the cache ttl: time to live for each item in the cache + jitter_sec: jitter in seconds to add to the ttl (to solve Thundering Herd problem) + jitter_steps_sec: Step interval value (in mins) for jitter_sec value """ _STORETYPE = "cache" - __slots__ = ("_cachesize", "_ttl") + __slots__ = ("_cachesize", "_ttl", "jitter_sec", "jitter_steps_sec") - def __init__(self, cachesize: int, ttl: int): + def __init__(self, cachesize: int, ttl: int, jitter_sec: int, jitter_steps_sec: int): self._cachesize = cachesize - self._ttl = ttl + self._ttl = _apply_jitter(ts=ttl, jitter_sec=jitter_sec, jitter_steps_sec=jitter_steps_sec) @property def cachesize(self): diff --git a/numalogic/registry/localcache.py b/numalogic/registry/localcache.py index b9993dda..a69187bd 100644 --- a/numalogic/registry/localcache.py +++ b/numalogic/registry/localcache.py @@ -12,7 +12,6 @@ from copy import deepcopy from threading import Lock from typing import Optional - from cachetools import TTLCache from numalogic.registry.artifact import ArtifactCache, ArtifactData @@ -24,15 +23,23 @@ class LocalLRUCache(ArtifactCache, metaclass=Singleton): Args: ---- - cachesize: Size of the cache, - i.e. number of elements the cache can hold - ttl: Time to live for each item in seconds + cachesize: Size of the cache, i.e. number of elements the cache can hold + ttl: Time to live for each item in seconds + jitter_sec: Jitter in seconds to add to the ttl (to solve Thundering Herd problem) + (default = 120 secs) + jitter_steps_sec: Step interval value (in secs) for jitter_sec value (default = 60 secs) """ __cache: Optional[TTLCache] = None - def __init__(self, cachesize: int = 512, ttl: int = 300): - super().__init__(cachesize, ttl) + def __init__( + self, + cachesize: int = 512, + ttl: int = 300, + jitter_sec: int = 120, + jitter_steps_sec: int = 60, + ): + super().__init__(cachesize, ttl, jitter_sec, jitter_steps_sec) if not self.__cache: self.__cache = TTLCache(maxsize=cachesize, ttl=ttl) self.__lock = Lock() diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 0f4fe952..234233ac 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -18,7 +18,7 @@ from redis.exceptions import RedisError -from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache +from numalogic.registry.artifact import ArtifactManager, ArtifactData, ArtifactCache, _apply_jitter from numalogic.registry._serialize import loads, dumps from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError from numalogic.tools.types import artifact_t, redis_client_t, KEYS, META_T, META_VT, KeyedArtifact @@ -33,6 +33,9 @@ class RedisRegistry(ArtifactManager): ---- client: Take in the redis client already established/created ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800) + jitter_sec: Jitter (in secs) added to model timestamp information to solve + Thundering Herd problem (default = 30 mins) + jitter_steps_sec: Step interval value (in sec) for jitter_sec value (default = 120 secs) cache_registry: Cache registry to use (default = None). transactional: Flag to indicate if the registry should be transactional or not (default = False). @@ -51,18 +54,29 @@ class RedisRegistry(ArtifactManager): >>> loaded_artifact = registry.load(skeys, dkeys) """ - __slots__ = ("client", "ttl", "cache_registry", "transactional") + __slots__ = ( + "client", + "ttl", + "jitter_sec", + "jitter_steps_sec", + "cache_registry", + "transactional", + ) def __init__( self, client: redis_client_t, ttl: int = 604800, + jitter_sec: int = 30 * 60, + jitter_steps_sec: int = 2 * 60, cache_registry: Optional[ArtifactCache] = None, transactional: bool = True, ): super().__init__("") self.client = client self.ttl = ttl + self.jitter_sec = jitter_sec + self.jitter_steps_sec = jitter_steps_sec self.cache_registry = cache_registry self.transactional = transactional @@ -116,9 +130,7 @@ def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None: def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]: if self.cache_registry: - if key: - return self.cache_registry.delete(key) - return self.cache_registry.clear() + return self.cache_registry.delete(key) if key else self.cache_registry.clear() return None def __get_artifact_data( @@ -196,16 +208,15 @@ def __save_artifact( latest_key = self.__construct_latest_key(key) pipe.set(name=latest_key, value=new_version_key) _LOGGER.debug("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key) - serialized_metadata = "" - if metadata: - serialized_metadata = orjson.dumps(metadata) + serialized_metadata = orjson.dumps(metadata) if metadata else b"" serialized_artifact = dumps(deserialized_object=artifact) + _cur_ts = int(time.time()) pipe.hset( name=new_version_key, mapping={ "artifact": serialized_artifact, - "version": str(version), - "timestamp": time.time(), + "version": version, + "timestamp": _apply_jitter(_cur_ts, self.jitter_sec, self.jitter_steps_sec), "metadata": serialized_metadata, }, ) @@ -403,7 +414,7 @@ def save_multiple( """ dict_model_ver = {} try: - for key, value in dict_artifacts.items(): + for value in dict_artifacts.values(): dict_model_ver[":".join(value.dkeys)] = self.save( skeys=skeys, dkeys=value.dkeys, diff --git a/numalogic/tools/exceptions.py b/numalogic/tools/exceptions.py index 0b22f563..76eced27 100644 --- a/numalogic/tools/exceptions.py +++ b/numalogic/tools/exceptions.py @@ -52,6 +52,12 @@ class ConfigNotFoundError(RuntimeError): pass +class ConfigError(RuntimeError): + """Raised when a config value has a problem.""" + + pass + + class ModelVersionError(Exception): """Raised when a model version is not found in the registry.""" diff --git a/numalogic/udfs/_config.py b/numalogic/udfs/_config.py index 0be4d519..da8d11a2 100644 --- a/numalogic/udfs/_config.py +++ b/numalogic/udfs/_config.py @@ -3,7 +3,8 @@ from omegaconf import OmegaConf -from numalogic.config import NumalogicConf +from numalogic.config import NumalogicConf, RegistryInfo + from numalogic.connectors import ( ConnectorType, RedisConf, @@ -26,6 +27,9 @@ class StreamConf: class PipelineConf: stream_confs: dict[str, StreamConf] = field(default_factory=dict) redis_conf: Optional[RedisConf] = None + registry_conf: Optional[RegistryInfo] = field( + default_factory=lambda: RegistryInfo(name="RedisRegistry", model_expiry_sec=172800) + ) prometheus_conf: Optional[PrometheusConf] = None druid_conf: Optional[DruidConf] = None diff --git a/numalogic/udfs/inference.py b/numalogic/udfs/inference.py index af7f0354..9d7e59ab 100644 --- a/numalogic/udfs/inference.py +++ b/numalogic/udfs/inference.py @@ -38,12 +38,18 @@ class InferenceUDF(NumalogicUDF): def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None): super().__init__(is_async=False) - model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.pl_conf = pl_conf or PipelineConf() + self.registry_conf = self.pl_conf.registry_conf + model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name) self.model_registry = model_registry_cls( client=r_client, - cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL, cachesize=LOCAL_CACHE_SIZE), + cache_registry=LocalLRUCache( + ttl=LOCAL_CACHE_TTL, + cachesize=LOCAL_CACHE_SIZE, + jitter_sec=self.registry_conf.jitter_conf.jitter_sec, + jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec, + ), ) - self.pl_conf = pl_conf or PipelineConf() # TODO: remove, and have an update config method def register_conf(self, config_id: str, conf: StreamConf) -> None: diff --git a/numalogic/udfs/postprocess.py b/numalogic/udfs/postprocess.py index fd5d17d4..afdb985b 100644 --- a/numalogic/udfs/postprocess.py +++ b/numalogic/udfs/postprocess.py @@ -41,11 +41,18 @@ def __init__( pl_conf: Optional[PipelineConf] = None, ): super().__init__() - model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.pl_conf = pl_conf or PipelineConf() + self.registry_conf = self.pl_conf.registry_conf + model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name) self.model_registry = model_registry_cls( - client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL) + client=r_client, + cache_registry=LocalLRUCache( + ttl=LOCAL_CACHE_TTL, + cachesize=LOCAL_CACHE_SIZE, + jitter_sec=self.registry_conf.jitter_conf.jitter_sec, + jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec, + ), ) - self.pl_conf = pl_conf or PipelineConf() self.postproc_factory = PostprocessFactory() def register_conf(self, config_id: str, conf: StreamConf) -> None: diff --git a/numalogic/udfs/preprocess.py b/numalogic/udfs/preprocess.py index 835d7792..3f79142e 100644 --- a/numalogic/udfs/preprocess.py +++ b/numalogic/udfs/preprocess.py @@ -37,11 +37,18 @@ class PreprocessUDF(NumalogicUDF): def __init__(self, r_client: redis_client_t, pl_conf: Optional[PipelineConf] = None): super().__init__() - model_registry_cls = RegistryFactory.get_cls("RedisRegistry") + self.pl_conf = pl_conf or PipelineConf() + self.registry_conf = self.pl_conf.registry_conf + model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name) self.model_registry = model_registry_cls( - client=r_client, cache_registry=LocalLRUCache(ttl=LOCAL_CACHE_TTL) + client=r_client, + cache_registry=LocalLRUCache( + ttl=LOCAL_CACHE_TTL, + cachesize=LOCAL_CACHE_SIZE, + jitter_sec=self.registry_conf.jitter_conf.jitter_sec, + jitter_steps_sec=self.registry_conf.jitter_conf.jitter_steps_sec, + ), ) - self.pl_conf = pl_conf or PipelineConf() self.preproc_factory = PreprocessFactory() def register_conf(self, config_id: str, conf: StreamConf) -> None: diff --git a/numalogic/udfs/trainer.py b/numalogic/udfs/trainer.py index 767dca46..82ca6103 100644 --- a/numalogic/udfs/trainer.py +++ b/numalogic/udfs/trainer.py @@ -43,9 +43,18 @@ def __init__( ): super().__init__(is_async=False) self.r_client = r_client - model_registry_cls = RegistryFactory.get_cls("RedisRegistry") - self.model_registry = model_registry_cls(client=r_client) self.pl_conf = pl_conf or PipelineConf() + self.registry_conf = self.pl_conf.registry_conf + model_registry_cls = RegistryFactory.get_cls(self.registry_conf.name) + model_expiry_sec = self.pl_conf.registry_conf.model_expiry_sec + jitter_sec = self.registry_conf.jitter_conf.jitter_sec + jitter_steps_sec = self.registry_conf.jitter_conf.jitter_steps_sec + self.model_registry = model_registry_cls( + client=r_client, + ttl=model_expiry_sec, + jitter_sec=jitter_sec, + jitter_steps_sec=jitter_steps_sec, + ) self.druid_conf = self.pl_conf.druid_conf data_fetcher_cls = ConnectorFactory.get_cls("DruidFetcher") @@ -169,7 +178,7 @@ def exec(self, keys: list[str], datum: Datum) -> Messages: # set the retry and retrain_freq retrain_freq_ts = _conf.numalogic_conf.trainer.retrain_freq_hr - retry_ts = _conf.numalogic_conf.trainer.retry_secs + retry_ts = _conf.numalogic_conf.trainer.retry_sec if not self.train_msg_deduplicator.ack_read( key=payload.composite_keys, uuid=payload.uuid, diff --git a/pyproject.toml b/pyproject.toml index e25d9868..8720d696 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "numalogic" -version = "0.6.0a9" +version = "0.6.0a10" description = "Collection of operational Machine Learning models and tools." authors = ["Numalogic Developers"] packages = [{ include = "numalogic" }] diff --git a/tests/config/test_optdeps.py b/tests/config/test_optdeps.py index 4d8115c5..ec4d0379 100644 --- a/tests/config/test_optdeps.py +++ b/tests/config/test_optdeps.py @@ -20,7 +20,9 @@ class TestOptionalDependencies(unittest.TestCase): def setUp(self) -> None: - self.regconf = RegistryInfo(name="RedisRegistry", conf=dict(ttl=50)) + self.regconf = RegistryInfo( + name="RedisRegistry", model_expiry_sec=10, extra_param=dict(ttl=50) + ) @patch("numalogic.config.factory.getattr", side_effect=AttributeError) def test_not_installed_dep_01(self, _): @@ -30,7 +32,7 @@ def test_not_installed_dep_01(self, _): server = fakeredis.FakeServer() redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False) with self.assertRaises(ImportError): - model_factory.get_cls("RedisRegistry")(redis_cli, **self.regconf.conf) + model_factory.get_cls("RedisRegistry")(redis_cli, **self.regconf.extra_param) @patch("numalogic.config.factory.getattr", side_effect=AttributeError) def test_not_installed_dep_02(self, _): @@ -40,13 +42,13 @@ def test_not_installed_dep_02(self, _): server = fakeredis.FakeServer() redis_cli = fakeredis.FakeStrictRedis(server=server, decode_responses=False) with self.assertRaises(ImportError): - model_factory.get_instance(self.regconf)(redis_cli, **self.regconf.conf) + model_factory.get_instance(self.regconf)(redis_cli, **self.regconf.extra_param) def test_unknown_registry(self): from numalogic.config.factory import RegistryFactory model_factory = RegistryFactory() - reg_conf = RegistryInfo(name="UnknownRegistry") + reg_conf = RegistryInfo(name="UnknownRegistry", model_expiry_sec=1) with self.assertRaises(UnknownConfigArgsError): model_factory.get_cls("UnknownRegistry") with self.assertRaises(UnknownConfigArgsError): diff --git a/tests/registry/test_cache.py b/tests/registry/test_cache.py index 1f2f5bc1..6a076e6e 100644 --- a/tests/registry/test_cache.py +++ b/tests/registry/test_cache.py @@ -3,13 +3,17 @@ from concurrent.futures import ThreadPoolExecutor from threading import Thread +from freezegun import freeze_time + from numalogic.models.autoencoder.variants import VanillaAE from numalogic.registry import LocalLRUCache, ArtifactData, ArtifactCache +from numalogic.registry.artifact import _apply_jitter +from numalogic.tools.exceptions import ConfigError class TestArtifactCache(unittest.TestCase): def test_cache(self): - cache_reg = ArtifactCache(cachesize=2, ttl=2) + cache_reg = ArtifactCache(cachesize=2, ttl=2, jitter_sec=180, jitter_steps_sec=2) with self.assertRaises(NotImplementedError): cache_reg.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={})) with self.assertRaises(NotImplementedError): @@ -21,15 +25,19 @@ def test_cache(self): class TestLocalLRUCache(unittest.TestCase): + def setUp(self): + cache = LocalLRUCache(cachesize=4, ttl=1, jitter_sec=0, jitter_steps_sec=0) + cache.clear() + def test_cache_size(self): - cache_registry = LocalLRUCache(cachesize=2, ttl=1) + cache_registry = LocalLRUCache(cachesize=2, ttl=1, jitter_sec=0, jitter_steps_sec=0) cache_registry.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={})) + time.sleep(1) cache_registry.save("m2", ArtifactData(VanillaAE(12, 1), metadata={}, extras={})) cache_registry.save("m3", ArtifactData(VanillaAE(14, 1), metadata={}, extras={})) - self.assertIsNone(cache_registry.load("m1")) self.assertIsInstance(cache_registry.load("m2"), ArtifactData) - self.assertEqual(2, cache_registry.cachesize) + self.assertEqual(4, cache_registry.cachesize) self.assertEqual(1, cache_registry.ttl) self.assertTrue("m2" in cache_registry) self.assertTrue("m3" in cache_registry) @@ -48,12 +56,13 @@ def test_cache_overwrite(self): self.assertDictEqual({"version": "2", "source": "cache"}, loaded_artifact.extras) def test_cache_ttl(self): - cache_registry = LocalLRUCache(cachesize=2, ttl=1) - cache_registry.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={})) - self.assertIsInstance(cache_registry.load("m1"), ArtifactData) - - time.sleep(1) - self.assertIsNone(cache_registry.load("m1")) + ts = "2021-01-01 00:00:00" + with freeze_time(ts): + cache_registry = LocalLRUCache(cachesize=2, ttl=1, jitter_sec=0, jitter_steps_sec=0) + cache_registry.save("m1", ArtifactData(VanillaAE(10, 1), metadata={}, extras={})) + self.assertIsInstance(cache_registry.load("m1"), ArtifactData) + time.sleep(1) + self.assertIsNone(cache_registry.load("m1")) def test_singleton(self): cache_registry_1 = LocalLRUCache(cachesize=2, ttl=1) @@ -75,6 +84,10 @@ def test_clear(self): cache_registry.clear() self.assertIsNone(cache_registry.load("m1")) + def test_apply_jitter(self): + with self.assertRaises(ConfigError): + _apply_jitter(1, jitter_sec=30, jitter_steps_sec=60) + def test_multithread(self): def load_cache(idx): artifact_data = cache_reg.load(f"key_{idx}") diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index ae84bb4a..7e51c3e6 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -135,14 +135,14 @@ def test_load_model_with_version(self): self.assertEqual(data.extras["version"], version) def test_check_if_model_stale_true(self): - delta = datetime.today() - timedelta(days=5) + delta = datetime.now() - timedelta(days=5) with freeze_time(delta): self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) data = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertTrue(self.registry.is_artifact_stale(data, 12)) def test_check_if_model_stale_false(self): - delta = datetime.today() + delta = datetime.now() with freeze_time(delta): self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) with freeze_time(delta + timedelta(hours=7)): @@ -169,21 +169,27 @@ def test_load_model_when_no_model(self): self.registry.load(skeys=self.skeys, dkeys=self.dkeys) def test_load_latest_model_twice(self): - with freeze_time(datetime.today() - timedelta(days=5)): + with freeze_time(datetime.now() - timedelta(days=5)): self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertTrue(self.registry.is_artifact_stale(artifact_data_1, 4)) self.assertEqual("registry", artifact_data_1.extras["source"]) - self.assertEqual("cache", artifact_data_2.extras["source"]) + with freeze_time(datetime.now() - timedelta(minutes=60)): + self.assertEqual("cache", artifact_data_2.extras["source"]) def test_load_latest_cache_ttl_expire(self): - self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) - artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) - time.sleep(1) - artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + cache = LocalLRUCache(cachesize=4, ttl=1, jitter_sec=0, jitter_steps_sec=0) + registry = RedisRegistry( + client=self.redis_client, + cache_registry=cache, + ) + registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + artifact_data_1 = registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertEqual("registry", artifact_data_1.extras["source"]) + time.sleep(1) + artifact_data_2 = registry.load(skeys=self.skeys, dkeys=self.dkeys) self.assertEqual("registry", artifact_data_2.extras["source"]) def test_multiple_save(self): diff --git a/tests/udfs/test_trainer.py b/tests/udfs/test_trainer.py index 9500ebfc..50196796 100644 --- a/tests/udfs/test_trainer.py +++ b/tests/udfs/test_trainer.py @@ -16,6 +16,7 @@ from numalogic._constants import TESTS_DIR from numalogic.config import NumalogicConf, ModelInfo from numalogic.config import TrainerConf, LightningTrainerConf +from numalogic.connectors import RedisConf from numalogic.connectors.druid import DruidFetcher from numalogic.tools.exceptions import ConfigNotFoundError from numalogic.udfs import StreamConf, PipelineConf @@ -246,9 +247,12 @@ def test_trainer_do_not_train_3(self): ) def test_trainer_conf_err(self): - udf = TrainerUDF(REDIS_CLIENT) + self.udf = TrainerUDF( + REDIS_CLIENT, + pl_conf=PipelineConf(redis_conf=RedisConf(url="redis://localhost:6379", port=0)), + ) with self.assertRaises(ConfigNotFoundError): - udf(self.keys, self.datum) + self.udf(self.keys, self.datum) @patch.object(DruidFetcher, "fetch", Mock(return_value=mock_druid_fetch_data(nrows=10))) def test_trainer_data_insufficient(self):