diff --git a/bridge/primitives/dataset/dataset.py b/bridge/primitives/dataset/dataset.py index 9de0c67..f99764d 100644 --- a/bridge/primitives/dataset/dataset.py +++ b/bridge/primitives/dataset/dataset.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from bridge.display.display_engine import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element import Element from bridge.primitives.sample.transform import SampleTransform diff --git a/bridge/primitives/dataset/sample_api.py b/bridge/primitives/dataset/sample_api.py index 14ebb86..084b0bb 100644 --- a/bridge/primitives/dataset/sample_api.py +++ b/bridge/primitives/dataset/sample_api.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from bridge.display.display_engine import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample import Sample diff --git a/bridge/primitives/dataset/singular_dataset.py b/bridge/primitives/dataset/singular_dataset.py index aa7ce4f..c591183 100644 --- a/bridge/primitives/dataset/singular_dataset.py +++ b/bridge/primitives/dataset/singular_dataset.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample.transform import SampleTransform diff --git a/bridge/primitives/element/data/cache/__init__.py b/bridge/primitives/element/data/cache/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bridge/primitives/element/data/cache/cache_methods.py b/bridge/primitives/element/data/cache/cache_methods.py deleted file mode 100644 index 9ab652f..0000000 --- a/bridge/primitives/element/data/cache/cache_methods.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from pathlib import Path -from typing import TYPE_CHECKING, Dict - -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism - -if TYPE_CHECKING: - from bridge.primitives.element.data.uri_components import URIComponents - from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE - - -class CacheCategory(ABC): - @staticmethod - @abstractmethod - def store( - data: ELEMENT_DATA_TYPE, - url: URIComponents | None, - category: DataCategory, - ) -> LoadMechanism: ... - - -class CacheImage(CacheCategory): - @staticmethod - def store(data: ELEMENT_DATA_TYPE, url: URIComponents | None, category: DataCategory) -> LoadMechanism: - import PIL.Image - from skimage.io import imsave - - if url is None: - return LoadMechanism(PIL.Image.fromarray(data), category) - - if url.scheme not in ["", "file"]: - raise NotImplementedError() - - path = Path(str(url)).expanduser() - - Path.mkdir(path.parent, parents=True, exist_ok=True) - imsave(path, data) - return LoadMechanism.from_url_string(str(url), category) - - -class CacheTorch(CacheCategory): - @staticmethod - def store(data: ELEMENT_DATA_TYPE, url: URIComponents | None, category: DataCategory) -> LoadMechanism: - if url is None: - return LoadMechanism(data, category) - - if url.scheme not in ["", "file"]: - raise NotImplementedError() - import torch - - path = Path(str(url)) - - Path.mkdir(path.parent, parents=True, exist_ok=True) - torch.save(data, path) - return LoadMechanism.from_url_string(str(url), category) - - -class CacheObj(CacheCategory): - @staticmethod - def store(data: ELEMENT_DATA_TYPE, url: URIComponents | None, category: DataCategory) -> LoadMechanism: - if url is None: - return LoadMechanism(data, category) - - if url.scheme not in ["", "file"]: - raise NotImplementedError() - import pickle - - path = Path(str(url)) - - Path.mkdir(path.parent, parents=True, exist_ok=True) - with open(path, "wb") as f: - pickle.dump(data, f) - return LoadMechanism.from_url_string(str(url), category) - - -class CachingMethodExecutor: - CACHE_DATA_METHODS: Dict[DataCategory, CacheCategory] = { - DataCategory.image: CacheImage, - DataCategory.torch: CacheTorch, - DataCategory.obj: CacheObj, - } - - def store(self, data: ELEMENT_DATA_TYPE, url: URIComponents | None, category: DataCategory) -> LoadMechanism: - return self.CACHE_DATA_METHODS[category].store(data, url, category) diff --git a/bridge/primitives/element/data/cache/cache_mechanism.py b/bridge/primitives/element/data/cache_mechanism.py similarity index 66% rename from bridge/primitives/element/data/cache/cache_mechanism.py rename to bridge/primitives/element/data/cache_mechanism.py index 6ab7035..7cc24ae 100644 --- a/bridge/primitives/element/data/cache/cache_mechanism.py +++ b/bridge/primitives/element/data/cache_mechanism.py @@ -4,22 +4,14 @@ import pandas as pd -from bridge.primitives.element.data.cache.cache_methods import CachingMethodExecutor -from bridge.primitives.element.data.category import DataCategory +from bridge.primitives.element.data import data_io from bridge.primitives.element.data.uri_components import URIComponents if TYPE_CHECKING: - from bridge.primitives.element.data.load.load_mechanism import LoadMechanism + from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE -CATEGORY_TO_EXTENSION = { - DataCategory.image: ".jpg", - DataCategory.torch: ".pt", - DataCategory.text: ".txt", - DataCategory.obj: ".pkl", -} - class CacheMechanism: def __init__(self, root_uri: URIComponents | None = None): @@ -33,24 +25,24 @@ def store( self, element: Element, data: ELEMENT_DATA_TYPE, - as_category: DataCategory | None = None, + as_category: str | None = None, should_update_elements: bool = False, ) -> LoadMechanism: if as_category is None: as_category = element.category uri = self._build_uri(element, as_category) - new_provider = CachingMethodExecutor().store(data, uri, as_category) # noqa + new_provider = data_io.store(data, uri, as_category) if should_update_elements and self._elements is not None: self._update_samples_with_new_provider(element.id, new_provider) return new_provider - def _build_uri(self, element: Element, category: DataCategory) -> URIComponents | None: + def _build_uri(self, element: Element, category: str) -> URIComponents | None: if self._root_uri is None: return None uri = URIComponents( scheme=self._root_uri.scheme, - path=self._root_uri.path + f"/{element.id}{CATEGORY_TO_EXTENSION[category]}", + path=self._root_uri.path + f"/{element.id}{data_io.extension(category)}", ) return uri diff --git a/bridge/primitives/element/data/category.py b/bridge/primitives/element/data/category.py deleted file mode 100644 index 77a135c..0000000 --- a/bridge/primitives/element/data/category.py +++ /dev/null @@ -1,9 +0,0 @@ -from bridge.utils import StrEnum - - -class DataCategory(StrEnum): - image = "image" - torch = "torch" - numpy = "numpy" - text = "text" - obj = "obj" diff --git a/bridge/primitives/element/data/data_io.py b/bridge/primitives/element/data/data_io.py new file mode 100644 index 0000000..f563d9e --- /dev/null +++ b/bridge/primitives/element/data/data_io.py @@ -0,0 +1,186 @@ +import abc +from pathlib import Path +from typing import Any + +import numpy as np + +from bridge.primitives.element.data.load_mechanism import LoadMechanism +from bridge.primitives.element.data.uri_components import URIComponents +from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE + +REGISTRY = {} + + +def register(cls): + if cls.category in REGISTRY: + raise ValueError(f"Category {cls.category} is already registered.") + REGISTRY[cls.category] = cls + return cls + + +class DataIO(abc.ABC): + @property + @abc.abstractmethod + def category(self): + pass + + @property + @abc.abstractmethod + def extension(self): + pass + + @classmethod + @abc.abstractmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + pass + + @classmethod + @abc.abstractmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + pass + + +@register +class JPEGDataIO(DataIO): + category = "image" + extension = ".jpg" + + @classmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + if not isinstance(url_or_data, URIComponents): + return np.array(url_or_data) # assumes object is a PIL image or np.ndarray + + if url_or_data.scheme not in ["http", "https", "file", ""]: + raise NotImplementedError("Only loading from local or http(s) URLs is supported for now.") + from skimage.io import imread + + return imread(str(url_or_data)) + + @classmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + import PIL.Image + from skimage.io import imsave + + if url is None: + return LoadMechanism(PIL.Image.fromarray(data), cls.category) + + if url.scheme not in ["", "file"]: + raise NotImplementedError("Only saving locally is supported for now.") + + path = Path(str(url)).expanduser() + + Path.mkdir(path.parent, parents=True, exist_ok=True) + imsave(path, data) + return LoadMechanism.from_url_string(str(url), cls.category) + + +@register +class TorchDataIO(DataIO): + category = "torch" + extension = ".pt" + + @classmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + if not isinstance(url_or_data, URIComponents): + return url_or_data # assume that is already torch tensor + import torch + + return torch.load(str(url_or_data)) + + @classmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + if url is None: + return LoadMechanism(data, cls.category) + + if url.scheme not in ["", "file"]: + raise NotImplementedError("Only saving locally is supported for now.") + import torch + + path = Path(str(url)) + + Path.mkdir(path.parent, parents=True, exist_ok=True) + torch.save(data, path) + return LoadMechanism.from_url_string(str(url), cls.category) + + +@register +class NumpyDataIO(DataIO): + category = "numpy" + extension = ".npy" + + @classmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + if not isinstance(url_or_data, URIComponents): + return url_or_data + return np.load(str(url_or_data)) + + @classmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + raise NotImplementedError() + + +@register +class TextDataIO(DataIO): + category = "text" + extension = ".txt" + + @classmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + if not isinstance(url_or_data, URIComponents): + return url_or_data + return open(str(url_or_data), "r").read() + + @classmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + raise NotImplementedError() + + +@register +class ObjDataIO(DataIO): + category = "obj" + extension = ".pkl" + + @classmethod + def load(cls, url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: + if not isinstance(url_or_data, URIComponents): + return url_or_data + raise NotImplementedError() + + @classmethod + def store(cls, data: Any, url: URIComponents | None) -> LoadMechanism: + if url is None: + return LoadMechanism(data, cls.category) + + if url.scheme not in ["", "file"]: + raise NotImplementedError("Only saving locally is supported for now.") + import pickle + + path = Path(str(url)) + + Path.mkdir(path.parent, parents=True, exist_ok=True) + with open(path, "wb") as f: + pickle.dump(data, f) + return LoadMechanism.from_url_string(str(url), cls.category) + + +def store(data: Any, url: URIComponents | None, category: str) -> LoadMechanism: + return REGISTRY[category].store(data, url) + + +def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: str) -> ELEMENT_DATA_TYPE: + return REGISTRY[category].load(url_or_data) + + +def extension(category: str) -> str: + return REGISTRY[category].extension + + +def is_registered(category: str) -> bool: + return category in REGISTRY + + +def list_registered_categories() -> list[str]: + return list(REGISTRY.keys()) + + +__all__ = ["register", "store", "load", "extension", "is_registered", "list_registered_categories"] diff --git a/bridge/primitives/element/data/load/__init__.py b/bridge/primitives/element/data/load/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bridge/primitives/element/data/load/load_methods.py b/bridge/primitives/element/data/load/load_methods.py deleted file mode 100644 index be900a8..0000000 --- a/bridge/primitives/element/data/load/load_methods.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict - -import numpy as np - -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.uri_components import URIComponents - -if TYPE_CHECKING: - from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE - - -class LoadCategory(ABC): - @staticmethod - @abstractmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - pass - - -class LoadImage(LoadCategory): - @staticmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - if not isinstance(url_or_data, URIComponents): - return np.array(url_or_data) # assumes object is a PIL image or np.ndarray - - if url_or_data.scheme not in ["http", "https", "file", ""]: - raise NotImplementedError() - from skimage.io import imread - - return imread(str(url_or_data)) - - -class LoadText(LoadCategory): - @staticmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - if not isinstance(url_or_data, URIComponents): - return url_or_data # assumes object is a string - - return open(str(url_or_data), "r").read() - - -class LoadTorch(LoadCategory): - @staticmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - if not isinstance(url_or_data, URIComponents): - return url_or_data - import torch - - return torch.load(str(url_or_data)) - - -class LoadNumpy(LoadCategory): - @staticmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - if not isinstance(url_or_data, URIComponents): - return url_or_data # assume that is already np array - return np.load(str(url_or_data)) - - -class LoadObj(LoadCategory): - @staticmethod - def load(url_or_data: URIComponents | ELEMENT_DATA_TYPE) -> ELEMENT_DATA_TYPE: - if not isinstance(url_or_data, URIComponents): - return url_or_data # assume that is already np array - raise NotImplementedError() - - -class LoadingMethodExecutor: - LOAD_DATA_METHODS: Dict[DataCategory, LoadCategory] = { - DataCategory.image: LoadImage, - DataCategory.text: LoadText, - DataCategory.torch: LoadTorch, - DataCategory.numpy: LoadNumpy, - DataCategory.obj: LoadObj, - } - - def load(self, url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: DataCategory) -> ELEMENT_DATA_TYPE: - return self.LOAD_DATA_METHODS[category].load(url_or_data) diff --git a/bridge/primitives/element/data/load/load_mechanism.py b/bridge/primitives/element/data/load_mechanism.py similarity index 78% rename from bridge/primitives/element/data/load/load_mechanism.py rename to bridge/primitives/element/data/load_mechanism.py index 18904d1..7a9b0a8 100644 --- a/bridge/primitives/element/data/load/load_mechanism.py +++ b/bridge/primitives/element/data/load_mechanism.py @@ -4,20 +4,20 @@ from typing_extensions import Self -from bridge.primitives.element.data.load.load_methods import LoadingMethodExecutor +from bridge.primitives.element.data import data_io from bridge.primitives.element.data.uri_components import URIComponents from bridge.utils import Dictable from bridge.utils.constants import ELEMENT_COLS if TYPE_CHECKING: - from bridge.primitives.element.data.category import DataCategory from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE class LoadMechanism(Dictable): keys = ELEMENT_COLS.LOAD_MECHANISM.list() - def __init__(self, url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: DataCategory) -> None: + def __init__(self, url_or_data: URIComponents | ELEMENT_DATA_TYPE, category: str) -> None: + assert data_io.is_registered(category), f"Category {category} is not registered." self._url_or_data = url_or_data self._category = category @@ -26,11 +26,11 @@ def url_or_data(self) -> URIComponents | ELEMENT_DATA_TYPE: return self._url_or_data @property - def category(self) -> DataCategory: + def category(self) -> str: return self._category def load_data(self) -> Any: - return LoadingMethodExecutor().load(self._url_or_data, self._category) + return data_io.load(self._url_or_data, self._category) def to_dict(self) -> Dict[str, Any]: return { @@ -47,6 +47,6 @@ def from_dict(cls, dic: Dict[str, Any], **kwargs) -> Self: ) @classmethod - def from_url_string(cls, url_string: str, category: DataCategory) -> Self: + def from_url_string(cls, url_string: str, category: str) -> Self: components = URIComponents.from_str(url_string) return cls(components, category) diff --git a/bridge/primitives/element/element.py b/bridge/primitives/element/element.py index 1a7cdca..8bed26f 100644 --- a/bridge/primitives/element/element.py +++ b/bridge/primitives/element/element.py @@ -4,15 +4,14 @@ import pandas as pd -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.utils import validate_metadata from bridge.utils.constants import ELEMENT_COLS from bridge.utils.helper import Displayable if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism - from bridge.primitives.element.data.category import DataCategory + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE from bridge.primitives.element.element_type import ElementType @@ -60,7 +59,7 @@ def etype(self) -> ElementType: return self._etype @property - def category(self) -> DataCategory: + def category(self) -> str: return self._load_mechanism.category @property diff --git a/bridge/primitives/sample/sample.py b/bridge/primitives/sample/sample.py index 41db7b7..6113fa6 100644 --- a/bridge/primitives/sample/sample.py +++ b/bridge/primitives/sample/sample.py @@ -5,7 +5,7 @@ import pandas as pd -from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism +from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.utils.constants import ELEMENT_COLS, INDICES diff --git a/bridge/primitives/sample/singular_sample.py b/bridge/primitives/sample/singular_sample.py index df86d64..2ca780f 100644 --- a/bridge/primitives/sample/singular_sample.py +++ b/bridge/primitives/sample/singular_sample.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_data_type import ELEMENT_DATA_TYPE from bridge.primitives.element.element_type import ElementType diff --git a/bridge/primitives/sample/transform/sample_transform.py b/bridge/primitives/sample/transform/sample_transform.py index fe3b90d..4489669 100644 --- a/bridge/primitives/sample/transform/sample_transform.py +++ b/bridge/primitives/sample/transform/sample_transform.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample import Sample diff --git a/bridge/primitives/sample/transform/vision.py b/bridge/primitives/sample/transform/vision.py index 4bf93cc..e211191 100644 --- a/bridge/primitives/sample/transform/vision.py +++ b/bridge/primitives/sample/transform/vision.py @@ -7,7 +7,6 @@ import numpy as np from PIL.Image import Image -from bridge.primitives.element.data.category import DataCategory from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample import Sample @@ -17,7 +16,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism class AlbumentationsCompose(SampleTransform): @@ -104,15 +103,15 @@ def _update_element_with_transformed_data( if curr_element.etype == ElementType.bbox: albm_data = np.array(albm_data[0]) new_element_data = BoundingBox(albm_data[:4], class_label=albm_data[4]) # noqa - new_category = DataCategory.obj + new_category = "obj" elif curr_element.etype == ElementType.image: if isinstance(albm_data, np.ndarray): - new_category = DataCategory.image + new_category = "image" else: with optional_dependencies(error="raise"): import torch if isinstance(albm_data, torch.Tensor): - new_category = DataCategory.torch + new_category = "torch" else: raise NotImplementedError(f"invalid data type: {type(albm_data)}") new_element_data = albm_data diff --git a/bridge/providers/dataset_provider.py b/bridge/providers/dataset_provider.py index 4f5435d..43358c2 100644 --- a/bridge/providers/dataset_provider.py +++ b/bridge/providers/dataset_provider.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism from bridge.primitives.element.element_type import ElementType diff --git a/bridge/providers/text.py b/bridge/providers/text.py index 6f8fc5f..05bf0f1 100644 --- a/bridge/providers/text.py +++ b/bridge/providers/text.py @@ -6,8 +6,7 @@ from bridge.display.basic import SimplePrints from bridge.primitives.dataset.singular_dataset import SingularDataset -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample.singular_sample import SingularSample @@ -17,7 +16,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism class LargeMovieReviewDataset(DatasetProvider[SingularDataset, SingularSample]): @@ -44,14 +43,14 @@ def build_dataset( class_dir_list = [d for d in list(self._split_root.iterdir()) if d.is_dir()] for class_idx, class_dir in enumerate(sorted(class_dir_list)): for textfile in class_dir.iterdir(): - load_mechanism = LoadMechanism.from_url_string(str(textfile), DataCategory.text) + load_mechanism = LoadMechanism.from_url_string(str(textfile), "text") text_element = Element( element_id=f"text_{textfile.stem}", sample_id=textfile.stem, etype=ElementType.text, load_mechanism=load_mechanism, ) - load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category=DataCategory.obj) + load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category="obj") label_element = Element( element_id=f"label_{textfile.stem}", sample_id=textfile.stem, diff --git a/bridge/providers/vision.py b/bridge/providers/vision.py index 1586461..8a8d046 100644 --- a/bridge/providers/vision.py +++ b/bridge/providers/vision.py @@ -9,8 +9,7 @@ from bridge.display.basic import SimplePrints from bridge.primitives.dataset import SingularDataset -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.primitives.sample.singular_sample import SingularSample @@ -20,7 +19,7 @@ if TYPE_CHECKING: from bridge.display import DisplayEngine - from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism + from bridge.primitives.element.data.cache_mechanism import CacheMechanism class ImageFolder(DatasetProvider[SingularDataset, SingularSample]): @@ -39,14 +38,14 @@ def build_dataset( element_id=f"image_{sample_id}", sample_id=sample_id, etype=ElementType.image, - load_mechanism=LoadMechanism.from_url_string(str(img_file), category=DataCategory.image), + load_mechanism=LoadMechanism.from_url_string(str(img_file), category="image"), metadata={"filename": img_file.name}, ) class_element = Element( element_id=f"class_{i}", sample_id=sample_id, etype=ElementType.class_label, - load_mechanism=LoadMechanism(ClassLabel(i, class_dir.name), category=DataCategory.obj), + load_mechanism=LoadMechanism(ClassLabel(i, class_dir.name), category="obj"), metadata={"filename": img_file.name}, ) images.append(img_element) @@ -109,7 +108,7 @@ def build_dataset( url = coco_img["coco_url"] # noqa else: url = str(img_file) - load_mechanism = LoadMechanism.from_url_string(url, category=DataCategory.image) # noqa + load_mechanism = LoadMechanism.from_url_string(url, category="image") # noqa img_element = Element( element_id=f"{img_id}_img", sample_id=img_id, @@ -122,7 +121,7 @@ def build_dataset( for coco_ann_dict in coco_annotations: category_id = coco_ann_dict["category_id"] bbox_data = BoundingBox(coords=(np.array(coco_ann_dict["bbox"])), class_label=(ClassLabel(category_id))) - load_mechanism = LoadMechanism(bbox_data, category=DataCategory.obj) + load_mechanism = LoadMechanism(bbox_data, category="obj") bbox_element = Element( element_id=f"{img_id}_{coco_ann_dict['id']}", sample_id=img_id, @@ -159,7 +158,7 @@ def build_dataset( element_id=i, etype=ElementType.image, sample_id=i, - load_mechanism=LoadMechanism(url_or_data=img, category=DataCategory.image), + load_mechanism=LoadMechanism(url_or_data=img, category="image"), ) label_element = Element( element_id=f"label_{i}", @@ -167,7 +166,7 @@ def build_dataset( sample_id=i, load_mechanism=LoadMechanism( url_or_data=ClassLabel(class_idx=target, class_name=self._ds.classes[target]), - category=DataCategory.obj, + category="obj", ), ) sample_list.append(img_element) diff --git a/docs/source/notebooks/vision/custom_data/dataset_provider.ipynb b/docs/source/notebooks/vision/custom_data/dataset_provider.ipynb index 71ffea4..d86be11 100644 --- a/docs/source/notebooks/vision/custom_data/dataset_provider.ipynb +++ b/docs/source/notebooks/vision/custom_data/dataset_provider.ipynb @@ -209,8 +209,7 @@ "from pathlib import Path\n", "\n", "from bridge.primitives.dataset.singular_dataset import SingularDataset\n", - "from bridge.primitives.element.data.category import DataCategory\n", - "from bridge.primitives.element.data.load.load_mechanism import LoadMechanism\n", + "from bridge.primitives.element.data.load_mechanism import LoadMechanism\n", "from bridge.primitives.element.element import Element\n", "from bridge.primitives.element.element_type import ElementType\n", "from bridge.utils.data_objects import ClassLabel\n", @@ -240,14 +239,14 @@ " class_dir_list = [d for d in list(self._split_root.iterdir()) if d.is_dir()]\n", " for class_idx, class_dir in enumerate(sorted(class_dir_list)):\n", " for textfile in class_dir.iterdir():\n", - " load_mechanism = LoadMechanism.from_url_string(str(textfile), DataCategory.text)\n", + " load_mechanism = LoadMechanism.from_url_string(str(textfile), \"text\")\n", " text_element = Element(\n", " element_id=f\"text_{textfile.stem}\",\n", " sample_id=textfile.stem,\n", " etype=ElementType.text,\n", " load_mechanism=load_mechanism,\n", " )\n", - " load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category=DataCategory.obj)\n", + " load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category=\"obj\")\n", " label_element = Element(\n", " element_id=f\"label_{textfile.stem}\",\n", " sample_id=textfile.stem,\n", @@ -282,7 +281,7 @@ "\n", "#### Create Text Element\n", "```python\n", - "load_mechanism = LoadMechanism.from_url_string(str(textfile), DataCategory.text)\n", + "load_mechanism = LoadMechanism.from_url_string(str(textfile), 'text')\n", "text_element = Element(\n", " element_id=f\"text_{textfile.stem}\",\n", " sample_id=textfile.stem,\n", @@ -297,7 +296,7 @@ "### Create Class Element\n", "\n", "```python\n", - "load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category=DataCategory.obj)\n", + "load_mechanism = LoadMechanism(ClassLabel(class_idx, class_dir.name), category='obj')\n", "label_element = Element(\n", " element_id=f\"label_{textfile.stem}\",\n", " sample_id=textfile.stem,\n", @@ -424,7 +423,7 @@ ], "metadata": { "kernelspec": { - "display_name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/source/notebooks/vision/custom_data/load_mechanism.ipynb b/docs/source/notebooks/vision/custom_data/load_mechanism.ipynb index 776f68a..5f24724 100644 --- a/docs/source/notebooks/vision/custom_data/load_mechanism.ipynb +++ b/docs/source/notebooks/vision/custom_data/load_mechanism.ipynb @@ -150,7 +150,7 @@ "metadata": {}, "source": [ "- **url_or_data**, as its name suggests, contains either a url that references the object (url broadly speaking - including local paths, s3 paths, etc.), or contains the actual object, in case we want to store it directly in-memory.\n", - "- **category** - accepts values of the enum `DataCategory`. This is used to determine which logic is used to load the object. Should we load the image using PIL? or a text file using simple `with open()`? this value determines that." + "- **category** - accepts a string that is used to determine which logic is used to load the object. Should we load the image using PIL? or a text file using simple `with open()`? this value determines that. To find which categories are supported, use `list_registered_categories`." ] }, { @@ -166,7 +166,7 @@ ], "metadata": { "kernelspec": { - "display_name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/source/notebooks/vision/processing_data/cache_mechanism.ipynb b/docs/source/notebooks/vision/processing_data/cache_mechanism.ipynb index bfdefa9..fe998f3 100644 --- a/docs/source/notebooks/vision/processing_data/cache_mechanism.ipynb +++ b/docs/source/notebooks/vision/processing_data/cache_mechanism.ipynb @@ -138,7 +138,7 @@ "metadata": {}, "outputs": [], "source": [ - "from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism\n", + "from bridge.primitives.element.data.cache_mechanism import CacheMechanism\n", "from bridge.primitives.element.data.uri_components import URIComponents\n", "from bridge.primitives.element.element_type import ElementType\n", "\n", @@ -228,7 +228,7 @@ " self,\n", " element,\n", " data,\n", - " as_category: DataCategory | None = None,\n", + " as_category: str | None = None,\n", " should_update_elements: bool = False,\n", ") -> LoadMechanism:\n", " ...\n", @@ -243,7 +243,7 @@ "def data(self) -> Any:\n", " data = self._load_mechanism.load_data()\n", " if self._cache_mechanism:\n", - " new_load_mechanism = self._cache_mechanism.store(self.id, self.type, data)\n", + " new_load_mechanism = self._cache_mechanism.store_image(self.id, self.type, data)\n", " self._load_mechanism = new_load_mechanism\n", " return data\n", " return data\n", @@ -309,7 +309,7 @@ ], "metadata": { "kernelspec": { - "display_name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/docs/source/notebooks/vision/processing_data/source2tensors_demo.ipynb b/docs/source/notebooks/vision/processing_data/source2tensors_demo.ipynb index 9326f35..2798f95 100644 --- a/docs/source/notebooks/vision/processing_data/source2tensors_demo.ipynb +++ b/docs/source/notebooks/vision/processing_data/source2tensors_demo.ipynb @@ -46,7 +46,7 @@ "import panel as pn\n", "\n", "from bridge.display.vision import Holoviews\n", - "from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism\n", + "from bridge.primitives.element.data.cache_mechanism import CacheMechanism\n", "from bridge.primitives.element.data.uri_components import URIComponents\n", "from bridge.primitives.element.element_type import ElementType\n", "from bridge.utils import pmap\n", diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index 28279c7..524c1ee 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -4,9 +4,8 @@ import pytest from bridge.primitives.dataset import Dataset -from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.cache_mechanism import CacheMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.data.uri_components import URIComponents from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType @@ -28,16 +27,14 @@ def dummy_elements(): etype=ElementType.image, load_mechanism=LoadMechanism( url_or_data=np.random.randint(0, 255, size=(100, 100, 3)).astype("uint8"), - category=DataCategory.obj, + category="obj", ), ) lbl_element = Element( element_id=f"label_{i}", sample_id=i, etype=ElementType.class_label, - load_mechanism=LoadMechanism( - url_or_data=ClassLabel(class_idx=np.random.randint(0, 10)), category=DataCategory.obj - ), + load_mechanism=LoadMechanism(url_or_data=ClassLabel(class_idx=np.random.randint(0, 10)), category="obj"), ) elements.extend([img_element, lbl_element]) return elements @@ -57,16 +54,14 @@ def dummy_elements_2(): etype=ElementType.image, load_mechanism=LoadMechanism( url_or_data=np.random.randint(0, 255, size=(100, 100, 3)).astype("uint8"), - category=DataCategory.obj, + category="obj", ), ) lbl_element = Element( element_id=f"label_{100+i}", sample_id=50 + i, # Adjusting sample_id to create overlap etype=ElementType.class_label, - load_mechanism=LoadMechanism( - url_or_data=ClassLabel(class_idx=np.random.randint(0, 10)), category=DataCategory.obj - ), + load_mechanism=LoadMechanism(url_or_data=ClassLabel(class_idx=np.random.randint(0, 10)), category="obj"), ) elements.extend([img_element, lbl_element]) return elements diff --git a/tests/core/test_dictable.py b/tests/core/test_dictable.py index 226277e..06fb394 100644 --- a/tests/core/test_dictable.py +++ b/tests/core/test_dictable.py @@ -2,9 +2,8 @@ import pytest from bridge.display import DisplayEngine -from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism -from bridge.primitives.element.data.category import DataCategory -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.cache_mechanism import CacheMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType from bridge.utils.constants import ELEMENT_COLS @@ -14,32 +13,32 @@ @pytest.fixture( params=[ { - ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: DataCategory.obj, + ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: "obj", ELEMENT_COLS.LOAD_MECHANISM.URL_OR_DATA: BoundingBox(np.array([0, 0, 1, 1]), class_label=ClassLabel(0)), }, { - ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: DataCategory.obj, + ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: "obj", ELEMENT_COLS.LOAD_MECHANISM.URL_OR_DATA: ClassLabel(class_idx=0, class_name="some_class"), }, { - ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: DataCategory.image, + ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: "image", ELEMENT_COLS.LOAD_MECHANISM.URL_OR_DATA: "http://example.com/image.jpg", }, { - ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: DataCategory.image, + ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: "image", ELEMENT_COLS.LOAD_MECHANISM.URL_OR_DATA: "dummy_path.jpg", }, { - ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: DataCategory.obj, + ELEMENT_COLS.LOAD_MECHANISM.CATEGORY: "obj", ELEMENT_COLS.LOAD_MECHANISM.URL_OR_DATA: Keypoint(np.array([0, 0])), }, ], ids=[ - f"{DataCategory.obj}_memory_bbox", - f"{DataCategory.obj}_memory_class_label", - f"{DataCategory.image}_http", - f"{DataCategory.image}_file", - f"{DataCategory.obj}_memory_keypoint", + "obj_memory_bbox", + "obj_memory_class_label", + "image_http", + "image_file", + "obj_memory_keypoint", ], ) def load_mechanism_dict(request): diff --git a/tests/core/test_element.py b/tests/core/test_element.py index 692692c..22c3ed9 100644 --- a/tests/core/test_element.py +++ b/tests/core/test_element.py @@ -1,8 +1,8 @@ import pytest from bridge.display import DisplayEngine -from bridge.primitives.element.data.cache.cache_mechanism import CacheMechanism -from bridge.primitives.element.data.load.load_mechanism import LoadMechanism +from bridge.primitives.element.data.cache_mechanism import CacheMechanism +from bridge.primitives.element.data.load_mechanism import LoadMechanism from bridge.primitives.element.element import Element from bridge.primitives.element.element_type import ElementType