From be4b7977a0f7bc5a0ef9121626c9f2528bb5dcbd Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 2 Aug 2021 22:08:22 +0200 Subject: [PATCH 001/171] add datasets to vis --- experiment/exp/00_data_traversal/run.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/experiment/exp/00_data_traversal/run.py b/experiment/exp/00_data_traversal/run.py index 4fa2623f..1f8319fb 100644 --- a/experiment/exp/00_data_traversal/run.py +++ b/experiment/exp/00_data_traversal/run.py @@ -34,6 +34,8 @@ from disent.dataset.data import GroundTruthData from disent.dataset.data import Shapes3dData from disent.dataset.data import SmallNorbData +from disent.dataset.data import XYBlocksData +from disent.dataset.data import XYObjectData from disent.dataset.data import XYSquaresData from disent.util.seeds import TempNumpySeed @@ -121,6 +123,18 @@ def plot_dataset_traversals( seed=7, add_random_traversal=add_random_traversal, num_cols=num_cols ) + plot_dataset_traversals( + XYObjectData(), + rel_path=f'plots/xy-object-traversal', + seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + + plot_dataset_traversals( + XYBlocksData(), + rel_path=f'plots/xy-blocks-traversal', + seed=47, add_random_traversal=add_random_traversal, num_cols=num_cols + ) + plot_dataset_traversals( Shapes3dData(), rel_path=f'plots/shapes3d-traversal', From 0224d8a5c99754786930a623209d77ab8b6034e6 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 3 Aug 2021 00:13:18 +0200 Subject: [PATCH 002/171] cherry-pick main -- remove disent.util.config --- disent/util/config.py | 73 ---------------------------------- experiment/run.py | 2 +- experiment/util/hydra_data.py | 2 +- experiment/util/hydra_utils.py | 29 ++++++++++++++ 4 files changed, 31 insertions(+), 75 deletions(-) delete mode 100644 disent/util/config.py diff --git a/disent/util/config.py b/disent/util/config.py deleted file mode 100644 index d6962ef4..00000000 --- a/disent/util/config.py +++ /dev/null @@ -1,73 +0,0 @@ -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -# MIT License -# -# Copyright (c) 2021 Nathan Juraj Michlo -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - -from deprecated import deprecated - - -# ========================================================================= # -# Recursive Hydra Instantiation # -# TODO: use https://github.com/facebookresearch/hydra/pull/989 # -# I think this is quicker? Just doesn't perform checks... # -# ========================================================================= # - - -@deprecated('replace with hydra 1.1') -def call_recursive(config): - # import hydra - try: - import hydra - from omegaconf import DictConfig - from omegaconf import ListConfig - except ImportError: - raise ImportError('please install hydra-core for call_recursive/instantiate_recursive support') - # recurse - def _call_recursive(config): - if isinstance(config, (dict, DictConfig)): - c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'} - if '_target_' in config: - config = hydra.utils.instantiate({'_target_': config['_target_']}, **c) - elif isinstance(config, (tuple, list, ListConfig)): - config = [_call_recursive(v) for v in config] - return config - return _call_recursive(config) - - -# alias -@deprecated('replace with hydra 1.1') -def instantiate_recursive(config): - return call_recursive(config) - - -@deprecated('replace with hydra 1.1') -def instantiate_object_if_needed(config_or_object): - if isinstance(config_or_object, dict): - return instantiate_recursive(config_or_object) - else: - return config_or_object - - -# ========================================================================= # -# END # -# ========================================================================= # diff --git a/experiment/run.py b/experiment/run.py index 15ff09d9..537a73c2 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -29,7 +29,6 @@ import pytorch_lightning as pl import torch import torch.utils.data -from disent.util.config import instantiate_recursive from omegaconf import DictConfig from omegaconf import OmegaConf from pytorch_lightning.loggers import CometLogger @@ -49,6 +48,7 @@ from experiment.util.hydra_data import HydraDataModule from experiment.util.hydra_utils import make_non_strict from experiment.util.hydra_utils import merge_specializations +from experiment.util.hydra_utils import instantiate_recursive from experiment.util.run_utils import log_error_and_exit from experiment.util.run_utils import set_debug_logger from experiment.util.run_utils import set_debug_trainer diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index 634a8b0d..cf1d54b1 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -29,7 +29,7 @@ from disent.dataset import DisentDataset from disent.nn.transform import DisentDatasetTransform -from disent.util.config import instantiate_recursive +from experiment.util.hydra_utils import instantiate_recursive # ========================================================================= # diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 02c0ec9f..8dfa3f33 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -24,14 +24,43 @@ import logging +import hydra from deprecated import deprecated from omegaconf import DictConfig +from omegaconf import ListConfig from omegaconf import OmegaConf log = logging.getLogger(__name__) +# ========================================================================= # +# Recursive Hydra Instantiation # +# TODO: use https://github.com/facebookresearch/hydra/pull/989 # +# I think this is quicker? Just doesn't perform checks... # +# ========================================================================= # + + +@deprecated('replace with hydra 1.1') +def call_recursive(config): + # recurse + def _call_recursive(config): + if isinstance(config, (dict, DictConfig)): + c = {k: _call_recursive(v) for k, v in config.items() if k != '_target_'} + if '_target_' in config: + config = hydra.utils.instantiate({'_target_': config['_target_']}, **c) + elif isinstance(config, (tuple, list, ListConfig)): + config = [_call_recursive(v) for v in config] + return config + return _call_recursive(config) + + +# alias +@deprecated('replace with hydra 1.1') +def instantiate_recursive(config): + return call_recursive(config) + + # ========================================================================= # # Better Specializations # # TODO: this might be replaced by recursive instantiation # From 960284919f9dd7dca10339ade71dfcee8b969f4d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 3 Aug 2021 00:21:36 +0200 Subject: [PATCH 003/171] cherry-pick from commit c9bd6099 -- replace XySquaresData with XyObjectData --- docs/examples/overview_data.py | 4 +- docs/examples/overview_dataset_loader.py | 4 +- docs/examples/overview_dataset_pair.py | 4 +- .../examples/overview_dataset_pair_augment.py | 4 +- docs/examples/overview_dataset_single.py | 4 +- docs/examples/overview_framework_adagvae.py | 4 +- docs/examples/overview_framework_ae.py | 4 +- docs/examples/overview_framework_betavae.py | 4 +- .../overview_framework_betavae_scheduled.py | 4 +- tests/test_data.py | 28 +++++---- tests/test_math.py | 4 +- tests/test_metrics.py | 5 +- tests/test_samplers.py | 62 +++++++++---------- 13 files changed, 70 insertions(+), 65 deletions(-) diff --git a/docs/examples/overview_data.py b/docs/examples/overview_data.py index 6f8b87da..8da7e17d 100644 --- a/docs/examples/overview_data.py +++ b/docs/examples/overview_data.py @@ -1,6 +1,6 @@ -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData -data = XYSquaresData(square_size=1, image_size=2, num_squares=2) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') print(f'Number of observations: {len(data)} == {data.size}') print(f'Observation shape: {data.observation_shape}') diff --git a/docs/examples/overview_dataset_loader.py b/docs/examples/overview_dataset_loader.py index 5565e889..5646df02 100644 --- a/docs/examples/overview_dataset_loader.py +++ b/docs/examples/overview_dataset_loader.py @@ -1,11 +1,11 @@ from torch.utils.data import DataLoader from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairOrigSampler from disent.nn.transform import ToStandardisedTensor # prepare the data -data = XYSquaresData(square_size=1, image_size=2, num_squares=2) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_dataset_pair.py b/docs/examples/overview_dataset_pair.py index 853061a8..96e1f155 100644 --- a/docs/examples/overview_dataset_pair.py +++ b/docs/examples/overview_dataset_pair.py @@ -1,11 +1,11 @@ from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairOrigSampler from disent.nn.transform import ToStandardisedTensor # prepare the data -data = XYSquaresData(square_size=1, image_size=2, num_squares=2) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') dataset = DisentDataset(data, sampler=GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) # iterate over single epoch diff --git a/docs/examples/overview_dataset_pair_augment.py b/docs/examples/overview_dataset_pair_augment.py index 6cdb0178..81b494c7 100644 --- a/docs/examples/overview_dataset_pair_augment.py +++ b/docs/examples/overview_dataset_pair_augment.py @@ -1,11 +1,11 @@ from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairSampler from disent.nn.transform import ToStandardisedTensor, FftBoxBlur # prepare the data -data = XYSquaresData(square_size=1, image_size=2, num_squares=2) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') dataset = DisentDataset(data, sampler=GroundTruthPairSampler(), transform=ToStandardisedTensor(), augment=FftBoxBlur(radius=1, p=1.0)) # iterate over single epoch diff --git a/docs/examples/overview_dataset_single.py b/docs/examples/overview_dataset_single.py index 34a68348..120b4268 100644 --- a/docs/examples/overview_dataset_single.py +++ b/docs/examples/overview_dataset_single.py @@ -1,11 +1,11 @@ -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset import DisentDataset # prepare the data # - DisentDataset is a generic wrapper around torch Datasets that prepares # the data for the various frameworks according to some sampling strategy # by default this sampling strategy just returns the data at the given idx. -data = XYSquaresData(square_size=1, image_size=2, num_squares=2) +data = XYObjectData(grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') dataset = DisentDataset(data, transform=None, augment=None) # iterate over single epoch diff --git a/docs/examples/overview_framework_adagvae.py b/docs/examples/overview_framework_adagvae.py index 55200b3e..ca908369 100644 --- a/docs/examples/overview_framework_adagvae.py +++ b/docs/examples/overview_framework_adagvae.py @@ -2,7 +2,7 @@ from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import GroundTruthPairOrigSampler from disent.frameworks.vae import AdaVae from disent.model import AutoEncoder @@ -12,7 +12,7 @@ # prepare the data -data = XYSquaresData() +data = XYObjectData() dataset = DisentDataset(data, GroundTruthPairOrigSampler(), transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_ae.py b/docs/examples/overview_framework_ae.py index adfbdc26..4e65a015 100644 --- a/docs/examples/overview_framework_ae.py +++ b/docs/examples/overview_framework_ae.py @@ -2,7 +2,7 @@ from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import SingleSampler from disent.frameworks.ae import Ae from disent.model import AutoEncoder @@ -12,7 +12,7 @@ # prepare the data -data = XYSquaresData() +data = XYObjectData() dataset = DisentDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_betavae.py b/docs/examples/overview_framework_betavae.py index 184f7338..28f0904a 100644 --- a/docs/examples/overview_framework_betavae.py +++ b/docs/examples/overview_framework_betavae.py @@ -2,7 +2,7 @@ from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import SingleSampler from disent.frameworks.vae import BetaVae from disent.model import AutoEncoder @@ -12,7 +12,7 @@ # prepare the data -data = XYSquaresData() +data = XYObjectData() dataset = DisentDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/docs/examples/overview_framework_betavae_scheduled.py b/docs/examples/overview_framework_betavae_scheduled.py index 95407500..682f8298 100644 --- a/docs/examples/overview_framework_betavae_scheduled.py +++ b/docs/examples/overview_framework_betavae_scheduled.py @@ -2,7 +2,7 @@ from torch.optim import Adam from torch.utils.data import DataLoader from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import SingleSampler from disent.frameworks.vae import BetaVae from disent.model import AutoEncoder @@ -12,7 +12,7 @@ from disent.util import is_test_run # you can ignore and remove this # prepare the data -data = XYSquaresData() +data = XYObjectData() dataset = DisentDataset(data, transform=ToStandardisedTensor()) dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) diff --git a/tests/test_data.py b/tests/test_data.py index 09844c29..efbefe0e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -32,11 +32,13 @@ import pytest from disent.dataset.data import Hdf5Dataset +from disent.dataset.data import XYObjectData from disent.dataset.data import XYSquaresData from disent.dataset.data import XYSquaresMinimalData from disent.dataset.util.hdf5 import hdf5_resave_file from disent.dataset.util.hdf5 import hdf5_test_speed from disent.util.inout.hashing import hash_file +from disent.util.function import wrapped_partial from tests.util import no_stderr from tests.util import no_stdout @@ -46,7 +48,6 @@ # TESTS # # ========================================================================= # - def test_xysquares_similarity(): data_org = XYSquaresData() data_min = XYSquaresMinimalData() @@ -68,15 +69,20 @@ def _iterate_over_data(data, indices): return i + 1 +# factors=(3, 3, 2, 3), len=54 +TestXYObjectData = wrapped_partial(XYObjectData, grid_size=4, min_square_size=1, max_square_size=2, square_size_spacing=1, palette='rgb') +_TEST_LEN = 54 + + @contextlib.contextmanager def create_temp_h5data(track_times=False, **h5py_dataset_kwargs): # generate data - data = np.stack([img for img in XYSquaresData(square_size=2, image_size=4)], axis=0) + data = np.stack([img for img in TestXYObjectData()], axis=0) # create temp file with NamedTemporaryFile('r') as out_file: # create temp files with h5py.File(out_file.name, 'w', libver='earliest') as file: - file.create_dataset(name='data', shape=(64, 4, 4, 3), dtype='uint8', data=data, track_times=track_times, **h5py_dataset_kwargs) + file.create_dataset(name='data', shape=(_TEST_LEN, 4, 4, 3), dtype='uint8', data=data, track_times=track_times, **h5py_dataset_kwargs) # return the data & file yield out_file.name, data @@ -89,13 +95,13 @@ def test_hdf5_pickle_dataset(): with Hdf5Dataset(tmp_path, 'data') as data: indices = list(range(len(data))) # test locally - assert _iterate_over_data(data=data, indices=indices) == 64 + assert _iterate_over_data(data=data, indices=indices) == _TEST_LEN # test multiprocessing executor = ProcessPoolExecutor(2) future_0 = executor.submit(_iterate_over_data, data=data, indices=indices[0::2]) future_1 = executor.submit(_iterate_over_data, data=data, indices=indices[1::2]) - assert future_0.result() == 32 - assert future_1.result() == 32 + assert future_0.result() == _TEST_LEN // 2 + assert future_1.result() == _TEST_LEN // 2 # test multiprocessing on invalid data with h5py.File(tmp_path, 'r', swmr=True) as file: with pytest.raises(TypeError, match='h5py objects cannot be pickled'): @@ -104,8 +110,8 @@ def test_hdf5_pickle_dataset(): @pytest.mark.parametrize(['hash_mode', 'target_hash'], [ - ('full', 'eec6e5d78b513f697f13bc5b43e3e361'), - ('fast', '598983f80047af65b9f85b5b1cf60e19'), + ('full', 'a3b60a9e248b4b66bdbf4f87a78bf7cc'), + ('fast', 'a20d554d4912a39e7654b4dc98207490'), ]) def test_hdf5_determinism(hash_mode: str, target_hash: str): # check hashing a @@ -134,7 +140,7 @@ def make_and_hash(track_times=False, **h5py_dataset_kwargs): def test_hdf5_resave_dataset(): with no_stdout(), no_stderr(): - with create_temp_h5data(chunks=(64, 4, 4, 3)) as (inp_path, raw_data), create_temp_h5data(chunks=None) as (out_path, _): + with create_temp_h5data(chunks=(_TEST_LEN, 4, 4, 3)) as (inp_path, raw_data), create_temp_h5data(chunks=None) as (out_path, _): # convert dataset hdf5_resave_file( inp_path=inp_path, @@ -152,14 +158,14 @@ def test_hdf5_resave_dataset(): # check datasets with h5py.File(inp_path, 'r') as inp: assert np.all(inp['data'][...] == raw_data) - assert inp['data'].chunks == (64, 4, 4, 3) + assert inp['data'].chunks == (_TEST_LEN, 4, 4, 3) with h5py.File(out_path, 'r') as out: assert np.all(out['data'][...] == raw_data) assert out['data'].chunks == (1, 4, 4, 3) def test_hdf5_speed_test(): - with create_temp_h5data(chunks=(64, 4, 4, 3)) as (path, _): + with create_temp_h5data(chunks=(_TEST_LEN, 4, 4, 3)) as (path, _): hdf5_test_speed(path, dataset_name='data', access_method='random') with create_temp_h5data(chunks=(1, 4, 4, 3)) as (path, _): hdf5_test_speed(path, dataset_name='data', access_method='sequential') diff --git a/tests/test_math.py b/tests/test_math.py index 902adc2e..1921c18a 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -29,7 +29,7 @@ from scipy.stats import hmean from disent.dataset import DisentDataset -from disent.dataset.data import XYSquaresData +from disent.dataset.data import XYObjectData from disent.dataset.sampling import RandomSampler from disent.nn.functional import torch_conv2d_channel_wise from disent.nn.functional import torch_conv2d_channel_wise_fft @@ -131,7 +131,7 @@ def test_dct(): def test_fft_conv2d(): - data = XYSquaresData() + data = XYObjectData() dataset = DisentDataset(data, RandomSampler(), transform=ToStandardisedTensor(), augment=None) # sample data factors = dataset.ground_truth_data.sample_random_factor_traversal(f_idx=2) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index cdd1d5c8..5b844292 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -25,9 +25,8 @@ import pytest import torch -from disent.dataset.data import XYSquaresMinimalData +from disent.dataset.data import XYObjectData from disent.dataset import DisentDataset -from disent.dataset.sampling import RandomSampler from disent.metrics import * from disent.nn.transform import ToStandardisedTensor from disent.util.function import wrapped_partial @@ -51,7 +50,7 @@ def test_metrics(metric_fn): z_size = 8 # ground truth data # TODO: DisentDataset should not be needed to compute metrics! - dataset = DisentDataset(XYSquaresMinimalData(), transform=ToStandardisedTensor()) + dataset = DisentDataset(XYObjectData(), transform=ToStandardisedTensor()) # randomly sampled representation get_repr = lambda x: torch.randn(len(x), z_size) # evaluate diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 58bf9d05..d5f74593 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -32,7 +32,7 @@ from disent.dataset import DisentDataset from disent.dataset.data import BaseEpisodesData from disent.dataset.sampling import * -from disent.dataset.data import XYSquaresMinimalData +from disent.dataset.data import XYObjectData class TestEpisodesData(BaseEpisodesData): @@ -46,44 +46,44 @@ def _load_episode_observations(self) -> List[np.ndarray]: @pytest.mark.parametrize(['dataset', 'num_samples', 'check_mode', 'sampler'], [ - [XYSquaresMinimalData(), 1, 'first', SingleSampler()], - [XYSquaresMinimalData(), 1, 'first', GroundTruthSingleSampler()], + [XYObjectData(), 1, 'first', SingleSampler()], + [XYObjectData(), 1, 'first', GroundTruthSingleSampler()], # original AdaVAE sampling method - [XYSquaresMinimalData(), 2, 'first', GroundTruthPairOrigSampler(p_k=1)], - [XYSquaresMinimalData(), 2, 'first', GroundTruthPairOrigSampler(p_k=2)], - [XYSquaresMinimalData(), 2, 'first', GroundTruthPairOrigSampler(p_k=-1)], + [XYObjectData(), 2, 'first', GroundTruthPairOrigSampler(p_k=1)], + [XYObjectData(), 2, 'first', GroundTruthPairOrigSampler(p_k=2)], + [XYObjectData(), 2, 'first', GroundTruthPairOrigSampler(p_k=-1)], # TODO: consider removing the pair sampler... it is difficult to maintain and confusing - [XYSquaresMinimalData(), 2, 'first', GroundTruthPairSampler()], + [XYObjectData(), 2, 'first', GroundTruthPairSampler()], # TODO: consider removing the triplet sampler... it is difficult to maintain and confusing - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler()], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(swap_metric='k')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(swap_metric='manhattan')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(swap_metric='manhattan_norm')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(swap_metric='euclidean')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(swap_metric='euclidean_norm')], - # [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='offset')], # TODO: these are broken - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='bounded_below')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='random')], - # [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='offset')], # TODO: these are broken - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='bounded_below')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='random')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler()], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(swap_metric='k')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(swap_metric='manhattan')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(swap_metric='manhattan_norm')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(swap_metric='euclidean')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(swap_metric='euclidean_norm')], + # [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='offset')], # TODO: these are broken + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='bounded_below')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_k_sample_mode='random')], + # [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='offset')], # TODO: these are broken + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='bounded_below')], + [XYObjectData(), 3, 'first', GroundTruthTripleSampler(n_radius_sample_mode='random')], - [XYSquaresMinimalData(), 1, 'first', GroundTruthDistSampler(num_samples=1)], - [XYSquaresMinimalData(), 2, 'first', GroundTruthDistSampler(num_samples=2)], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3)], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='random')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='factors')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='manhattan')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='manhattan_scaled')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='combined')], - [XYSquaresMinimalData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='combined_scaled')], + [XYObjectData(), 1, 'first', GroundTruthDistSampler(num_samples=1)], + [XYObjectData(), 2, 'first', GroundTruthDistSampler(num_samples=2)], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3)], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='random')], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='factors')], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='manhattan')], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='manhattan_scaled')], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='combined')], + [XYObjectData(), 3, 'first', GroundTruthDistSampler(num_samples=3, triplet_sample_mode='combined_scaled')], - [XYSquaresMinimalData(), 1, 'first', RandomSampler(num_samples=1)], - [XYSquaresMinimalData(), 2, 'first', RandomSampler(num_samples=2)], - [XYSquaresMinimalData(), 3, 'first', RandomSampler(num_samples=3)], + [XYObjectData(), 1, 'first', RandomSampler(num_samples=1)], + [XYObjectData(), 2, 'first', RandomSampler(num_samples=2)], + [XYObjectData(), 3, 'first', RandomSampler(num_samples=3)], [TestEpisodesData(), 1, 'any', RandomEpisodeSampler(num_samples=1)], [TestEpisodesData(), 2, 'any', RandomEpisodeSampler(num_samples=2)], From 077b1bd60b1cb0f2ba8aea6bb4ef6ae31766cf6e Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 3 Aug 2021 00:29:17 +0200 Subject: [PATCH 004/171] revert removal of instantiate_object_if_needed --- .../frameworks/vae/experimental/_unsupervised__dorvae.py | 3 ++- .../frameworks/vae/experimental/_unsupervised__dotvae.py | 3 ++- .../vae/experimental/_weaklysupervised__augpostriplet.py | 9 ++++++--- experiment/util/hydra_utils.py | 8 ++++++++ 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py b/disent/frameworks/vae/experimental/_unsupervised__dorvae.py index 3b6d115e..1b1d0c24 100644 --- a/disent/frameworks/vae/experimental/_unsupervised__dorvae.py +++ b/disent/frameworks/vae/experimental/_unsupervised__dorvae.py @@ -36,7 +36,6 @@ from disent.frameworks.vae._weaklysupervised__adavae import AdaVae from disent.nn.loss.softsort import torch_mse_rank_loss from disent.nn.loss.softsort import spearman_rank_loss -from disent.util.config import instantiate_object_if_needed # ========================================================================= # @@ -78,6 +77,8 @@ def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cf if self.cfg.overlap_augment_mode != 'none': assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.' if self.cfg.overlap_augment is not None: + # TODO: this should not reference experiments! + from experiment.util.hydra_utils import instantiate_object_if_needed self._augment = instantiate_object_if_needed(self.cfg.overlap_augment) assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' # get overlap loss diff --git a/disent/frameworks/vae/experimental/_unsupervised__dotvae.py b/disent/frameworks/vae/experimental/_unsupervised__dotvae.py index e1e64442..f14a955d 100644 --- a/disent/frameworks/vae/experimental/_unsupervised__dotvae.py +++ b/disent/frameworks/vae/experimental/_unsupervised__dotvae.py @@ -35,7 +35,6 @@ from disent.frameworks.helper.reconstructions import ReconLossHandler from disent.frameworks.vae.experimental._supervised__adaneg_tvae import AdaNegTripletVae from disent.nn.loss.triplet_mining import configured_idx_mine -from disent.util.config import instantiate_object_if_needed log = logging.getLogger(__name__) @@ -78,6 +77,8 @@ def init_data_overlap_mixin(self): if self.cfg.overlap_augment_mode != 'none': assert self.cfg.overlap_augment is not None, 'if cfg.overlap_augment_mode is not "none", then cfg.overlap_augment must be defined.' if self.cfg.overlap_augment is not None: + # TODO: this should not reference experiments! + from experiment.util.hydra_utils import instantiate_object_if_needed self._augment = instantiate_object_if_needed(self.cfg.overlap_augment) assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' # get overlap loss diff --git a/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py b/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py index 3ba54796..9dfb402f 100644 --- a/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py +++ b/disent/frameworks/vae/experimental/_weaklysupervised__augpostriplet.py @@ -30,7 +30,6 @@ import torch from disent.frameworks.vae._supervised__tvae import TripletVae -from disent.util.config import instantiate_object_if_needed log = logging.getLogger(__name__) @@ -51,9 +50,13 @@ class cfg(TripletVae.cfg): def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) - self._augment = None # initialise & check augment - self._augment = instantiate_object_if_needed(self.cfg.overlap_augment) + self._augment = None + if self.cfg.overlap_augment is not None: + # TODO: this should not reference experiments! + from experiment.util.hydra_utils import instantiate_object_if_needed + self._augment = instantiate_object_if_needed(self.cfg.overlap_augment) + assert callable(self._augment), f'augment is not callable: {repr(self._augment)}' if self._augment is None: self._augment = torch.nn.Identity() warnings.warn(f'{self.__class__.__name__}, no overlap_augment was specified, defaulting to nn.Identity which WILL break things!') diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 8dfa3f33..93e76792 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -61,6 +61,14 @@ def instantiate_recursive(config): return call_recursive(config) +@deprecated('replace with hydra 1.1') +def instantiate_object_if_needed(config_or_object): + if isinstance(config_or_object, dict): + return instantiate_recursive(config_or_object) + else: + return config_or_object + + # ========================================================================= # # Better Specializations # # TODO: this might be replaced by recursive instantiation # From 66a60054b75c5a9824e645532562c8de9b54397b Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 3 Aug 2021 00:37:58 +0200 Subject: [PATCH 005/171] cherry-pick readme from main v0.1.0 --- README.md | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f0f61336..427188e1 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ * [Schedules & Annealing](#schedules--annealing) - [Examples](#examples) * [Python Example](#python-example) + * [Hydra Config Example](#hydra-config-example) - [Why?](#why) ---------------------- @@ -99,7 +100,7 @@ The disent directory structure: - `disent/model`: common encoder and decoder models used for VAE research - `disent/nn`: torch components for building models including layers, transforms, losses and general maths - `disent/schedule`: annealing schedules that can be registered to a framework -- `disent/util`: helper functions for the rest of the framework +- `disent/util`: helper classes, functions, callbacks, anything unrelated to a pytorch system/model/framework. **Please Note The API Is Still Unstable ⚠️** @@ -107,12 +108,23 @@ Disent is still under active development. Features and APIs are not considered s and should be expected to change! A limited set of tests currently exist which will be expanded upon in time. +**Hydra Experiment Directories** + +Easily run experiments with hydra config, these files +are not available from `pip install`. + +- `experiment/run.py`: entrypoint for running basic experiments with [hydra](https://github.com/facebookresearch/hydra) config +- `experiment/config`: root folder for [hydra](https://github.com/facebookresearch/hydra) config files +- `experiment/util`: various helper code for experiments + + ---------------------- ## Features -Disent includes implementations of modules, metrics and datasets -from various papers. As well as many custom experimental frameworks. +Disent includes implementations of modules, metrics and +datasets from various papers. Please note that items marked + with a "🧵" are introduced in and are unique to disent! ### Frameworks - **Unsupervised**: @@ -170,6 +182,13 @@ low-memory disk-based access. + SmallNORB + Shapes3D +- **Ground Truth Synthetic**: + + 🧵 XYObject: *A simplistic version of dSprites with a single square.* + +

+ XYObject Dataset Factor Traversals +

+ #### Input Transforms + Input/Target Augmentations - Input based transforms are supported. @@ -264,6 +283,59 @@ print('metrics:', metrics) Visit the [docs](https://disent.dontpanic.sh) for more examples! + +### Hydra Config Example + +The entrypoint for basic experiments is `experiment/run.py`. + +Some configuration will be required, but basic experiments can +be adjusted by modifying the [Hydra Config 1.0](https://github.com/facebookresearch/hydra) +files in `experiment/config` (Please note that hydra 1.1 is not yet supported). + +Modifying the main `experiment/config/config.yaml` is all you +need for most basic experiments. The main config file contains +a defaults list with entries corresponding to yaml configuration +files (config options) in the subfolders (config groups) in +`experiment/config//