diff --git a/numalogic/registry/artifact.py b/numalogic/registry/artifact.py index 6547401e..3ba62742 100644 --- a/numalogic/registry/artifact.py +++ b/numalogic/registry/artifact.py @@ -48,6 +48,8 @@ class ArtifactManager(Generic[KEYS, A_D]): uri: server/connection uri """ + _STORETYPE = "registry" + __slots__ = ("uri",) def __init__(self, uri: str): @@ -137,6 +139,8 @@ class ArtifactCache(Generic[M_K, A_D]): ttl: time to live for each item in the cache """ + _STORETYPE = "cache" + __slots__ = ("_cachesize", "_ttl") def __init__(self, cachesize: int, ttl: int): diff --git a/numalogic/registry/dynamodb_registry.py b/numalogic/registry/dynamodb_registry.py index 520c68e8..4ff4432d 100644 --- a/numalogic/registry/dynamodb_registry.py +++ b/numalogic/registry/dynamodb_registry.py @@ -1,3 +1,14 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import time from typing import Any, Optional diff --git a/numalogic/registry/localcache.py b/numalogic/registry/localcache.py index 26b3e8e1..17db8f5b 100644 --- a/numalogic/registry/localcache.py +++ b/numalogic/registry/localcache.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import Optional from cachetools import TTLCache @@ -34,14 +35,55 @@ def __init__(self, cachesize: int = 512, ttl: int = 300): if not self.__cache: self.__cache = TTLCache(maxsize=cachesize, ttl=ttl) - def load(self, artifact_key: str) -> ArtifactData: + def __contains__(self, artifact_key: str) -> bool: + """Check if an artifact is in the cache.""" + return artifact_key in self.__cache + + def load(self, artifact_key: str) -> Optional[ArtifactData]: + """ + Load an artifact from the cache. + + Args: + ---- + artifact_key: The key of the artifact to load. + + Returns + ------- + The artifact data instance if found, None otherwise. + """ return self.__cache.get(artifact_key) def save(self, key: str, artifact: ArtifactData) -> None: + """ + Save an artifact to the cache. + + Args: + ---- + key: The key of the artifact to save. + artifact: The artifact data instance to save. + """ + artifact = deepcopy(artifact) + artifact.extras["source"] = self._STORETYPE self.__cache[key] = artifact def delete(self, key: str) -> Optional[ArtifactData]: + """ + Delete an artifact from the cache. + + Args: + ---- + key: The key of the artifact to delete. + + Returns + ------- + The deleted artifact data instance if found, None otherwise. + """ return self.__cache.pop(key, default=None) def clear(self) -> None: + """Clears the whole cache.""" self.__cache.clear() + + def keys(self) -> list[str]: + """Returns the current keys of the cache.""" + return list(_key for _key in self.__cache) diff --git a/numalogic/registry/redis_registry.py b/numalogic/registry/redis_registry.py index 6016c2dd..ce0508b9 100644 --- a/numalogic/registry/redis_registry.py +++ b/numalogic/registry/redis_registry.py @@ -1,3 +1,14 @@ +# Copyright 2022 The Numaproj Authors. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import time from datetime import datetime, timedelta @@ -18,7 +29,7 @@ class RedisRegistry(ArtifactManager): Args: ---- - client: Take in the reids client already established/created + client: Take in the redis client already established/created ttl: Total Time to Live (in seconds) for the key when saving in redis (dafault = 604800) cache_registry: Cache registry to use (default = None). @@ -94,6 +105,7 @@ def _load_from_cache(self, key: str) -> Optional[ArtifactData]: def _save_in_cache(self, key: str, artifact_data: ArtifactData) -> None: if self.cache_registry: + _LOGGER.debug("Saving artifact in cache with key: %s", key) self.cache_registry.save(key, artifact_data) def _clear_cache(self, key: Optional[str] = None) -> Optional[ArtifactData]: @@ -123,20 +135,38 @@ def __get_artifact_data( extras={ "timestamp": float(artifact_timestamp.decode()), "version": artifact_version.decode(), + "source": self._STORETYPE, }, ) - def __load_latest_artifact(self, key: str) -> ArtifactData: + def __load_latest_artifact(self, key: str) -> tuple[ArtifactData, bool]: + """ + Load the latest artifact from the registry. + + Args: + key: full model key. + + Returns + ------- + ArtifactData and a boolean flag indicating if the artifact was loaded from cache. + + Raises + ------ + ModelKeyNotFound: If the model key is not found in the registry. + """ cached_artifact = self._load_from_cache(key) if cached_artifact: _LOGGER.debug("Found cached artifact for key: %s", key) - return cached_artifact + return cached_artifact, True latest_key = self.__construct_latest_key(key) if not self.client.exists(latest_key): raise ModelKeyNotFound(f"latest key: {latest_key}, Not Found !!!") model_key = self.client.get(latest_key) - _LOGGER.info("latest key, %s, is pointing to the key : %s", latest_key, model_key) - return self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key) + _LOGGER.debug("latest key, %s, is pointing to the key : %s", latest_key, model_key) + return ( + self.__load_version_artifact(version=self.get_version(model_key.decode()), key=key), + False, + ) def __load_version_artifact(self, version: str, key: str) -> ArtifactData: model_key = self.__construct_version_key(key, version) @@ -152,7 +182,7 @@ def __save_artifact( new_version_key = self.__construct_version_key(key, version) latest_key = self.__construct_latest_key(key) pipe.set(name=latest_key, value=new_version_key) - _LOGGER.info("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key) + _LOGGER.debug("Setting latest key : %s ,to this new key = %s", latest_key, new_version_key) serialized_metadata = "" if metadata: serialized_metadata = dumps(deserialized_object=metadata) @@ -178,6 +208,8 @@ def load( """Loads the artifact from redis registry. Either latest or version (one of the arguments) is needed to load the respective artifact. + If cache registry is provided, it will first check the cache registry for the artifact. + Args: ---- skeys: static key fields as list/tuple of strings @@ -188,19 +220,26 @@ def load( Returns ------- ArtifactData instance + + Raises + ------ + ValueError: If both latest and version are provided or none of them are provided. + RedisRegistryError: If any redis error occurs. """ if (latest and version) or (not latest and not version): raise ValueError("Either One of 'latest' or 'version' needed in load method call") key = self.construct_key(skeys, dkeys) + is_cached = False try: if latest: - artifact_data = self.__load_latest_artifact(key) - self._save_in_cache(key, artifact_data) + artifact_data, is_cached = self.__load_latest_artifact(key) else: artifact_data = self.__load_version_artifact(version, key) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: + if (not is_cached) and latest: + self._save_in_cache(key, artifact_data) return artifact_data def save( @@ -221,24 +260,28 @@ def save( Returns ------- - model version + Model version (str) + + Raises + ------ + RedisRegistryError: If there is any RedisError while saving the artifact. """ key = self.construct_key(skeys, dkeys) latest_key = self.__construct_latest_key(key) version = 0 try: if self.client.exists(latest_key): - _LOGGER.debug("latest key exists for the model") + _LOGGER.debug("Latest key: %s exists for the model", latest_key) version_key = self.client.get(name=latest_key) version = int(self.get_version(version_key.decode())) + 1 with self.client.pipeline() as pipe: new_version_key = self.__save_artifact(pipe, artifact, metadata, key, str(version)) pipe.expire(name=new_version_key, time=self.ttl) - _LOGGER.info("Model with the key = %s, loaded successfully.", new_version_key) pipe.execute() except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: + _LOGGER.info("Model with the key = %s, saved successfully.", new_version_key) return str(version) def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: @@ -249,21 +292,25 @@ def delete(self, skeys: KEYS, dkeys: KEYS, version: str) -> None: skeys: static key fields as list/tuple of strings dkeys: dynamic key fields as list/tuple of strings version: model version to delete. + + Raises + ------ + ModelKeyNotFound: If the model version is not found in registry. + RedisRegistryError: If there is any RedisError while deleting the artifact. """ key = self.construct_key(skeys, dkeys) del_key = self.__construct_version_key(key, version) try: if self.client.exists(del_key): self.client.delete(del_key) - _LOGGER.info("Model with the key = %s, deleted successfully", del_key) else: - _LOGGER.debug("Key to delete: %s, Not Found !!!\n Exiting.....", del_key) raise ModelKeyNotFound( - "Key to delete: %s, Not Found !!!\n Exiting....." % del_key, + "Key to delete: %s, Not Found!" % del_key, ) except RedisError as err: raise RedisRegistryError(f"{err.__class__.__name__} raised") from err else: + _LOGGER.info("Model with the key = %s, deleted successfully", del_key) self._clear_cache(del_key) @staticmethod @@ -276,6 +323,13 @@ def is_artifact_stale(artifact_data: ArtifactData, freq_hr: int) -> bool: artifact_data: ArtifactData object to look into freq_hr: Frequency of retraining in hours. + Returns + ------- + True if artifact is stale, False otherwise. + + Raises + ------ + RedisRegistryError: If there is any error while fetching timestamp information. """ try: artifact_ts = float(artifact_data.extras["timestamp"]) diff --git a/tests/registry/test_cache.py b/tests/registry/test_cache.py index f7edc9df..26882de2 100644 --- a/tests/registry/test_cache.py +++ b/tests/registry/test_cache.py @@ -27,9 +27,11 @@ def test_cache_size(self): self.assertIsNone(cache_registry.load("m1")) self.assertIsInstance(cache_registry.load("m2"), ArtifactData) - self.assertIsInstance(cache_registry.load("m3"), ArtifactData) self.assertEqual(2, cache_registry.cachesize) self.assertEqual(1, cache_registry.ttl) + self.assertTrue("m2" in cache_registry) + self.assertTrue("m3" in cache_registry) + self.assertListEqual(["m2", "m3"], cache_registry.keys()) def test_cache_overwrite(self): cache_registry = LocalLRUCache(cachesize=2, ttl=1) @@ -41,7 +43,7 @@ def test_cache_overwrite(self): ) loaded_artifact = cache_registry.load("m1") - self.assertDictEqual({"version": "2"}, loaded_artifact.extras) + self.assertDictEqual({"version": "2", "source": "cache"}, loaded_artifact.extras) def test_cache_ttl(self): cache_registry = LocalLRUCache(cachesize=2, ttl=1) diff --git a/tests/registry/test_redis_registry.py b/tests/registry/test_redis_registry.py index f80c7093..7e821472 100644 --- a/tests/registry/test_redis_registry.py +++ b/tests/registry/test_redis_registry.py @@ -1,3 +1,5 @@ +import logging +import time import unittest from datetime import datetime, timedelta from unittest.mock import Mock, patch @@ -12,6 +14,8 @@ from numalogic.registry import RedisRegistry, LocalLRUCache, ArtifactData from numalogic.tools.exceptions import ModelKeyNotFound, RedisRegistryError +logging.basicConfig(level=logging.DEBUG) + class TestRedisRegistry(unittest.TestCase): @classmethod @@ -25,7 +29,7 @@ def setUpClass(cls) -> None: cls.redis_client = fakeredis.FakeStrictRedis(server=server, decode_responses=False) def setUp(self): - self.cache = LocalLRUCache(cachesize=4, ttl=300) + self.cache = LocalLRUCache(cachesize=4, ttl=1) self.registry = RedisRegistry( client=self.redis_client, cache_registry=self.cache, @@ -162,19 +166,44 @@ def test_load_model_when_no_model(self): with self.assertRaises(ModelKeyNotFound): self.registry.load(skeys=self.skeys, dkeys=self.dkeys) - def test_load_model_when_model_stale(self): - with self.assertRaises(ModelKeyNotFound): - version = self.registry.save( - skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model - ) - self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=str(version)) - self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + def test_load_latest_model_twice(self): + with freeze_time(datetime.today() - timedelta(days=5)): + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + + artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertTrue(self.registry.is_artifact_stale(artifact_data_1, 4)) + self.assertEqual("registry", artifact_data_1.extras["source"]) + self.assertEqual("cache", artifact_data_2.extras["source"]) + + def test_load_latest_cache_ttl_expire(self): + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + artifact_data_1 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + time.sleep(1) + artifact_data_2 = self.registry.load(skeys=self.skeys, dkeys=self.dkeys) + self.assertEqual("registry", artifact_data_1.extras["source"]) + self.assertEqual("registry", artifact_data_2.extras["source"]) + + def test_load_non_latest_model_twice(self): + old_version = self.registry.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) + self.registry.save(skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model) + + artifact_data_1 = self.registry.load( + skeys=self.skeys, dkeys=self.dkeys, latest=False, version=old_version + ) + artifact_data_2 = self.registry.load( + skeys=self.skeys, dkeys=self.dkeys, latest=False, version=old_version + ) + self.assertEqual("registry", artifact_data_1.extras["source"]) + self.assertEqual("registry", artifact_data_2.extras["source"]) def test_delete_version(self): + version = self.registry.save( + skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model + ) with self.assertRaises(ModelKeyNotFound): - version = self.registry.save( - skeys=self.skeys, dkeys=self.dkeys, artifact=self.pytorch_model - ) self.registry.delete(skeys=self.skeys, dkeys=self.dkeys, version=str(version)) self.registry.load(skeys=self.skeys, dkeys=self.dkeys)