-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* removed the <Cache/Load>MethodExecutors * replaced them with a registry, which allows users to decorate custom functions and register their own * reworked cache and load methods into data_io, and allowed people to register their own
- Loading branch information
Showing
26 changed files
with
257 additions
and
266 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import abc | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
import numpy as np | ||
|
||
from bridge.primitives.element.data.load_mechanism import LoadMechanism | ||
from bridge.primitives.element.data.uri_components import URIComponents | ||
from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE | ||
|
||
REGISTRY = {} | ||
|
||
|
||
def register(cls): | ||
if cls.category in REGISTRY: | ||
raise ValueError(f"Category {cls.category} is already registered.") | ||
REGISTRY[cls.category] = cls | ||
return cls | ||
|
||
|
||
class DataIO(abc.ABC): | ||
@property | ||
@abc.abstractmethod | ||
def category(self): | ||
pass | ||
|
||
@property | ||
@abc.abstractmethod | ||
def extension(self): | ||
pass | ||
|
||
@classmethod | ||
@abc.abstractmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
pass | ||
|
||
@classmethod | ||
@abc.abstractmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
pass | ||
|
||
|
||
@register | ||
class JPEGDataIO(DataIO): | ||
category = "image" | ||
extension = ".jpg" | ||
|
||
@classmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
if not isinstance(url_or_data, URIComponents): | ||
return np.array(url_or_data) # assumes object is a PIL image or np.ndarray | ||
|
||
if url_or_data.scheme not in ["http", "https", "file", ""]: | ||
raise NotImplementedError("Only loading from local or http(s) URLs is supported for now.") | ||
from skimage.io import imread | ||
|
||
return imread(str(url_or_data)) | ||
|
||
@classmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
import PIL.Image | ||
from skimage.io import imsave | ||
|
||
if url is None: | ||
return LoadMechanism(PIL.Image.fromarray(data), cls.category) | ||
|
||
if url.scheme not in ["", "file"]: | ||
raise NotImplementedError("Only saving locally is supported for now.") | ||
|
||
path = Path(str(url)).expanduser() | ||
|
||
Path.mkdir(path.parent, parents=True, exist_ok=True) | ||
imsave(path, data) | ||
return LoadMechanism.from_url_string(str(url), cls.category) | ||
|
||
|
||
@register | ||
class TorchDataIO(DataIO): | ||
category = "torch" | ||
extension = ".pt" | ||
|
||
@classmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
if not isinstance(url_or_data, URIComponents): | ||
return url_or_data # assume that is already torch tensor | ||
import torch | ||
|
||
return torch.load(str(url_or_data)) | ||
|
||
@classmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
if url is None: | ||
return LoadMechanism(data, cls.category) | ||
|
||
if url.scheme not in ["", "file"]: | ||
raise NotImplementedError("Only saving locally is supported for now.") | ||
import torch | ||
|
||
path = Path(str(url)) | ||
|
||
Path.mkdir(path.parent, parents=True, exist_ok=True) | ||
torch.save(data, path) | ||
return LoadMechanism.from_url_string(str(url), cls.category) | ||
|
||
|
||
@register | ||
class NumpyDataIO(DataIO): | ||
category = "numpy" | ||
extension = ".npy" | ||
|
||
@classmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
if not isinstance(url_or_data, URIComponents): | ||
return url_or_data | ||
return np.load(str(url_or_data)) | ||
|
||
@classmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
raise NotImplementedError() | ||
|
||
|
||
@register | ||
class TextDataIO(DataIO): | ||
category = "text" | ||
extension = ".txt" | ||
|
||
@classmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
if not isinstance(url_or_data, URIComponents): | ||
return url_or_data | ||
return open(str(url_or_data), "r").read() | ||
|
||
@classmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
raise NotImplementedError() | ||
|
||
|
||
@register | ||
class ObjDataIO(DataIO): | ||
category = "obj" | ||
extension = ".pkl" | ||
|
||
@classmethod | ||
def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: | ||
if not isinstance(url_or_data, URIComponents): | ||
return url_or_data | ||
raise NotImplementedError() | ||
|
||
@classmethod | ||
def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: | ||
if url is None: | ||
return LoadMechanism(data, cls.category) | ||
|
||
if url.scheme not in ["", "file"]: | ||
raise NotImplementedError("Only saving locally is supported for now.") | ||
import pickle | ||
|
||
path = Path(str(url)) | ||
|
||
Path.mkdir(path.parent, parents=True, exist_ok=True) | ||
with open(path, "wb") as f: | ||
pickle.dump(data, f) | ||
return LoadMechanism.from_url_string(str(url), cls.category) | ||
|
||
|
||
def store(data: Any, url: URIComponents | None, category: str) -> LoadMechanism: | ||
return REGISTRY[category].store(data, url) | ||
|
||
|
||
def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: str) -> ELEMENT_DATA_TYPE: | ||
return REGISTRY[category].load(url_or_data) | ||
|
||
|
||
def extension(category: str) -> str: | ||
return REGISTRY[category].extension | ||
|
||
|
||
def is_registered(category: str) -> bool: | ||
return category in REGISTRY | ||
|
||
|
||
def list_registered_categories() -> list[str]: | ||
return list(REGISTRY.keys()) | ||
|
||
|
||
__all__ = ["register", "store", "load", "extension", "is_registered", "list_registered_categories"] |
Empty file.
Oops, something went wrong.