Skip to content

Commit

Permalink
Updated and fixed data splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh7joshi committed Nov 28, 2021
1 parent fb3be7e commit 43ac469
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
36 changes: 23 additions & 13 deletions agml/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from decimal import getcontext, Decimal

import numpy as np
Expand Down Expand Up @@ -107,12 +108,15 @@ def __iter__(self):
for indx in range(len(self)):
yield self[indx]

def __str__(self):
def __repr__(self):
out = f"<AgMLDataLoader: (dataset={self.name}"
out += f", task={self.task}"
out += f") at {hex(id(self))}>"
return out

def __str__(self):
return repr(self)

def copy(self):
"""Returns a deep copy of the data loader's contents."""
return self.__copy__()
Expand Down Expand Up @@ -167,8 +171,7 @@ def _generate_split_loader(self, contents, split):
contents = contents,
info = self.info,
root = self.dataset_root)
current_manager = self._manager.__getstate__()
new_manager = DataManager.__new__(DataManager)
current_manager = copy.deepcopy(self._manager.__getstate__())
current_manager.pop('builder')
current_manager['builder'] = builder

Expand All @@ -177,19 +180,26 @@ def _generate_split_loader(self, contents, split):
if self._manager._shuffle:
np.random.shuffle(accessors)
current_manager['accessors'] = accessors
batch_size = current_manager.pop('batch_size')
current_manager['batch_size'] = None
new_manager = DataManager.__new__(DataManager)
new_manager.__setstate__(current_manager)

# After the builder and accessors have been generated, we need
# to generate a new list of `DataObject`s.
new_manager._create_objects(
new_manager._builder, self.task)

# Batching data needs to be done independently.
if current_manager['batch_size'] is not None:
new_manager.batch_data(
batch_size = current_manager['batch_size'])
if batch_size is not None:
new_manager.batch_data(batch_size = batch_size)

# Instantiate a new `AgMLDataLoader` from the contents.
loader = self.copy().__getstate__()
loader['builder'] = builder
loader['manager'] = new_manager
loader_state = self.copy().__getstate__()
loader_state['builder'] = builder
loader_state['manager'] = new_manager
cls = super(AgMLDataLoader, self).__new__(AgMLDataLoader)
cls.__setstate__(loader)
cls.__setstate__(loader_state)
for attr in ['train', 'val', 'test']:
setattr(cls, f'_{attr}_data', None)
cls._is_split = True
Expand All @@ -201,7 +211,7 @@ def train_data(self):
if isinstance(self._train_data, AgMLDataLoader):
return self._train_data
self._train_data = self._generate_split_loader(
self._train_data, split = 'val')
self._train_data, split = 'train')
return self._train_data

@property
Expand Down Expand Up @@ -294,7 +304,7 @@ def on_epoch_end(self):
"""
self._manager._maybe_shuffle()

def as_keras_sequence(self):
def as_keras_sequence(self) -> "AgMLDataLoader":
"""Sets the `DataLoader` in TensorFlow mode.
This TensorFlow extension converts the loader into a TensorFlow mode,
Expand Down Expand Up @@ -324,7 +334,7 @@ def as_keras_sequence(self):
self._manager.update_train_state('tf')
return self

def as_torch_dataset(self):
def as_torch_dataset(self) -> "AgMLDataLoader":
"""Sets the `DataLoader` in PyTorch mode.
This PyTorch extension converts the loader into a PyTorch mode, adding
Expand Down
13 changes: 8 additions & 5 deletions agml/data/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ def __init__(self, image, annotation, root):
# dictionary doesn't contain the full path, only the base.
self._dataset_root = root

@staticmethod
def _parse_image(path):
with imread_context(os.path.realpath(path)) as image:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def __len__(self):
return 2

def __getitem__(self, i):
return self.get()[i]

def __repr__(self):
return f"<DataObject: {self._image_path}, {self._annotation_obj}>"

@staticmethod
def _parse_image(path):
with imread_context(os.path.realpath(path)) as image:
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

def get(self):
"""Returns the image and annotation pair with applied transforms.
Expand Down

0 comments on commit 43ac469

Please sign in to comment.