From 86f31705a3bdc171f3b9e590294989aa8269fccd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Frederik=20Hvilsh=C3=B8j?= <93145535+frederik-encord@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:12:07 +0100 Subject: [PATCH] feat: classification dataset for pytorch (#675) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: classification dataset for pytorch * fix: update doc string and make tags more flexible * misc: upgrade encord version New properties have been added to the Object class and the attributes classes have moved to their own file. * fix: bad identifier query when no tags were applied * feat: add object dataset for pytorch * misc: improve error handling + fix pylint * fix: remove shape check for now and fix error log --------- Co-authored-by: Eloy Pérez Torres --- poetry.lock | 27 +- pyproject.toml | 2 +- src/encord_active/public/dataset.py | 384 ++++++++++++++++++++++++++++ 3 files changed, 396 insertions(+), 17 deletions(-) create mode 100644 src/encord_active/public/dataset.py diff --git a/poetry.lock b/poetry.lock index 746396745..3bf4a7041 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aiofiles" @@ -828,17 +828,18 @@ files = [ [[package]] name = "encord" -version = "0.1.83" +version = "0.1.98" description = "Encord Python SDK Client" optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.8,<4.0" files = [ - {file = "encord-0.1.83-py3-none-any.whl", hash = "sha256:c3116c75f60f413ec552024a18755524cc4c7d5b6b774fce883c680c0d747316"}, - {file = "encord-0.1.83.tar.gz", hash = "sha256:7a923c47ca21cf3980ef618ff049383d61ecc82641671fabececc9ba0a32e46c"}, + {file = "encord-0.1.98-py3-none-any.whl", hash = "sha256:e9b1758c8ad4803e82af61bc6644ab60fad52649afea5eea18f51a04ea1f8f95"}, + {file = "encord-0.1.98.tar.gz", hash = "sha256:f2b8447daa270000ff3e401eda6acf427b13b37471bd782c4b87ad4f7d5de60f"}, ] [package.dependencies] cryptography = ">=3.4.8" +pydantic = ">=1.7.0" python-dateutil = ">=2.8.2,<3.0.0" requests = ">=2.25.0,<3.0.0" tqdm = ">=4.32.1,<5.0.0" @@ -969,7 +970,6 @@ files = [ {file = "greenlet-2.0.2-cp27-cp27m-win32.whl", hash = "sha256:6c3acb79b0bfd4fe733dff8bc62695283b57949ebcca05ae5c129eb606ff2d74"}, {file = "greenlet-2.0.2-cp27-cp27m-win_amd64.whl", hash = "sha256:283737e0da3f08bd637b5ad058507e578dd462db259f7f6e4c5c365ba4ee9343"}, {file = "greenlet-2.0.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:d27ec7509b9c18b6d73f2f5ede2622441de812e7b1a80bbd446cb0633bd3d5ae"}, - {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d967650d3f56af314b72df7089d96cda1083a7fc2da05b375d2bc48c82ab3f3c"}, {file = "greenlet-2.0.2-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:30bcf80dda7f15ac77ba5af2b961bdd9dbc77fd4ac6105cee85b0d0a5fcf74df"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:26fbfce90728d82bc9e6c38ea4d038cba20b7faf8a0ca53a9c07b67318d46088"}, {file = "greenlet-2.0.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9190f09060ea4debddd24665d6804b995a9c122ef5917ab26e1566dcc712ceeb"}, @@ -978,7 +978,6 @@ files = [ {file = "greenlet-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:76ae285c8104046b3a7f06b42f29c7b73f77683df18c49ab5af7983994c2dd91"}, {file = "greenlet-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:2d4686f195e32d36b4d7cf2d166857dbd0ee9f3d20ae349b6bf8afc8485b3645"}, {file = "greenlet-2.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:c4302695ad8027363e96311df24ee28978162cdcdd2006476c43970b384a244c"}, - {file = "greenlet-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d4606a527e30548153be1a9f155f4e283d109ffba663a15856089fb55f933e47"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c48f54ef8e05f04d6eff74b8233f6063cb1ed960243eacc474ee73a2ea8573ca"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a1846f1b999e78e13837c93c778dcfc3365902cfb8d1bdb7dd73ead37059f0d0"}, {file = "greenlet-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a06ad5312349fec0ab944664b01d26f8d1f05009566339ac6f63f56589bc1a2"}, @@ -1008,7 +1007,6 @@ files = [ {file = "greenlet-2.0.2-cp37-cp37m-win32.whl", hash = "sha256:3f6ea9bd35eb450837a3d80e77b517ea5bc56b4647f5502cd28de13675ee12f7"}, {file = "greenlet-2.0.2-cp37-cp37m-win_amd64.whl", hash = "sha256:7492e2b7bd7c9b9916388d9df23fa49d9b88ac0640db0a5b4ecc2b653bf451e3"}, {file = "greenlet-2.0.2-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b864ba53912b6c3ab6bcb2beb19f19edd01a6bfcbdfe1f37ddd1778abfe75a30"}, - {file = "greenlet-2.0.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:1087300cf9700bbf455b1b97e24db18f2f77b55302a68272c56209d5587c12d1"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:ba2956617f1c42598a308a84c6cf021a90ff3862eddafd20c3333d50f0edb45b"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc3a569657468b6f3fb60587e48356fe512c1754ca05a564f11366ac9e306526"}, {file = "greenlet-2.0.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8eab883b3b2a38cc1e050819ef06a7e6344d4a990d24d45bc6f2cf959045a45b"}, @@ -1017,7 +1015,6 @@ files = [ {file = "greenlet-2.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b0ef99cdbe2b682b9ccbb964743a6aca37905fda5e0452e5ee239b1654d37f2a"}, {file = "greenlet-2.0.2-cp38-cp38-win32.whl", hash = "sha256:b80f600eddddce72320dbbc8e3784d16bd3fb7b517e82476d8da921f27d4b249"}, {file = "greenlet-2.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:4d2e11331fc0c02b6e84b0d28ece3a36e0548ee1a1ce9ddde03752d9b79bba40"}, - {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8512a0c38cfd4e66a858ddd1b17705587900dd760c6003998e9472b77b56d417"}, {file = "greenlet-2.0.2-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:88d9ab96491d38a5ab7c56dd7a3cc37d83336ecc564e4e8816dbed12e5aaefc8"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:561091a7be172ab497a3527602d467e2b3fbe75f9e783d8b8ce403fa414f71a6"}, {file = "greenlet-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:971ce5e14dc5e73715755d0ca2975ac88cfdaefcaab078a284fea6cfabf866df"}, @@ -2548,10 +2545,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.2", markers = "python_version >= \"3.10\" or python_version >= \"3.6\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, - {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""}, - {version = ">=1.14.5", markers = "python_version >= \"3.7\""}, - {version = ">=1.17.3", markers = "python_version >= \"3.8\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\" or python_version >= \"3.9\" and platform_system == \"Darwin\" and platform_machine == \"arm64\""}, + {version = ">=1.19.3", markers = "python_version >= \"3.8\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" and python_version < \"3.10\" or python_version >= \"3.9\" and python_version < \"3.10\" and platform_system != \"Darwin\" or python_version >= \"3.9\" and python_version < \"3.10\" and platform_machine != \"arm64\""}, ] [[package]] @@ -2659,7 +2654,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.1" @@ -4066,7 +4061,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -5155,4 +5150,4 @@ notebooks = ["ipywidgets", "jupyterlab"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.9.7 || >3.9.7,<3.12" -content-hash = "b2e962857f73a5f3cfaac8dc8d2889626c0b673cdb377ebeef39405950f9b4db" +content-hash = "45f69595f91b056dd4192b35fe160e6b93e7e47abba904d7a08ee3e081f05eca" diff --git a/pyproject.toml b/pyproject.toml index 8787d3c89..e8ba1fad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ encord-active = "encord_active.cli.main:cli" [tool.poetry.dependencies] python = ">=3.9,<3.9.7 || >3.9.7,<3.12" -encord = "^0.1.83" +encord = "^0.1.95" numpy = ">=1.23.5,<1.24.0" opencv-python = "4.5.5.64" natsort = "^8.1.0" diff --git a/src/encord_active/public/dataset.py b/src/encord_active/public/dataset.py new file mode 100644 index 000000000..3311802d5 --- /dev/null +++ b/src/encord_active/public/dataset.py @@ -0,0 +1,384 @@ +from pathlib import Path +from typing import Optional, Union +from uuid import UUID + +from encord.objects import Object, OntologyStructure, RadioAttribute +from PIL import Image +from sqlalchemy.sql.operators import in_op +from sqlmodel import Session, select +from torch.utils.data import Dataset + +from encord_active.db.models import ( + Project, + ProjectDataMetadata, + ProjectDataUnitMetadata, + ProjectTag, + ProjectTaggedDataUnit, + get_engine, +) +from encord_active.lib.common.data_utils import url_to_file_path + +P = Project +T = ProjectTaggedDataUnit +D = ProjectDataUnitMetadata +L = ProjectDataMetadata + + +class ActiveDataset(Dataset): + def __init__( + self, + database_path: Path, + project_hash: Union[set[str], str], + tag_name: Optional[Union[set[str], str]] = None, + ontology_hashes: Optional[list[str]] = None, + transform=None, + target_transform=None, + ): + database_path = database_path.expanduser().resolve() + if not database_path.is_file(): + raise FileNotFoundError(f"The database file does not exist at the specified path: '{database_path}'") + self.root_path = database_path.parent + self.engine = get_engine(database_path, use_alembic=False) + self.project_hash = {UUID(project_hash)} if isinstance(project_hash, str) else set(map(UUID, project_hash)) + self.tag_name = {tag_name} if isinstance(tag_name, str) else tag_name + self.ontology_hashes = ontology_hashes + + self.identifiers: list[tuple[UUID, int]] = [] + + self.setup() + + self.transform = transform + self.target_transform = target_transform + + def get_identifier_query(self, sess: Session): + identifier_query = select(D.du_hash, D.frame) + if self.tag_name is not None: + in_uids = set() + in_names = set() + for name in self.tag_name: + try: + in_uids.add(UUID(name)) + except ValueError: + in_names.add(name) + + where_clauses = [ + in_op(col, vals) for col, vals in [(ProjectTag.tag_hash, in_uids), (ProjectTag.name, in_names)] if vals + ] + tag_query = select(ProjectTag.tag_hash).where( + in_op(ProjectTag.project_hash, self.project_hash), *where_clauses + ) + tag_hash = set(sess.exec(tag_query).all()) + + if tag_hash is None: + valid_tag_names = sess.exec( + select(ProjectTag.name).where(in_op(ProjectTag.project_hash, self.project_hash)) + ).all() + raise ValueError( + f"Couldn't find a data tag with either name or tag_hash `{self.tag_name}`. Valid tags for the specified project(s) are {valid_tag_names}." + ) + + identifier_query = identifier_query.join( + T, onclause=((T.du_hash == D.du_hash) & (T.frame == D.frame)) + ).where( + in_op(D.project_hash, self.project_hash), + in_op(T.project_hash, self.project_hash), + in_op(T.tag_hash, tag_hash), + ) + else: + identifier_query = identifier_query.where(in_op(D.project_hash, self.project_hash)) + + return identifier_query + + def setup(self): + with Session(self.engine) as sess: + # Check that data is available locally + identifier_query = self.get_identifier_query(sess) + probe = sess.exec(identifier_query.add_columns(D.data_uri).limit(1)).first() + if ( + probe is not None + and probe[-1] is not None + and url_to_file_path(probe[-1], self.root_path) is None # type: ignore + ): + raise ValueError("Couldn't find data locally. Please execute `encord-active download-data` first.") + + # Check for videos + video_probe = sess.exec(identifier_query.where(D.data_uri_is_video).limit(1)).first() + if video_probe is not None: + raise ValueError("Dataset contains videos. This is currently not supported for this dataloader.") + + # Load and validate ontology + if len(self.project_hash) > 1: + ontologies = list(map(OntologyStructure.from_dict, sess.exec(select(P.project_ontology)).all())) # type: ignore + first, *rest = [ + tuple( + [o.feature_node_hash for o in ont.objects] + [c.feature_node_hash for c in ont.classifications] + ) + for ont in ontologies + ] + assert all( + [first == next_ for next_ in rest] + ), "Ontologies must match if you select multiple projects at once" + + ontology_dict = sess.exec( + select(P.project_ontology).where(in_op(P.project_hash, self.project_hash)).limit(1) + ).first() + if ontology_dict is None: + raise ValueError("Couldn't read project ontology") + self.ontology = OntologyStructure.from_dict(ontology_dict) # type: ignore + + def __len__(self): + with Session(self.engine) as sess: + return sess.query(self.get_identifier_query(sess)).count() + + def __getitem__(self, idx): + ... + + +class ActiveClassificationDataset(ActiveDataset): + def __init__( + self, + database_path: Path, + project_hash: Union[set[str], str], + tag_name: Optional[Union[set[str], str]] = None, + ontology_hashes: Optional[list[str]] = None, + transform=None, + target_transform=None, + ): + """ + A dataset hooked up to an Encord Active database. + The dataset can filter (image) data from Encord Active based on both `project_hash`es + and `tag_hash`/`tag_name`s. For example, if you have added a tag called + "train" within Encord Active, you can use that tag name by setting + `tag_name="train"` option. You can also combine multiple tags into one + dataset by providing a set of names. Similarly, you can use multiple + projects if they share the same ontology. Just provide the relevant + project hashes. + + Note: that this dataset requires that you have downloaded the data locally. + This can be done with `encord-active project download-data`. + + ⚠️ Videos are not yet supported. + + Args: + database_path: Path to where the `encord_active.sqlite` database lives + on your system. + project_hash: The project hash (or set of hashes) of the project(s) + to load data from. + tag_name: tag names (or hashes) for the tags you want to include. + If no tags are specified, all images with labels from the project + will be included. + ontology_hashes: The `feature_node_hash` of the radio button classification + question used for labels. The first radiobutton within that + classification will be used. + transform (): Data transform applied to PIL images + target_transform (): Target transform applied to the uint label tensor. + """ + assert ( + ontology_hashes is None or len(ontology_hashes) == 1 + ), "Either don't define ontology hashes to use first radio button in ontology or specify the feature node hash of the classification you want." + + super().__init__( + database_path, + project_hash, + tag_name, + ontology_hashes, + transform, + target_transform, + ) + + def setup(self): + super().setup() + + with Session(self.engine) as sess: + identifier_query = self.get_identifier_query(sess) + identifier_query = identifier_query.join(L, onclause=(L.data_hash == D.data_hash)).where( + in_op(L.project_hash, self.project_hash) + ) + identifier_query = identifier_query.add_columns(D.data_uri, D.classifications, L.label_row_json) + + identifiers = sess.exec(identifier_query).all() + ontology_pairs = [ + (c, a) + for c in self.ontology.classifications + for a in c.attributes + if isinstance(a, RadioAttribute) + and ((not self.ontology_hashes) or c.feature_node_hash in self.ontology_hashes) + ] + if len(ontology_pairs) == 0: + raise ValueError("No ontology classifications were found to use for labels") + classification, attribute = ontology_pairs[0] + indices = {o.feature_node_hash: i for i, o in enumerate(attribute.options)} + self.class_names = [o.title for o in attribute.options] + + self.uris = [] + self.labels = [] + for ( # type: ignore + *_, + data_uri, + classifications, + label_row_json, + ) in identifiers: + classification_answers = label_row_json["classification_answers"] + + clf_instance = next( + (c for c in classifications if c["featureHash"] == classification.feature_node_hash), + None, + ) + if clf_instance is None: + continue + clf_hash = clf_instance["classificationHash"] + clf_classifications = classification_answers[clf_hash]["classifications"] + clf_answers = next( + (a for a in clf_classifications if a["featureHash"] == attribute.feature_node_hash), + None, + ) + if clf_answers is None: + continue + clf_opt = next( + (o for o in clf_answers["answers"] if o["featureHash"] in indices), + None, + ) + if clf_opt is None: + continue + + self.uris.append(url_to_file_path(data_uri, self.root_path)) # type: ignore + self.labels.append(indices[clf_opt["featureHash"]]) + + def __getitem__(self, idx): + data_uri = self.uris[idx] + + img = Image.open(data_uri) + label = self.labels[idx] + + if self.transform: + img = self.transform(img) + + if self.target_transform: + label = self.target_transform(label) + + return img, label + + def __len__(self): + return len(self.labels) + + +class ActiveObjectDataset(ActiveDataset): + def __init__( + self, + database_path: Path, + project_hash: Union[set[str], str], + tag_name: Optional[Union[set[str], str]] = None, + ontology_hashes: Optional[list[str]] = None, + transform=None, + target_transform=None, + ): + """ + A dataset hooked up to an Encord Active database. + The dataset can filter (image) data from Encord Active based on both `project_hash`es + and `tag_hash`/`tag_name`s. For example, if you have added a tag called + "train" within Encord Active, you can use that tag name by setting + `tag_name="train"` option. You can also combine multiple tags into one + dataset by providing a set of names. Similarly, you can use multiple + projects if they share the same ontology. Just provide the relevant + project hashes. + + Note: This dataset requires that you have downloaded the data locally. + This can be done with `encord-active project download-data`. + + ⚠️ Videos are not yet supported. + + Args: + database_path: Path to where the `encord_active.sqlite` database lives + on your system. + project_hash: The project hash (or set of hashes) of the project(s) + to load data from. + tag_name: tag names (or hashes) for the tags you want to include. + If no tags are specified, all images with labels from the project + will be included. + ontology_hashes: The `feature_node_hash` of the objects used for labels. + transform (): Data transform applied to PIL images + target_transform (): Target transform applied to the uint label tensor. # TODO update the type + """ + super().__init__( + database_path, + project_hash, + tag_name, + ontology_hashes, + transform, + target_transform, + ) + + def setup(self): + super().setup() + + with Session(self.engine) as sess: + identifier_query = self.get_identifier_query(sess) + identifier_query = identifier_query.join(L, onclause=(L.data_hash == D.data_hash)).where( + in_op(L.project_hash, self.project_hash) + ) + identifier_query = identifier_query.add_columns(D.data_uri, D.objects, L.label_row_json) + identifiers = sess.exec(identifier_query).all() + + feature_hash_to_ontology_object: dict[str, Object] = { + o.feature_node_hash: o + for o in self.ontology.objects + if (self.ontology_hashes is None or (o.feature_node_hash in self.ontology_hashes)) + } + + if self.ontology_hashes is not None and len(feature_hash_to_ontology_object) != len(self.ontology_hashes): + missing_feature_hashes = set(self.ontology_hashes).difference(set(feature_hash_to_ontology_object)) + error_log = [] + for feature_hash in missing_feature_hashes: + ontology_object = feature_hash_to_ontology_object.get(feature_hash) + shape = "UNKNOWN" if ontology_object is None else ontology_object.shape.name + error_log.append(f"{(ontology_object and ontology_object.title) or feature_hash}: {shape}") + raise ValueError( + "Mismatch between objects with specified `ontology_hashes` and supported shapes: " + + ", ".join(error_log) + ) + if len(feature_hash_to_ontology_object) == 0: + raise ValueError("No ontology objects with supported shapes were found to use for labels") + + self.class_names = [o.title for o in feature_hash_to_ontology_object.values()] + + self.data_unit_paths: list[Path] = [] + self.labels_per_data_unit: list[list[dict]] = [] + self.label_attributes_per_data_unit: list[dict] = [] + for ( # type: ignore + *_, + data_uri, + all_objects, + label_row_json, + ) in identifiers: + object_hash_to_object = { + o["objectHash"]: o for o in all_objects if o["featureHash"] in feature_hash_to_ontology_object + } + object_attributes = { + k: v for k, v in label_row_json["object_answers"].items() if k in object_hash_to_object + } + + data_unit_path = url_to_file_path(data_uri, self.root_path) + if data_unit_path is None: + # Skip file as it's missing + continue + + self.data_unit_paths.append(data_unit_path) # TODO check "type: ignore" + self.labels_per_data_unit.append(list(object_hash_to_object.values())) + self.label_attributes_per_data_unit.append(object_attributes) + + def __getitem__(self, idx): + data_unit_path = self.data_unit_paths[idx] + + img = Image.open(data_unit_path) + labels = self.labels_per_data_unit[idx] + + if self.transform: + img = self.transform(img) + + if self.target_transform: + labels = self.target_transform(labels) + + return img, labels + + def __len__(self): + return len(self.data_unit_paths)