Skip to content

Commit

Permalink
core: implement TODO and relative import cleanup (#24)
Browse files Browse the repository at this point in the history
* use relative imports

Signed-off-by: Isabella do Amaral <[email protected]>

* use dataclass where possible

Signed-off-by: Isabella do Amaral <[email protected]>

* remove TODO re oras-py#146

Signed-off-by: Isabella do Amaral <[email protected]>

* patch missing test metadata

Signed-off-by: Isabella do Amaral <[email protected]>

* update with suggestions

Signed-off-by: Isabella do Amaral <[email protected]>

* store entire response on push

Signed-off-by: Isabella do Amaral <[email protected]>

* skip 2xx check on Oras-py Registry.push response

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Oct 30, 2024
1 parent 8b154c5 commit a1b2030
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 93 deletions.
34 changes: 18 additions & 16 deletions omlmd/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Command line interface for OMLMD."""

from __future__ import annotations

import logging
from pathlib import Path

import click
import cloup
import logging

from omlmd.helpers import Helper
from omlmd.model_metadata import deserialize_mdfile
from omlmd.provider import OMLMDRegistry

from .helpers import Helper
from .model_metadata import deserialize_mdfile

logger = logging.getLogger(__name__)

Expand All @@ -23,10 +23,6 @@
)


def get_OMLMDRegistry(plain_http: bool) -> OMLMDRegistry:
return OMLMDRegistry(insecure=plain_http)


@cloup.group()
def cli():
logging.basicConfig(level=logging.INFO)
Expand All @@ -45,7 +41,7 @@ def cli():
@click.option("--media-types", "-m", multiple=True, default=[])
def pull(plain_http: bool, target: str, output: Path, media_types: tuple[str]):
"""Pulls an OCI Artifact containing ML model and metadata, filtering if necessary."""
Helper(get_OMLMDRegistry(plain_http)).pull(target, output, media_types)
Helper.from_default_registry(plain_http).pull(target, output, media_types)


@cli.group()
Expand All @@ -58,15 +54,15 @@ def get():
@click.argument("target", required=True)
def config(plain_http: bool, target: str):
"""Outputs configuration of the given OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).get_config(target))
click.echo(Helper.from_default_registry(plain_http).get_config(target))


@cli.command()
@plain_http
@click.argument("targets", required=True, nargs=-1)
def crawl(plain_http: bool, targets: tuple[str]):
"""Crawls configuration for the given list of OCI Artifact for ML model and metadata."""
click.echo(Helper(get_OMLMDRegistry(plain_http)).crawl(targets))
click.echo(Helper.from_default_registry(plain_http).crawl(targets))


@cli.command()
Expand All @@ -83,15 +79,21 @@ def crawl(plain_http: bool, targets: tuple[str]):
"-m",
"--metadata",
type=click.Path(path_type=Path, exists=True, resolve_path=True),
help="Metadata file in JSON or YAML format"
help="Metadata file in JSON or YAML format",
),
cloup.option('--empty-metadata', help='Push with empty metadata', is_flag=True),
cloup.option("--empty-metadata", help="Push with empty metadata", is_flag=True),
constraint=cloup.constraints.require_one,
)
def push(plain_http: bool, target: str, path: Path, metadata: Path | None, empty_metadata: bool):
def push(
plain_http: bool,
target: str,
path: Path,
metadata: Path | None,
empty_metadata: bool,
):
"""Pushes an OCI Artifact containing ML model and metadata, supplying metadata from file as necessary"""

if empty_metadata:
logger.warning(f"Pushing to {target} with empty metadata.")
md = deserialize_mdfile(metadata) if metadata else {}
click.echo(Helper(get_OMLMDRegistry(plain_http)).push(target, path, **md))
click.echo(Helper.from_default_registry(plain_http).push(target, path, **md))
35 changes: 16 additions & 19 deletions omlmd/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
import os
import urllib.request
from collections.abc import Sequence
from dataclasses import fields
from dataclasses import dataclass, field, fields
from pathlib import Path

from omlmd.constants import (
from .constants import (
FILENAME_METADATA_JSON,
FILENAME_METADATA_YAML,
MIME_APPLICATION_CONFIG,
MIME_APPLICATION_MLMODEL,
)
from omlmd.listener import Event, Listener, PushEvent
from omlmd.model_metadata import ModelMetadata
from omlmd.provider import OMLMDRegistry

from .listener import Event, Listener, PushEvent
from .model_metadata import ModelMetadata
from .provider import OMLMDRegistry

logger = logging.getLogger(__name__)

Expand All @@ -27,20 +26,16 @@ def download_file(uri: str):
return file_name


@dataclass
class Helper:
_listeners: list[Listener] = []

def __init__(self, registry: OMLMDRegistry | None = None):
if registry is None:
self._registry = OMLMDRegistry(
insecure=True
) # TODO: this is a bit limiting when used from CLI, to be refactored
else:
self._registry = registry
_registry: OMLMDRegistry = field(
default_factory=lambda: OMLMDRegistry(insecure=True)
)
_listeners: list[Listener] = field(default_factory=list)

@property
def registry(self):
return self._registry
@classmethod
def from_default_registry(cls, insecure: bool):
return cls(OMLMDRegistry(insecure=insecure))

def push(
self,
Expand Down Expand Up @@ -102,7 +97,9 @@ def push(
manifest_config=manifest_cfg,
do_chunked=True,
)
self.notify_listeners(PushEvent(target, model_metadata))
self.notify_listeners(
PushEvent.from_response(result, target, model_metadata)
)
return result
finally:
if owns_meta_files:
Expand Down
25 changes: 17 additions & 8 deletions omlmd/listener.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import typing as t
from abc import ABC, abstractmethod
from typing import Any
from dataclasses import dataclass

from omlmd.model_metadata import ModelMetadata
import requests

from .model_metadata import ModelMetadata


class Listener(ABC):
Expand All @@ -12,19 +15,25 @@ class Listener(ABC):
"""

@abstractmethod
def update(self, source: Any, event: Event) -> None:
def update(self, source: t.Any, event: Event) -> None:
"""
Receive update event.
"""
pass


class Event:
class Event(ABC):
pass


@dataclass
class PushEvent(Event):
def __init__(self, target: str, metadata: ModelMetadata):
# TODO: cannot just receive yet the push sha, waiting for: https://github.com/oras-project/oras-py/pull/146 in a release.
self.target = target
self.metadata = metadata
digest: str
target: str
metadata: ModelMetadata

@classmethod
def from_response(
cls, response: requests.Response, target: str, metadata: ModelMetadata
) -> "PushEvent":
return cls(response.headers["Docker-Content-Digest"], target, metadata)
38 changes: 6 additions & 32 deletions omlmd/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@
import os
import tempfile

import oras.defaults
import oras.oci
import oras.provider
import oras.schemas
import oras.utils
from oras import provider
from oras.decorator import ensure_container
from oras.provider import container_type
from oras.defaults import annotation_title as ANNOTATION_TITLE
from oras.utils import sanitize_path

logger = logging.getLogger(__name__)


class OMLMDRegistry(oras.provider.Registry):
class OMLMDRegistry(provider.Registry):
@ensure_container
def download_layers(self, package, download_dir, media_types):
"""
Expand All @@ -33,8 +30,8 @@ def download_layers(self, package, download_dir, media_types):
or len(media_types) == 0
or layer["mediaType"] in media_types
):
artifact = layer["annotations"]["org.opencontainers.image.title"]
outfile = oras.utils.sanitize_path(
artifact = layer["annotations"][ANNOTATION_TITLE]
outfile = sanitize_path(
download_dir, os.path.join(download_dir, artifact)
)
path = self.download_blob(package, layer["digest"], outfile)
Expand Down Expand Up @@ -74,26 +71,3 @@ def get_config(self, package) -> str:
os.rmdir(temp_dir)
# print("Temporary directory and its contents have been removed.")
raise RuntimeError("Unable to locate config layer")

@ensure_container
def get_manifest_response(
self,
container: container_type,
allowed_media_type: list | None = None,
refresh_headers: bool = True,
) -> dict:
"""
like get_manifest but return response,
temporary until https://github.com/oras-project/oras-py/pull/146 in a release.
"""
if not allowed_media_type:
allowed_media_type = [oras.defaults.default_manifest_media_type]
headers = {"Accept": ";".join(allowed_media_type)}

if not refresh_headers:
headers.update(self.headers)

get_manifest = f"{self.prefix}://{container.manifest_url()}" # type: ignore
response = self.do_request(get_manifest, "GET", headers=headers)
self._check_200_response(response)
return response
25 changes: 20 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ scikit-learn = "^1.5.0"
ipykernel = "^6.29.4"
nbconvert = "^7.16.4"
markdown-it-py = "^3.0.0"
model-registry = "^0.2.4a1"
model-registry = ">=0.2.9,<0.3.0"
ruff = "^0.6.1"
mypy = "^1.11.1"
types-pyyaml = "^6.0.12.20240808"
types-requests = "^2.32.0.20241016"

[tool.poetry.scripts]
omlmd = "omlmd.cli:cli"
Expand All @@ -50,7 +51,9 @@ target-version = "py39"
respect-gitignore = true

[tool.ruff.lint.per-file-ignores]
"*.ipynb" = ["E402"] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos.
"*.ipynb" = [
"E402",
] # exclude https://docs.astral.sh/ruff/rules/module-import-not-at-top-of-file/#notebook-behavior from linting, especially for demos.

[tool.mypy]
python_version = "3.9"
Expand Down
Loading

0 comments on commit a1b2030

Please sign in to comment.