Skip to content

Commit

Permalink
cache and load registry (#11)
Browse files Browse the repository at this point in the history
* removed the <Cache/Load>MethodExecutors
* replaced them with a registry, which allows users to decorate custom functions and register their own
* reworked cache and load methods into data_io, and allowed people to register their own
  • Loading branch information
guybuk authored Jul 16, 2024
1 parent f954653 commit ce71e13
Show file tree
Hide file tree
Showing 26 changed files with 257 additions and 266 deletions.
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

if TYPE_CHECKING:
from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.sample.transform import SampleTransform

Expand Down
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/sample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

if TYPE_CHECKING:
from bridge.display.display_engine import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample import Sample
Expand Down
2 changes: 1 addition & 1 deletion bridge/primitives/dataset/singular_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
from bridge.display import DisplayEngine
from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism
from bridge.primitives.element.data.cache_mechanism import CacheMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_type import ElementType
from bridge.primitives.sample.transform import SampleTransform
Expand Down
Empty file.
87 changes: 0 additions & 87 deletions bridge/primitives/element/data/cache/cache_methods.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,14 @@

import pandas as pd

from bridge.primitives.element.data.cache.cache_methods import CachingMethodExecutor
from bridge.primitives.element.data.category import DataCategory
from bridge.primitives.element.data import data_io
from bridge.primitives.element.data.uri_components import URIComponents

if TYPE_CHECKING:
from bridge.primitives.element.data.load.load_mechanism import LoadMechanism
from bridge.primitives.element.data.load_mechanism import LoadMechanism
from bridge.primitives.element.element import Element
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE

CATEGORY_TO_EXTENSION = {
DataCategory.image: ".jpg",
DataCategory.torch: ".pt",
DataCategory.text: ".txt",
DataCategory.obj: ".pkl",
}


class CacheMechanism:
def __init__(self, root_uri: URIComponents | None = None):
Expand All @@ -33,24 +25,24 @@ def store(
self,
element: Element,
data: ELEMENT_DATA_TYPE,
as_category: DataCategory | None = None,
as_category: str | None = None,
should_update_elements: bool = False,
) -> LoadMechanism:
if as_category is None:
as_category = element.category
uri = self._build_uri(element, as_category)
new_provider = CachingMethodExecutor().store(data, uri, as_category) # noqa
new_provider = data_io.store(data, uri, as_category)
if should_update_elements and self._elements is not None:
self._update_samples_with_new_provider(element.id, new_provider)
return new_provider

def _build_uri(self, element: Element, category: DataCategory) -> URIComponents | None:
def _build_uri(self, element: Element, category: str) -> URIComponents | None:
if self._root_uri is None:
return None

uri = URIComponents(
scheme=self._root_uri.scheme,
path=self._root_uri.path + f"/{element.id}{CATEGORY_TO_EXTENSION[category]}",
path=self._root_uri.path + f"/{element.id}{data_io.extension(category)}",
)
return uri

Expand Down
9 changes: 0 additions & 9 deletions bridge/primitives/element/data/category.py

This file was deleted.

186 changes: 186 additions & 0 deletions bridge/primitives/element/data/data_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import abc
from pathlib import Path
from typing import Any

import numpy as np

from bridge.primitives.element.data.load_mechanism import LoadMechanism
from bridge.primitives.element.data.uri_components import URIComponents
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE

REGISTRY = {}


def register(cls):
if cls.category in REGISTRY:
raise ValueError(f"Category {cls.category} is already registered.")
REGISTRY[cls.category] = cls
return cls


class DataIO(abc.ABC):
@property
@abc.abstractmethod
def category(self):
pass

@property
@abc.abstractmethod
def extension(self):
pass

@classmethod
@abc.abstractmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
pass

@classmethod
@abc.abstractmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
pass


@register
class JPEGDataIO(DataIO):
category = "image"
extension = ".jpg"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return np.array(url_or_data) # assumes object is a PIL image or np.ndarray

if url_or_data.scheme not in ["http", "https", "file", ""]:
raise NotImplementedError("Only loading from local or http(s) URLs is supported for now.")
from skimage.io import imread

return imread(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
import PIL.Image
from skimage.io import imsave

if url is None:
return LoadMechanism(PIL.Image.fromarray(data), cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")

path = Path(str(url)).expanduser()

Path.mkdir(path.parent, parents=True, exist_ok=True)
imsave(path, data)
return LoadMechanism.from_url_string(str(url), cls.category)


@register
class TorchDataIO(DataIO):
category = "torch"
extension = ".pt"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data # assume that is already torch tensor
import torch

return torch.load(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
if url is None:
return LoadMechanism(data, cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")
import torch

path = Path(str(url))

Path.mkdir(path.parent, parents=True, exist_ok=True)
torch.save(data, path)
return LoadMechanism.from_url_string(str(url), cls.category)


@register
class NumpyDataIO(DataIO):
category = "numpy"
extension = ".npy"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
return np.load(str(url_or_data))

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
raise NotImplementedError()


@register
class TextDataIO(DataIO):
category = "text"
extension = ".txt"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
return open(str(url_or_data), "r").read()

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
raise NotImplementedError()


@register
class ObjDataIO(DataIO):
category = "obj"
extension = ".pkl"

@classmethod
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE:
if not isinstance(url_or_data, URIComponents):
return url_or_data
raise NotImplementedError()

@classmethod
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism:
if url is None:
return LoadMechanism(data, cls.category)

if url.scheme not in ["", "file"]:
raise NotImplementedError("Only saving locally is supported for now.")
import pickle

path = Path(str(url))

Path.mkdir(path.parent, parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(data, f)
return LoadMechanism.from_url_string(str(url), cls.category)


def store(data: Any, url: URIComponents | None, category: str) -> LoadMechanism:
return REGISTRY[category].store(data, url)


def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: str) -> ELEMENT_DATA_TYPE:
return REGISTRY[category].load(url_or_data)


def extension(category: str) -> str:
return REGISTRY[category].extension


def is_registered(category: str) -> bool:
return category in REGISTRY


def list_registered_categories() -> list[str]:
return list(REGISTRY.keys())


__all__ = ["register", "store", "load", "extension", "is_registered", "list_registered_categories"]
Empty file.
Loading

0 comments on commit ce71e13

Please sign in to comment.