Skip to content

Commit

Permalink
Bug in split_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
yanncalec committed Jun 18, 2024
1 parent 3e1b732 commit c966e99
Show file tree
Hide file tree
Showing 3 changed files with 3,599 additions and 365 deletions.
5 changes: 4 additions & 1 deletion dpmhm/datasets/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,11 @@ def _drop_meta(X):
return {'feature': X['feature'], 'label': X['label']}

ds = self.to_windows(self._dataset_origin, self._window_size, self._hop_size)
# ds = utils.restore_shape(ds, 'feature', self.data_dim)

return ds.map(_drop_meta, num_parallel_calls=tf.data.AUTOTUNE) if self._no_meta else ds
if self._no_meta:
ds = ds.map(_drop_meta, num_parallel_calls=tf.data.AUTOTUNE)
return ds

# @property
# def full_label_dict(self) -> dict:
Expand Down
51 changes: 27 additions & 24 deletions dpmhm/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def random_split_dataset(ds:Dataset, splits:dict, *, shuffle_size:int=None, ds_s
dictionary specifying the name and ratio of the splits.
shuffle_size
size of shuffle, 1 for no shuffle (deterministic), None for full shuffle.
ds_size
real size of `ds`.
kwargs
other keywords arguments to the method `shuffle()`, e.g. `reshuffle_each_iteration=False`, `seed=1234`.
Expand Down Expand Up @@ -222,7 +224,7 @@ def split_dataset(ds:Dataset, splits:dict={'train':0.7, 'val':0.2, 'test':0.1},
splits
dictionary specifying the name and ratio of the splits.
labels
list of categories. If given apply the few-shot style split (i.e. split per category) otherwise apply the normal split.
list of categories. If given apply the few-shot style split (i.e. split per category) otherwise apply the normal split. This is incompatible with the keyword argument `ds_size`.
kwargs
arguments for `random_split_dataset()`
Expand All @@ -241,7 +243,12 @@ def split_dataset(ds:Dataset, splits:dict={'train':0.7, 'val':0.2, 'test':0.1},
ds = extract_by_category(ds, labels)
dp = {}
for n, (k,v) in enumerate(ds.items()):
dq = random_split_dataset(v, splits, **kwargs)
try:
dq = random_split_dataset(
v, splits, ds_size=None, **kwargs
)
except:
raise Exception("`ds_size` not supported in per category split")
if n == 0:
dp.update(dq)
else:
Expand All @@ -251,7 +258,7 @@ def split_dataset(ds:Dataset, splits:dict={'train':0.7, 'val':0.2, 'test':0.1},
return dp


def restore_shape(ds:Dataset, key:int|str=None, shape:tuple[int]=None) -> Dataset:
def restore_shape(ds:Dataset, key:str|int=None, shape:tuple[int]=None) -> Dataset:
"""Restore the shape of a dataset.
Parameters
Expand All @@ -274,35 +281,31 @@ def restore_shape(ds:Dataset, key:int|str=None, shape:tuple[int]=None) -> Datase
shape = list(ds.take(1).as_numpy_iterator())[0].shape
except:
shape = list(ds.take(1).as_numpy_iterator())[0][key].shape
# print(shape, key)

@tf.function
def _mapper(X):
try:
# flat dataset
Y = tf.ensure_shape(X, shape)
except:
# nested dataset
Y = X.copy()
Y[key] = tf.ensure_shape(Y[key], shape) # will create an extra dimension if `key=None`
Y[key] = tf.ensure_shape(X[key], shape)
return Y
# if key is None:
# if shape is None:
# shape = list(ds.take(1).as_numpy_iterator())[0].shape

# @tf.function
# def _mapper(X):
# Y = tf.ensure_shape(X, shape)
# return Y
# else:
# if shape is None:
# shape = list(ds.take(1).as_numpy_iterator())[0][key].shape

# @tf.function
# def _mapper(X):
# Y = X.copy()
# Y[key] = tf.ensure_shape(Y[key], shape)
# return Y

return ds.map(_mapper, num_parallel_calls=tf.data.AUTOTUNE)
# return ds.map(lambda x,y: (tf.ensure_shape(x, shape), y), num_parallel_calls=tf.data.AUTOTUNE)

@tf.function
def _mapper_tuple(*X):
# tuple dataset
# This code looks suspicious but actually works...
Y = list(X)
Y[key] = tf.ensure_shape(X[key], shape)
return Y # automatically converted back to tuple

if type(ds.element_spec) is tuple:
return ds.map(_mapper_tuple, num_parallel_calls=tf.data.AUTOTUNE)
else:
return ds.map(_mapper, num_parallel_calls=tf.data.AUTOTUNE)


def restore_cardinality(ds:Dataset, card:int=None) -> Dataset:
Expand Down
3,908 changes: 3,568 additions & 340 deletions notebooks/models/SimCLR.ipynb

Large diffs are not rendered by default.

0 comments on commit c966e99

Please sign in to comment.