Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cache and load registry #11

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.sample.transform import SampleTransform

Expand Down
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/sample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

if TYPE_CHECKING:
from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample import Sample
Expand Down
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/singular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from bridge.display import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample.transform import SampleTransform
Expand Down
Empty file.
87 changes: 0 additions & 87 deletions bridge/primitives/element/data/cache/cache_methods.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,14 @@

import pandas as pd

from bridge.primitives.element.data.cache.cache_methods import CachingMethodExecutor
from bridge.primitives.element.data.category import DataCategory
from bridge.primitives.element.data import data_io
from bridge.primitives.element.data.uri_components import URIComponents

if TYPE_CHECKING:
from bridge.primitives.element.data.load.load_mechanism import LoadMechanism
from bridge.primitives.element.data.load_mechanism import LoadMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE

CATEGORY_TO_EXTENSION = {
DataCategory.image: ".jpg",
DataCategory.torch: ".pt",
DataCategory.text: ".txt",
DataCategory.obj: ".pkl",
}


class CacheMechanism:
def __init__(self, root_uri: URIComponents | None = None):
Expand All @@ -33,24 +25,24 @@ def store(
self,
element: Element,
data: ELEMENT_DATA_TYPE,
as_category: DataCategory | None = None,
as_category: str | None = None,
should_update_elements: bool = False,
) -> LoadMechanism:
if as_category is None:
as_category = element.category
uri = self._build_uri(element, as_category)
new_provider = CachingMethodExecutor().store(data, uri, as_category) # noqa
new_provider = data_io.store(data, uri, as_category)
if should_update_elements and self._elements is not None:
self._update_samples_with_new_provider(element.id, new_provider)
return new_provider

def _build_uri(self, element: Element, category: DataCategory) -> URIComponents | None:
def _build_uri(self, element: Element, category: str) -> URIComponents | None:
if self._root_uri is None:
return None

uri = URIComponents(
scheme=self._root_uri.scheme,
path=self._root_uri.path + f"/{element.id}{CATEGORY_TO_EXTENSION[category]}",
path=self._root_uri.path + f"/{element.id}{data_io.extension(category)}",
)
return uri

Expand Down
9 changes: 0 additions & 9 deletions bridge/primitives/element/data/category.py

This file was deleted.

186 changes: 186 additions & 0 deletions bridge/primitives/element/data/data_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import abc
from pathlib import Path
from typing import Any

import numpy as np

from bridge.primitives.element.data.load_mechanism import LoadMechanism
from bridge.primitives.element.data.uri_components import URIComponents
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE

REGISTRY = {}


def register(cls):
if cls.category in REGISTRY:
raise ValueError(f"Category {cls.category} is already registered.")
REGISTRY[cls.category] = cls
return cls


class DataIO(abc.ABC):
@property
@abc.abstractmethod
def category(self):
pass

@property
@abc.abstractmethod
def extension(self):
pass

@classmethod
@abc.abstractmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
pass

@classmethod
@abc.abstractmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
pass


@register
class JPEGDataIO(DataIO):
category = "image"
extension = ".jpg"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return np.array(url_or_data) # assumes object is a PIL image or np.ndarray

if url_or_data.scheme not in ["http", "https", "file", ""]:
raise NotImplementedError("Only loading from local or http(s) URLs is supported for now.")
from skimage.io import imread

return imread(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
import PIL.Image
from skimage.io import imsave

if url is None:
return LoadMechanism(PIL.Image.fromarray(data), cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")

path = Path(str(url)).expanduser()

Path.mkdir(path.parent, parents=True, exist_ok=True)
imsave(path, data)
return LoadMechanism.from_url_string(str(url), cls.category)


@register
class TorchDataIO(DataIO):
category = "torch"
extension = ".pt"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data # assume that is already torch tensor
import torch

return torch.load(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
if url is None:
return LoadMechanism(data, cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")
import torch

path = Path(str(url))

Path.mkdir(path.parent, parents=True, exist_ok=True)
torch.save(data, path)
return LoadMechanism.from_url_string(str(url), cls.category)


@register
class NumpyDataIO(DataIO):
category = "numpy"
extension = ".npy"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
return np.load(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
raise NotImplementedError()


@register
class TextDataIO(DataIO):
category = "text"
extension = ".txt"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
return open(str(url_or_data), "r").read()

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
raise NotImplementedError()


@register
class ObjDataIO(DataIO):
category = "obj"
extension = ".pkl"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
raise NotImplementedError()

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
if url is None:
return LoadMechanism(data, cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")
import pickle

path = Path(str(url))

Path.mkdir(path.parent, parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(data, f)
return LoadMechanism.from_url_string(str(url), cls.category)


def store(data: Any, url: URIComponents | None, category: str) -> LoadMechanism:
return REGISTRY[category].store(data, url)


def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: str) -> ELEMENT_DATA_TYPE:
return REGISTRY[category].load(url_or_data)


def extension(category: str) -> str:
return REGISTRY[category].extension


def is_registered(category: str) -> bool:
return category in REGISTRY


def list_registered_categories() -> list[str]:
return list(REGISTRY.keys())


__all__ = ["register", "store", "load", "extension", "is_registered", "list_registered_categories"]
Empty file.
Loading
Loading