Skip to content

Commit

Permalink
registry lint
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <[email protected]>
  • Loading branch information
ab93 committed Jun 6, 2024
1 parent 3303a7d commit d106fb7
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 9 deletions.
31 changes: 31 additions & 0 deletions libs/numalogic-registry/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
default_language_version:
python: python3.9
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
language_version: python3.9
args: [--config=pyproject.toml, --diff, --color ]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.16.0"
hooks:
- id: blacken-docs
additional_dependencies:
- black==22.12.0
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-toml
- id: check-added-large-files
- id: check-ast
- id: check-case-conflict
- id: check-docstring-first
4 changes: 2 additions & 2 deletions libs/numalogic-registry/nlregistry/dynamodb_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ class DynamoDBRegistry(ArtifactManager):
Examples
--------
>>> from numalogic.models.autoencoder.variants import VanillaAE
>>> from numalogic.test_registry import DynamoDBRegistry
>>> from nlregistry import DynamoDBRegistry
>>> ...
>>> test_registry = DynamoDBRegistry(table="mytable", role="arn:aws:iam::1234567890:role/role-name")
>>> test_registry = DynamoDBRegistry(table="mytable", role="arn:aws:iam::123:role/name")
>>> skeys, dkeys = ("mymetric", "ae"), ("vanilla", "seq10")
>>> model = VanillaAE(seq_len=10)
>>> test_registry.save(skeys, dkeys, artifact=model, **{'lr': 0.01})
Expand Down
6 changes: 4 additions & 2 deletions libs/numalogic-registry/nlregistry/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def get_instance(self, object_info: Union[ModelInfo, RegistryInfo]):
except AttributeError as err:
if object_info.name in self._CLS_SET:
raise ImportError(
"Please install the required dependencies for the test_registry you want to use."
"Please install the required dependencies for the test_registry "
"you want to use."
) from err
raise UnknownConfigArgsError(
f"Invalid model info instance provided: {object_info}"
Expand All @@ -36,7 +37,8 @@ def get_cls(cls, name: str):
except AttributeError as err:
if name in cls._CLS_SET:
raise ImportError(
"Please install the required dependencies for the test_registry you want to use."
"Please install the required dependencies for the test_registry "
"you want to use."
) from err
raise UnknownConfigArgsError(
f"Invalid name provided for RegistryFactory: {name}"
Expand Down
6 changes: 4 additions & 2 deletions libs/numalogic-registry/nlregistry/mlflow_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def load(
version: Optional[str] = None,
artifact_type: str = "pytorch",
) -> Optional[ArtifactData]:
"""Load the artifact from the test_registry. The artifact is loaded from the cache if available.
"""Load the artifact from the test_registry.
The artifact is loaded from the cache if available.
Args:
----
Expand All @@ -167,7 +169,7 @@ def load(
return cached_artifact
version_info = self.client.get_latest_versions(model_key, stages=[self.model_stage])
if not version_info:
raise ModelVersionError("Model version missing for key = %s" % model_key)
raise ModelVersionError(f"Model version missing for key = {model_key}")
version_info = version_info[-1]
else:
version_info = self.client.get_model_version(model_key, version)
Expand Down
9 changes: 6 additions & 3 deletions libs/numalogic-registry/nlregistry/redis_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,13 @@ def load(
latest: bool = True,
version: Optional[str] = None,
) -> Optional[ArtifactData]:
"""Loads the artifact from redis test_registry. Either latest or version (one of the arguments)
is needed to load the respective artifact.
"""Loads the artifact from redis test_registry.
If cache test_registry is provided, it will first check the cache test_registry for the artifact.
Either latest or version (one of the arguments)
is needed to load the respective artifact.
If cache test_registry is provided, it will first check the cache
test_registry for the artifact.
If latest is passed, latest key is saved otherwise version call saves the respective
version artifact in cache.
Expand Down

0 comments on commit d106fb7

Please sign in to comment.