From 91d121715dc544b1b8c6740f68dbae2660e0c38d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 9 Aug 2021 18:10:09 +0200 Subject: [PATCH 01/17] beta-vae comments --- disent/frameworks/vae/_unsupervised__betavae.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/disent/frameworks/vae/_unsupervised__betavae.py b/disent/frameworks/vae/_unsupervised__betavae.py index 3570922f..7426fe14 100644 --- a/disent/frameworks/vae/_unsupervised__betavae.py +++ b/disent/frameworks/vae/_unsupervised__betavae.py @@ -42,6 +42,9 @@ class BetaVae(Vae): + """ + beta-VAE: https://arxiv.org/abs/1312.6114 + """ REQUIRED_OBS = 1 @@ -55,16 +58,24 @@ class cfg(Vae.cfg): # loss = mean_recon_loss + beta * mean_kl_loss # -- for loss_reduction='mean_sum' we usually have: # loss = (H*W*C) * mean_recon_loss + beta * (z_size) * mean_kl_loss - # So when switching from one mode to the other, we need to scale beta to preserve these loss ratios. + # + # So when switching from one mode to the other, we need to scale beta to + # preserve these loss ratios: # -- 'mean_sum' to 'mean': # beta <- beta * (z_size) / (H*W*C) # -- 'mean' to 'mean_sum': # beta <- beta * (H*W*C) / (z_size) + # # We obtain an equivalent beta for 'mean_sum' to 'mean': # -- given values: beta=4 for 'mean_sum', with (H*W*C)=(64*64*3) and z_size=9 # beta = beta * ((z_size) / (H*W*C)) # ~= 4 * 0.0007324 # ~= 0,003 + # + # This is similar to appendix A.6: `INTERPRETING NORMALISED β` of the beta-Vae paper: + # - Published as a conference paper at ICLR 2017 (22 pages) + # - https://openreview.net/forum?id=Sy2fzU9gl + # beta: float = 0.003 # approximately equal to mean_sum beta of 4 def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): From 90d496393c53a6d9996ae47bcf306165b84fa1f4 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 9 Aug 2021 23:50:48 +0200 Subject: [PATCH 02/17] cherry pick: efb2428abd680433f6a79ea1562f2aee4f9b9c61 -- replaced `make_optimizer_fn` and `make_model_fn` with direct `optimizer` and `model` --- README.md | 8 ++- disent/frameworks/_framework.py | 68 ++++++++++++++++--- disent/frameworks/ae/_unsupervised__ae.py | 7 +- .../frameworks/vae/_unsupervised__betavae.py | 4 +- .../frameworks/vae/_unsupervised__dfcvae.py | 4 +- .../frameworks/vae/_unsupervised__dipvae.py | 4 +- .../frameworks/vae/_unsupervised__infovae.py | 4 +- disent/frameworks/vae/_unsupervised__vae.py | 4 +- docs/examples/mnist_example.py | 8 ++- docs/examples/overview_framework_adagvae.py | 8 ++- docs/examples/overview_framework_ae.py | 5 +- docs/examples/overview_framework_betavae.py | 5 +- .../overview_framework_betavae_scheduled.py | 5 +- docs/examples/overview_metrics.py | 5 +- tests/test_frameworks.py | 8 ++- 15 files changed, 100 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 427188e1..3f959e14 100644 --- a/README.md +++ b/README.md @@ -241,13 +241,15 @@ dataloader = DataLoader(dataset=dataset, batch_size=128, shuffle=True, num_worke # create the BetaVAE model # - adjusting the beta, learning rate, and representation size. module = BetaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-4), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( # z_multiplier is needed to output mu & logvar when parameterising normal distribution encoder=EncoderConv64(x_shape=data.x_shape, z_size=10, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=10), ), - cfg=BetaVae.cfg(loss_reduction='mean_sum', beta=4) + cfg=BetaVae.cfg( + optimizer='adam', optimizer_kwargs=dict(lr=1e-3), + loss_reduction='mean_sum', beta=4, + ) ) # cyclic schedule for target 'beta' in the config/cfg. The initial value from the diff --git a/disent/frameworks/_framework.py b/disent/frameworks/_framework.py index 081575d1..53a76cc0 100644 --- a/disent/frameworks/_framework.py +++ b/disent/frameworks/_framework.py @@ -22,7 +22,6 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ -import logging from dataclasses import asdict from dataclasses import dataclass from dataclasses import fields @@ -31,7 +30,9 @@ from typing import Any from typing import Dict from typing import final +from typing import Optional from typing import Tuple +from typing import Type from typing import Union import logging @@ -71,6 +72,26 @@ def __init__(self, cfg: cfg = cfg()): self.cfg = cfg +# ========================================================================= # +# optimizers # +# ========================================================================= # + + +def _get_optimizer_list() -> Dict[str, Type[torch.optim.Optimizer]]: + # generate list of optimizers from torch + # - optimizer names are lowercase, eg. adam & rmsprop + optimizers = {} + for k in dir(torch.optim): + optim = getattr(torch.optim, k) + if isinstance(optim, type) and issubclass(optim, torch.optim.Optimizer) and (optim != torch.optim.Optimizer): + optimizers[k.lower()] = optim + return optimizers + + +# list of optimizers +_OPTIMIZERS = _get_optimizer_list() + + # ========================================================================= # # framework # # ========================================================================= # @@ -80,24 +101,51 @@ class DisentFramework(DisentConfigurable, DisentLightningModule): @dataclass class cfg(DisentConfigurable.cfg): - pass + # optimizer config + optimizer: Union[str, Type[torch.optim.Optimizer]] = torch.optim.Adam + optimizer_kwargs: Optional[Dict[str, Union[str, float, int]]] = None - def __init__(self, make_optimizer_fn, batch_augment=None, cfg: cfg = None): + def __init__( + self, + cfg: cfg = None, + # apply the batch augmentations on the GPU instead + batch_augment: callable = None, + ): + # save the config values to the class super().__init__(cfg=cfg) - # optimiser - assert callable(make_optimizer_fn) - self._make_optimiser_fn = make_optimizer_fn - # batch augmentations: not implemented as dataset transforms because we want to apply these on the GPU - assert (batch_augment is None) or callable(batch_augment) + # get the optimizer + if isinstance(self.cfg.optimizer, str): + if self.cfg.optimizer not in _OPTIMIZERS: + raise KeyError(f'invalid optimizer: {repr(self.cfg.optimizer)}, valid optimizers are: {sorted(_OPTIMIZERS.keys())}, otherwise pass a torch.optim.Optimizer class instead.') + self.cfg.optimizer = _OPTIMIZERS[self.cfg.optimizer] + # check the optimizer values + assert isinstance(self.cfg.optimizer, type) and issubclass(self.cfg.optimizer, torch.optim.Optimizer) and (self.cfg.optimizer != torch.optim.Optimizer) + assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None) + # set default values for optimizer + if self.cfg.optimizer_kwargs is None: + self.cfg.optimizer_kwargs = dict() + if 'lr' not in self.cfg.optimizer_kwargs: + self.cfg.optimizer_kwargs['lr'] = 1e-3 + log.info('lr not specified in `optimizer_kwargs`, setting to default value of `1e-3`') + # batch augmentations may not be implemented as dataset + # transforms so we can apply these on the GPU instead + assert callable(batch_augment) or (batch_augment is None) self._batch_augment = batch_augment # schedules + # - maybe add support for schedules in the config? self._registered_schedules = set() self._active_schedules: Dict[str, Tuple[Any, Schedule]] = {} @final def configure_optimizers(self): - # return optimizers - return self._make_optimiser_fn(self.parameters()) + optimizer = self.cfg.optimizer + # instantiate the optimizer! + if issubclass(optimizer, torch.optim.Optimizer): + optimizer = optimizer(self.parameters(), **self.cfg.optimizer_kwargs) + elif not isinstance(optimizer, torch.optim.Optimizer): + raise TypeError(f'unsupported optimizer type: {type(optimizer)}') + # return the optimizer + return optimizer @final def training_step(self, batch, batch_idx): diff --git a/disent/frameworks/ae/_unsupervised__ae.py b/disent/frameworks/ae/_unsupervised__ae.py index 7f74baa8..dedb9e2f 100644 --- a/disent/frameworks/ae/_unsupervised__ae.py +++ b/disent/frameworks/ae/_unsupervised__ae.py @@ -89,11 +89,10 @@ class cfg(DisentFramework.cfg): disable_rec_loss: bool = False disable_aug_loss: bool = False - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): - super().__init__(make_optimizer_fn, batch_augment=batch_augment, cfg=cfg) + def __init__(self, model: AutoEncoder, cfg: cfg = None, batch_augment=None): + super().__init__(cfg=cfg, batch_augment=batch_augment) # vae model - assert callable(make_model_fn) - self._model: AutoEncoder = make_model_fn() # TODO: move into property + self._model = model # check the model assert isinstance(self._model, AutoEncoder) assert self._model.z_multiplier == self.REQUIRED_Z_MULTIPLIER, f'model z_multiplier is {repr(self._model.z_multiplier)} but {self.__class__.__name__} requires that it is: {repr(self.REQUIRED_Z_MULTIPLIER)}' diff --git a/disent/frameworks/vae/_unsupervised__betavae.py b/disent/frameworks/vae/_unsupervised__betavae.py index 7426fe14..e57aff14 100644 --- a/disent/frameworks/vae/_unsupervised__betavae.py +++ b/disent/frameworks/vae/_unsupervised__betavae.py @@ -78,8 +78,8 @@ class cfg(Vae.cfg): # beta: float = 0.003 # approximately equal to mean_sum beta of 4 - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): - super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) + def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): + super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) assert self.cfg.beta >= 0, 'beta must be >= 0' # --------------------------------------------------------------------- # diff --git a/disent/frameworks/vae/_unsupervised__dfcvae.py b/disent/frameworks/vae/_unsupervised__dfcvae.py index 69f2a457..1bc93b0f 100644 --- a/disent/frameworks/vae/_unsupervised__dfcvae.py +++ b/disent/frameworks/vae/_unsupervised__dfcvae.py @@ -69,8 +69,8 @@ class cfg(BetaVae.cfg): feature_layers: Optional[List[Union[str, int]]] = None feature_inputs_mode: str = 'none' - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): - super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) + def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): + super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) # make dfc loss # TODO: this should be converted to a reconstruction loss handler that wraps another handler self._dfc_loss = DfcLossModule(feature_layers=self.cfg.feature_layers, input_mode=self.cfg.feature_inputs_mode) diff --git a/disent/frameworks/vae/_unsupervised__dipvae.py b/disent/frameworks/vae/_unsupervised__dipvae.py index f58c9a86..8a5313cc 100644 --- a/disent/frameworks/vae/_unsupervised__dipvae.py +++ b/disent/frameworks/vae/_unsupervised__dipvae.py @@ -55,8 +55,8 @@ class cfg(BetaVae.cfg): lambda_d: float = 10. lambda_od: float = 5. - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): - super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) + def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): + super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) # checks assert self.cfg.dip_mode in {'i', 'ii'}, f'unsupported dip_mode={repr(self.cfg.dip_mode)} for {self.__class__.__name__}. Must be one of: {{"i", "ii"}}' assert self.cfg.dip_beta >= 0, 'dip_beta must be >= 0' diff --git a/disent/frameworks/vae/_unsupervised__infovae.py b/disent/frameworks/vae/_unsupervised__infovae.py index 1c48b4ff..a855f1d2 100644 --- a/disent/frameworks/vae/_unsupervised__infovae.py +++ b/disent/frameworks/vae/_unsupervised__infovae.py @@ -63,8 +63,8 @@ class cfg(Vae.cfg): # this is optional maintain_reg_ratio: bool = True - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): - super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) + def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): + super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) # checks assert self.cfg.info_alpha <= 0, f'cfg.info_alpha must be <= zero, current value is: {self.cfg.info_alpha}' assert self.cfg.loss_reduction == 'mean', 'InfoVAE only supports cfg.loss_reduction == "mean"' diff --git a/disent/frameworks/vae/_unsupervised__vae.py b/disent/frameworks/vae/_unsupervised__vae.py index 17df7771..22a32d47 100644 --- a/disent/frameworks/vae/_unsupervised__vae.py +++ b/disent/frameworks/vae/_unsupervised__vae.py @@ -101,9 +101,9 @@ class cfg(Ae.cfg): disable_reg_loss: bool = False disable_posterior_scale: Optional[float] = None - def __init__(self, make_optimizer_fn, make_model_fn, batch_augment=None, cfg: cfg = None): + def __init__(self, model: 'AutoEncoder', cfg: cfg = None, batch_augment=None): # required_z_multiplier - super().__init__(make_optimizer_fn, make_model_fn, batch_augment=batch_augment, cfg=cfg) + super().__init__(model=model, cfg=cfg, batch_augment=batch_augment) # vae distribution self.__latents_handler = make_latent_distribution(self.cfg.latent_distribution, kl_mode=self.cfg.kl_loss_mode, reduction=self.cfg.loss_reduction) diff --git a/docs/examples/mnist_example.py b/docs/examples/mnist_example.py index 65e57290..f7e4d676 100644 --- a/docs/examples/mnist_example.py +++ b/docs/examples/mnist_example.py @@ -30,12 +30,14 @@ def __getitem__(self, index): # create the model module = AdaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderFC(x_shape=(1, 28, 28), z_size=9, z_multiplier=2), decoder=DecoderFC(x_shape=(1, 28, 28), z_size=9), ), - cfg=AdaVae.cfg(beta=4, recon_loss='mse', loss_reduction='mean_sum') # "mean_sum" is the traditional reduction, rather than "mean" + cfg=AdaVae.cfg( + optimizer='adam', optimizer_kwargs=dict(lr=1e-3), + beta=4, recon_loss='mse', loss_reduction='mean_sum', # "mean_sum" is the traditional loss reduction mode, rather than "mean" + ) ) # train the model diff --git a/docs/examples/overview_framework_adagvae.py b/docs/examples/overview_framework_adagvae.py index ca908369..11c88009 100644 --- a/docs/examples/overview_framework_adagvae.py +++ b/docs/examples/overview_framework_adagvae.py @@ -18,12 +18,14 @@ # create the pytorch lightning system module: pl.LightningModule = AdaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), ), - cfg=AdaVae.cfg(loss_reduction='mean_sum', beta=4, ada_average_mode='gvae', ada_thresh_mode='kl') + cfg=AdaVae.cfg( + optimizer='adam', optimizer_kwargs=dict(lr=1e-3), + loss_reduction='mean_sum', beta=4, ada_average_mode='gvae', ada_thresh_mode='kl', + ) ) # train the model diff --git a/docs/examples/overview_framework_ae.py b/docs/examples/overview_framework_ae.py index 4e65a015..d73124d1 100644 --- a/docs/examples/overview_framework_ae.py +++ b/docs/examples/overview_framework_ae.py @@ -18,12 +18,11 @@ # create the pytorch lightning system module: pl.LightningModule = Ae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderConv64(x_shape=data.x_shape, z_size=6), decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), ), - cfg=Ae.cfg(loss_reduction='mean_sum') + cfg=Ae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum') ) # train the model diff --git a/docs/examples/overview_framework_betavae.py b/docs/examples/overview_framework_betavae.py index 28f0904a..14974506 100644 --- a/docs/examples/overview_framework_betavae.py +++ b/docs/examples/overview_framework_betavae.py @@ -18,12 +18,11 @@ # create the pytorch lightning system module: pl.LightningModule = BetaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), ), - cfg=BetaVae.cfg(loss_reduction='mean_sum', beta=4) + cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4) ) # train the model diff --git a/docs/examples/overview_framework_betavae_scheduled.py b/docs/examples/overview_framework_betavae_scheduled.py index 682f8298..4b0d8550 100644 --- a/docs/examples/overview_framework_betavae_scheduled.py +++ b/docs/examples/overview_framework_betavae_scheduled.py @@ -18,12 +18,11 @@ # create the pytorch lightning system module: pl.LightningModule = BetaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), ), - cfg=BetaVae.cfg(loss_reduction='mean_sum', beta=4) + cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4) ) # register the scheduler with the DisentFramework diff --git a/docs/examples/overview_metrics.py b/docs/examples/overview_metrics.py index 72de8864..afafff2b 100644 --- a/docs/examples/overview_metrics.py +++ b/docs/examples/overview_metrics.py @@ -17,12 +17,11 @@ def make_vae(beta): return BetaVae( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderConv64(x_shape=data.x_shape, z_size=6, z_multiplier=2), decoder=DecoderConv64(x_shape=data.x_shape, z_size=6), ), - cfg=BetaVae.cfg(beta=beta) + cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), beta=beta) ) def train(module): diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 24661b7b..8e9e9893 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -82,8 +82,7 @@ def test_frameworks(Framework, cfg_kwargs, Data): dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True) framework = Framework( - make_optimizer_fn=lambda params: Adam(params, lr=1e-3), - make_model_fn=lambda: AutoEncoder( + model=AutoEncoder( encoder=EncoderTest(x_shape=data.x_shape, z_size=6, z_multiplier=2 if issubclass(Framework, Vae) else 1), decoder=DecoderTest(x_shape=data.x_shape, z_size=6), ), @@ -95,8 +94,11 @@ def test_frameworks(Framework, cfg_kwargs, Data): def test_framework_config_defaults(): + import torch # we test that defaults are working recursively assert asdict(BetaVae.cfg()) == dict( + optimizer=torch.optim.adam.Adam, + optimizer_kwargs=None, recon_loss='mse', disable_aug_loss=False, disable_decoder=False, @@ -109,6 +111,8 @@ def test_framework_config_defaults(): beta=0.003, ) assert asdict(BetaVae.cfg(recon_loss='bce', kl_loss_mode='approx')) == dict( + optimizer=torch.optim.adam.Adam, + optimizer_kwargs=None, recon_loss='bce', disable_aug_loss=False, disable_decoder=False, From e29c3ef1d97fc0fa981a2dce2c1a4906c03a7b19 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 10 Aug 2021 02:17:42 +0200 Subject: [PATCH 03/17] cherry pick: 10cb8e811df1f76aaca | fix experiment runner --- disent/frameworks/_framework.py | 4 +-- experiment/config/optimizer/adabelief.yaml | 7 ++-- experiment/config/optimizer/adam.yaml | 9 +++--- experiment/config/optimizer/amsgrad.yaml | 9 +++--- experiment/config/optimizer/radam.yaml | 7 ++-- experiment/config/optimizer/rmsprop.yaml | 7 ++-- experiment/config/optimizer/sgd.yaml | 7 ++-- experiment/run.py | 37 +++++++++++++++------- experiment/util/hydra_utils.py | 20 ++++++------ tests/test_frameworks.py | 4 +-- 10 files changed, 59 insertions(+), 52 deletions(-) diff --git a/disent/frameworks/_framework.py b/disent/frameworks/_framework.py index 53a76cc0..99be4649 100644 --- a/disent/frameworks/_framework.py +++ b/disent/frameworks/_framework.py @@ -102,7 +102,7 @@ class DisentFramework(DisentConfigurable, DisentLightningModule): @dataclass class cfg(DisentConfigurable.cfg): # optimizer config - optimizer: Union[str, Type[torch.optim.Optimizer]] = torch.optim.Adam + optimizer: Union[str, Type[torch.optim.Optimizer]] = 'adam' optimizer_kwargs: Optional[Dict[str, Union[str, float, int]]] = None def __init__( @@ -120,7 +120,7 @@ def __init__( self.cfg.optimizer = _OPTIMIZERS[self.cfg.optimizer] # check the optimizer values assert isinstance(self.cfg.optimizer, type) and issubclass(self.cfg.optimizer, torch.optim.Optimizer) and (self.cfg.optimizer != torch.optim.Optimizer) - assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None) + assert isinstance(self.cfg.optimizer_kwargs, dict) or (self.cfg.optimizer_kwargs is None), f'invalid optimizer_kwargs type, got: {type(self.cfg.optimizer_kwargs)}' # set default values for optimizer if self.cfg.optimizer_kwargs is None: self.cfg.optimizer_kwargs = dict() diff --git a/experiment/config/optimizer/adabelief.yaml b/experiment/config/optimizer/adabelief.yaml index 4425fa33..db3d3b73 100644 --- a/experiment/config/optimizer/adabelief.yaml +++ b/experiment/config/optimizer/adabelief.yaml @@ -1,7 +1,6 @@ -# @package _group_ -name: adabelief -cls: - _target_: torch_optimizer.AdaBelief +# @package framework.module +optimizer: torch_optimizer.AdaBelief +optimizer_kwargs: lr: ${optimizer.lr} betas: [0.9, 0.999] eps: 1e-8 diff --git a/experiment/config/optimizer/adam.yaml b/experiment/config/optimizer/adam.yaml index 330e89fb..686a12c4 100644 --- a/experiment/config/optimizer/adam.yaml +++ b/experiment/config/optimizer/adam.yaml @@ -1,10 +1,9 @@ -# @package _group_ -name: adam -cls: - _target_: torch.optim.Adam +# @package framework.module +optimizer: torch.optim.Adam +optimizer_kwargs: lr: ${optimizer.lr} betas: [0.9, 0.999] eps: 1e-8 weight_decay: 0 - amsgrad: False \ No newline at end of file + amsgrad: False diff --git a/experiment/config/optimizer/amsgrad.yaml b/experiment/config/optimizer/amsgrad.yaml index 879c3c4a..ead824ca 100644 --- a/experiment/config/optimizer/amsgrad.yaml +++ b/experiment/config/optimizer/amsgrad.yaml @@ -1,10 +1,9 @@ -# @package _group_ -name: amsgrad -cls: - _target_: torch.optim.Adam +# @package framework.module +optimizer: torch.optim.Adam +optimizer_kwargs: lr: ${optimizer.lr} betas: [0.9, 0.999] eps: 1e-8 weight_decay: 0 - amsgrad: True \ No newline at end of file + amsgrad: True diff --git a/experiment/config/optimizer/radam.yaml b/experiment/config/optimizer/radam.yaml index d7281863..6ecfa5e8 100644 --- a/experiment/config/optimizer/radam.yaml +++ b/experiment/config/optimizer/radam.yaml @@ -1,7 +1,6 @@ -# @package _group_ -name: radam -cls: - _target_: torch_optimizer.RAdam +# @package framework.module +optimizer: torch_optimizer.RAdam +optimizer_kwargs: lr: ${optimizer.lr} betas: [0.9, 0.999] eps: 1e-8 diff --git a/experiment/config/optimizer/rmsprop.yaml b/experiment/config/optimizer/rmsprop.yaml index 35bec1bd..42e876c9 100644 --- a/experiment/config/optimizer/rmsprop.yaml +++ b/experiment/config/optimizer/rmsprop.yaml @@ -1,7 +1,6 @@ -# @package _group_ -name: rmsprop -cls: - _target_: torch.optim.RMSprop +# @package framework.module +optimizer: torch.optim.RMSprop +optimizer_kwargs: lr: ${optimizer.lr} # default was 1e-2 alpha: 0.99 eps: 1e-8 diff --git a/experiment/config/optimizer/sgd.yaml b/experiment/config/optimizer/sgd.yaml index f2c2fdfe..1dfe53da 100644 --- a/experiment/config/optimizer/sgd.yaml +++ b/experiment/config/optimizer/sgd.yaml @@ -1,7 +1,6 @@ -# @package _group_ -name: sgd -cls: - _target_: torch.optim.SGD +# @package framework.module +optimizer: torch.optim.SGD +optimizer_kwargs: lr: ${optimizer.lr} momentum: 0 dampening: 0 diff --git a/experiment/run.py b/experiment/run.py index 537a73c2..7ebad503 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -24,6 +24,7 @@ import logging import os +import sys import hydra import pytorch_lightning as pl @@ -209,12 +210,14 @@ def hydra_register_schedules(module: DisentFramework, cfg): def hydra_create_framework_config(cfg): # create framework config - this is also kinda hacky - framework_cfg: DisentConfigurable.cfg = hydra.utils.instantiate({ + # - we need instantiate_recursive because of optimizer_kwargs, + # otherwise the dictionary is left as an OmegaConf dict + framework_cfg: DisentConfigurable.cfg = instantiate_recursive({ **cfg.framework.module, **dict(_target_=cfg.framework.module._target_ + '.cfg') }) # warn if some of the cfg variables were not overridden - missing_keys = sorted(set(framework_cfg.get_keys()) - set(cfg.framework.module.keys())) + missing_keys = sorted(set(framework_cfg.get_keys()) - (set(cfg.framework.module.keys()))) if missing_keys: log.error(f'Framework {repr(cfg.framework.name)} is missing config keys for:') for k in missing_keys: @@ -225,11 +228,15 @@ def hydra_create_framework_config(cfg): return framework_cfg -def hydra_create_framework(framework_cfg, datamodule, cfg): +def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cfg): + # specific handling for experiment, this is HACKY! + # - not supported normally, we need to instantiate to get the class (is there hydra support for this?) + framework_cfg.optimizer = hydra.utils.instantiate(dict(_target_=framework_cfg.optimizer), [torch.Tensor()]).__class__ + framework_cfg.optimizer_kwargs = dict(framework_cfg.optimizer_kwargs) + # instantiate return hydra.utils.instantiate( dict(_target_=cfg.framework.module._target_), - make_optimizer_fn=lambda params: hydra.utils.instantiate(cfg.optimizer.cls, params), - make_model_fn=lambda: init_model_weights( + model=init_model_weights( AutoEncoder( encoder=hydra.utils.instantiate(cfg.model.encoder), decoder=hydra.utils.instantiate(cfg.model.decoder) @@ -246,7 +253,7 @@ def hydra_create_framework(framework_cfg, datamodule, cfg): # ========================================================================= # -def run(cfg: DictConfig): +def run(cfg: DictConfig, config_path: str = None): # allow the cfg to be edited cfg = make_non_strict(cfg) @@ -265,7 +272,11 @@ def run(cfg: DictConfig): log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") # hydra config does not support variables in defaults lists, we handle this manually - cfg = merge_specializations(cfg, CONFIG_PATH, run) + print(os.getcwd()) + print(os.getcwd()) + print(os.getcwd()) + print(os.getcwd()) + cfg = merge_specializations(cfg, config_path=CONFIG_PATH if (config_path is None) else config_path) # check CUDA setting cfg.trainer.setdefault('cuda', 'try_cuda') @@ -330,19 +341,23 @@ def run(cfg: DictConfig): # ========================================================================= # +# path to root directory containing configs +CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'config')) +# root config existing inside `CONFIG_ROOT`, with '.yaml' appended. +CONFIG_NAME = 'config' + + if __name__ == '__main__': - CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), 'config')) - CONFIG_NAME = 'config' @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME) - def main(cfg: DictConfig): + def hydra_main(cfg: DictConfig): try: run(cfg) except Exception as e: log_error_and_exit(err_type='experiment error', err_msg=str(e)) try: - main() + hydra_main() except KeyboardInterrupt as e: log_error_and_exit(err_type='interrupted', err_msg=str(e), exc_info=False) except Exception as e: diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 8dfa3f33..2a426952 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -23,6 +23,7 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +from typing import Optional import hydra from deprecated import deprecated @@ -74,32 +75,29 @@ def make_non_strict(cfg: DictConfig): @deprecated('replace with hydra 1.1') -def merge_specializations(cfg: DictConfig, config_path: str, main_fn: callable, strict=True): +def merge_specializations(cfg: DictConfig, config_path: str, strict=True): + import os + # TODO: this should eventually be replaced with hydra recursive defaults # TODO: this makes config non-strict, allows setdefault to work even if key does not exist in config + assert os.path.isabs(config_path), f'config_path cannot be relative for merge_specializations: {repr(config_path)}, current working directory: {repr(os.getcwd())}' + # skip if we do not have any specializations if 'specializations' not in cfg: + log.warning('`specializations` key not found in `cfg`, skipping merging specializations') return + # we allow overwrites & missing values to be inserted if not strict: - # we allow overwrites & missing values to be inserted cfg = make_non_strict(cfg) - # imports - import os - from hydra._internal.utils import detect_calling_file_or_module_from_task_function - - # get hydra config root - calling_file, _, _ = detect_calling_file_or_module_from_task_function(main_fn) - config_root = os.path.join(os.path.dirname(calling_file), config_path) - # set and update specializations for group, specialization in cfg.specializations.items(): assert group not in cfg, f'group={repr(group)} already exists on cfg, specialization merging is not supported!' log.info(f'merging specialization: {repr(specialization)}') # load specialization config - specialization_cfg = OmegaConf.load(os.path.join(config_root, group, f'{specialization}.yaml')) + specialization_cfg = OmegaConf.load(os.path.join(config_path, group, f'{specialization}.yaml')) # create new config cfg = OmegaConf.merge(cfg, {group: specialization_cfg}) diff --git a/tests/test_frameworks.py b/tests/test_frameworks.py index 8e9e9893..0f917c63 100644 --- a/tests/test_frameworks.py +++ b/tests/test_frameworks.py @@ -97,7 +97,7 @@ def test_framework_config_defaults(): import torch # we test that defaults are working recursively assert asdict(BetaVae.cfg()) == dict( - optimizer=torch.optim.adam.Adam, + optimizer='adam', optimizer_kwargs=None, recon_loss='mse', disable_aug_loss=False, @@ -111,7 +111,7 @@ def test_framework_config_defaults(): beta=0.003, ) assert asdict(BetaVae.cfg(recon_loss='bce', kl_loss_mode='approx')) == dict( - optimizer=torch.optim.adam.Adam, + optimizer='adam', optimizer_kwargs=None, recon_loss='bce', disable_aug_loss=False, From aafbd68dab7cb074f9d6fe9317c4a42be11ab811 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Tue, 10 Aug 2021 02:18:27 +0200 Subject: [PATCH 04/17] cherry pick: ecbca23d6dba88 | test for experiment run --- experiment/__init__.py | 23 ++++++++++++ experiment/config/__init__.py | 3 ++ experiment/config/config_test.yaml | 48 +++++++++++++++++++++++++ experiment/config/run_length/test.yaml | 4 +++ experiment/util/__init__.py | 23 ++++++++++++ tests/test_experiment.py | 50 ++++++++++++++++++++++++++ tests/util.py | 19 ++++++++-- 7 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 experiment/__init__.py create mode 100644 experiment/config/__init__.py create mode 100644 experiment/config/config_test.yaml create mode 100644 experiment/config/run_length/test.yaml create mode 100644 experiment/util/__init__.py create mode 100644 tests/test_experiment.py diff --git a/experiment/__init__.py b/experiment/__init__.py new file mode 100644 index 00000000..9a05a479 --- /dev/null +++ b/experiment/__init__.py @@ -0,0 +1,23 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ diff --git a/experiment/config/__init__.py b/experiment/config/__init__.py new file mode 100644 index 00000000..1ab71c1a --- /dev/null +++ b/experiment/config/__init__.py @@ -0,0 +1,3 @@ + +# for some unknown reason this file is required for tests/test_experiment.py to work +# this is very odd. Although it might be fixed in hydra 1.1? diff --git a/experiment/config/config_test.yaml b/experiment/config/config_test.yaml new file mode 100644 index 00000000..ca54be57 --- /dev/null +++ b/experiment/config/config_test.yaml @@ -0,0 +1,48 @@ +defaults: + # experiment + - framework: betavae + - model: vae_conv64 + - optimizer: adam + - dataset: xyobject + - dataset_sampling: full_bb + - augment: none + - schedule: none + - metrics: test + # runtime + - run_length: test + - run_location: local_cpu + - run_callbacks: vis_slow + - run_logging: none + # plugins + - hydra/job_logging: colorlog + - hydra/hydra_logging: colorlog + - hydra/launcher: submitit_slurm + +job: + user: invalid + project: invalid + name: '${framework.name}:${framework.module.recon_loss}|${dataset.name}:${dataset_sampling.name}|${trainer.steps}' + partition: invalid + seed: NULL + +framework: + beta: 0.001 + module: + recon_loss: mse + loss_reduction: mean + optional: + latent_distribution: normal # only used by VAEs + overlap_loss: NULL + +model: + z_size: 25 + +optimizer: + lr: 1e-3 + +# CUSTOM DEFAULTS SPECIALIZATION +# - This key is deleted on load and the correct key on the root config is set similar to defaults. +# - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. +specializations: + dataset_sampler: ${dataset.data_type}_${framework.data_sample_mode} +# dataset_sampler: gt_dist_${framework.data_sample_mode} diff --git a/experiment/config/run_length/test.yaml b/experiment/config/run_length/test.yaml new file mode 100644 index 00000000..146d0153 --- /dev/null +++ b/experiment/config/run_length/test.yaml @@ -0,0 +1,4 @@ +# @package _global_ +trainer: + epochs: 1 + steps: 1 diff --git a/experiment/util/__init__.py b/experiment/util/__init__.py new file mode 100644 index 00000000..9a05a479 --- /dev/null +++ b/experiment/util/__init__.py @@ -0,0 +1,23 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ diff --git a/tests/test_experiment.py b/tests/test_experiment.py new file mode 100644 index 00000000..31d98169 --- /dev/null +++ b/tests/test_experiment.py @@ -0,0 +1,50 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import os +import os.path + +import hydra +import experiment.run as experiment_run +from tests.util import temp_sys_args + +# ========================================================================= # +# TESTS # +# ========================================================================= # + + +def test_experiment_run(): + # used by run() internally + experiment_run.CONFIG_PATH = os.path.join(os.path.dirname(experiment_run.__file__), 'config') + + os.environ['HYDRA_FULL_ERROR'] = '1' + with temp_sys_args([experiment_run.__file__]): + # why does this not work when config is absolute? + hydra_main = hydra.main(config_path='config', config_name='config_test')(experiment_run.run) + hydra_main() + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/tests/util.py b/tests/util.py index c4a1f0e7..bcff6941 100644 --- a/tests/util.py +++ b/tests/util.py @@ -21,8 +21,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ - - +import contextlib import os import sys from contextlib import contextmanager @@ -49,6 +48,22 @@ def no_stderr(): sys.stderr = old_stderr +@contextlib.contextmanager +def temp_wd(new_wd): + old_wd = os.getcwd() + os.chdir(new_wd) + yield + os.chdir(old_wd) + + +@contextlib.contextmanager +def temp_sys_args(new_argv): + old_argv = sys.argv + sys.argv = new_argv + yield + sys.argv = old_argv + + # ========================================================================= # # END # # ========================================================================= # From dba0ce33a85c86636acbba48e419432b95d1169e Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:27:20 +0200 Subject: [PATCH 05/17] cherrypick a34ca6d99 --- disent/dataset/_wrapper.py | 43 ++++++++++++++++++---------- disent/model/_base.py | 2 +- disent/util/lightning/logger_util.py | 6 ++++ disent/util/seeds.py | 7 +++-- 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/disent/dataset/_wrapper.py b/disent/dataset/_wrapper.py index 4af85af2..15c05971 100644 --- a/disent/dataset/_wrapper.py +++ b/disent/dataset/_wrapper.py @@ -67,13 +67,21 @@ def wrapper(self: 'DisentDataset', *args, **kwargs): class DisentDataset(Dataset, LengthIter): - def __init__(self, dataset: Union[Dataset, GroundTruthData], sampler: Optional[BaseDisentSampler] = None, transform=None, augment=None): + def __init__( + self, + dataset: Union[Dataset, GroundTruthData], + sampler: Optional[BaseDisentSampler] = None, + transform=None, + augment=None, + return_indices: bool = False, + ): super().__init__() # save attributes self._dataset = dataset self._sampler = SingleSampler() if (sampler is None) else sampler self._transform = transform self._augment = augment + self._return_indices = return_indices # initialize sampler if not self._sampler.is_init: self._sampler.init(dataset) @@ -112,7 +120,7 @@ def __getitem__(self, idx): else: idxs = (idx,) # get the observations - return self.dataset_get_observation(*idxs) + return self._dataset_get_observation(*idxs) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Single Datapoints # @@ -177,19 +185,18 @@ def dataset_get(self, idx, mode: str): # Multiple Datapoints # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # - def dataset_get_observation(self, *idxs): + def _dataset_get_observation(self, *idxs): xs, xs_targ = zip(*(self.dataset_get(idx, mode='pair') for idx in idxs)) # handle cases - if self._augment is None: - # makes 5-10% faster - return { - 'x_targ': xs_targ, - } - else: - return { - 'x': xs, - 'x_targ': xs_targ, - } + obs = {'x_targ': xs_targ} + # 5-10% faster + if self._augment is not None: + obs['x'] = xs + # add indices + if self._return_indices: + obs['idx'] = idxs + # done! + return obs # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Batches # @@ -199,7 +206,7 @@ def dataset_batch_from_indices(self, indices: Sequence[int], mode: str): """Get a batch of observations X from a batch of factors Y.""" return default_collate([self.dataset_get(idx, mode=mode) for idx in indices]) - def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False): + def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = False, return_indices: bool = False): """Sample a batch of observations X.""" # create seeded pseudo random number generator # - built in np.random.choice cannot handle large values: https://github.com/numpy/numpy/issues/5299#issuecomment-497915672 @@ -208,7 +215,13 @@ def dataset_sample_batch(self, num_samples: int, mode: str, replace: bool = Fals g = np.random.Generator(np.random.PCG64(seed=np.random.randint(0, 2**32))) # sample indices indices = g.choice(len(self), num_samples, replace=replace) - return self.dataset_batch_from_indices(indices, mode=mode) + # return batch + batch = self.dataset_batch_from_indices(indices, mode=mode) + # return values + if return_indices: + return batch, default_collate(indices) + else: + return batch # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # # Batches -- Ground Truth Only # diff --git a/disent/model/_base.py b/disent/model/_base.py index da8bae7b..85acfa3c 100644 --- a/disent/model/_base.py +++ b/disent/model/_base.py @@ -146,7 +146,7 @@ def __init__(self, encoder: DisentEncoder, decoder: DisentDecoder): self._decoder = decoder def forward(self, x): - raise RuntimeError('This has been disabled') + raise RuntimeError(f'{self.__class__.__name__}.forward(...) has been disabled') def encode(self, x, chunk=False): z_raw = self._encoder(x, chunk=chunk) diff --git a/disent/util/lightning/logger_util.py b/disent/util/lightning/logger_util.py index 3882d31e..0daec37e 100644 --- a/disent/util/lightning/logger_util.py +++ b/disent/util/lightning/logger_util.py @@ -71,6 +71,12 @@ def wb_yield_loggers(logger: Optional[LightningLoggerBase]) -> Iterable[WandbLog yield from wb_yield_loggers(l) +def wb_has_logger(logger: Optional[LightningLoggerBase]) -> bool: + for l in wb_yield_loggers(logger): + return True + return False + + def wb_log_metrics(logger: Optional[LightningLoggerBase], metrics_dct: dict): """ Log the given values only to loggers that are an instance of WandbLogger diff --git a/disent/util/seeds.py b/disent/util/seeds.py index 3a0b3ac8..c1f2164b 100644 --- a/disent/util/seeds.py +++ b/disent/util/seeds.py @@ -22,9 +22,9 @@ # SOFTWARE. # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +import contextlib import logging import random - import numpy as np @@ -59,7 +59,7 @@ def seed(long=777): log.info(f'[SEEDED]: {long}') -class TempNumpySeed(object): +class TempNumpySeed(contextlib.ContextDecorator): def __init__(self, seed=None, offset=0): if seed is not None: try: @@ -81,6 +81,9 @@ def __exit__(self, *args, **kwargs): np.random.set_state(self._state) self._state = None + def _recreate_cm(self): + # TODO: do we need to override this? + return self # ========================================================================= # # END # From de1a691f4bce9922a9996cdeca1eaa1e392460d5 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:29:38 +0200 Subject: [PATCH 06/17] cherrypick 8c8b53782f - update progress callback --- .../util/lightning/callbacks/_callbacks_pl.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/disent/util/lightning/callbacks/_callbacks_pl.py b/disent/util/lightning/callbacks/_callbacks_pl.py index 42bba0ff..add9d960 100644 --- a/disent/util/lightning/callbacks/_callbacks_pl.py +++ b/disent/util/lightning/callbacks/_callbacks_pl.py @@ -39,20 +39,30 @@ class LoggerProgressCallback(BaseCallbackTimed): def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, current_time, start_time): - # vars - batch, max_batches = trainer.batch_idx + 1, trainer.num_training_batches - epoch, max_epoch = trainer.current_epoch + 1, min(trainer.max_epochs, (trainer.max_steps + max_batches - 1) // max_batches) - global_step, global_steps = trainer.global_step + 1, min(trainer.max_epochs * max_batches, trainer.max_steps) - # computed - train_pct = global_step / global_steps + # get missing vars + trainer_max_epochs = trainer.max_epochs if (trainer.max_epochs is not None) else float('inf') + trainer_max_steps = trainer.max_steps if (trainer.max_steps is not None) else float('inf') + # get vars + batch = trainer.batch_idx + 1 + epoch = trainer.current_epoch + 1 + global_step = trainer.global_step + 1 + # compute vars + max_batches = trainer.num_training_batches + max_epochs = min(trainer_max_epochs, (trainer_max_steps + max_batches - 1) // max_batches) + max_steps = min(trainer_max_epochs * max_batches, trainer_max_steps) # completion + train_pct = global_step / max_steps train_remain_time = (current_time - start_time) * (1 - train_pct) / train_pct # info dict - info_dict = {k: f'{v:.4g}' if isinstance(v, (int, float)) else f'{v}' for k, v in trainer.progress_bar_dict.items() if k != 'v_num'} + info_dict = { + k: f'{v:.4g}' if isinstance(v, (int, float)) else f'{v}' + for k, v in trainer.progress_bar_dict.items() + if k != 'v_num' + } sorted_k = sorted(info_dict.keys(), key=lambda k: ('loss' != k.lower(), 'loss' not in k.lower(), k)) # log log.info( - f'EPOCH: {epoch}/{max_epoch} - {global_step:0{len(str(global_steps))}d}/{global_steps} ' + f'EPOCH: {epoch}/{max_epochs} - {global_step:0{len(str(max_steps))}d}/{max_steps} ' f'({int(train_pct * 100):02d}%) [{int(train_remain_time)}s] ' f'STEP: {batch:{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) ' f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}' From 8c238dd187d7a85fe2ad4ae65911ffdd1bea3828 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:30:15 +0200 Subject: [PATCH 07/17] cherrypick 94e7cfe96 - fix progress --- .../util/lightning/callbacks/_callbacks_pl.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/disent/util/lightning/callbacks/_callbacks_pl.py b/disent/util/lightning/callbacks/_callbacks_pl.py index add9d960..d226f557 100644 --- a/disent/util/lightning/callbacks/_callbacks_pl.py +++ b/disent/util/lightning/callbacks/_callbacks_pl.py @@ -23,6 +23,8 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging +import warnings + import pytorch_lightning as pl from disent.util.lightning.callbacks._callbacks_base import BaseCallbackTimed @@ -42,17 +44,28 @@ def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, curren # get missing vars trainer_max_epochs = trainer.max_epochs if (trainer.max_epochs is not None) else float('inf') trainer_max_steps = trainer.max_steps if (trainer.max_steps is not None) else float('inf') - # get vars - batch = trainer.batch_idx + 1 - epoch = trainer.current_epoch + 1 - global_step = trainer.global_step + 1 + # compute vars max_batches = trainer.num_training_batches max_epochs = min(trainer_max_epochs, (trainer_max_steps + max_batches - 1) // max_batches) max_steps = min(trainer_max_epochs * max_batches, trainer_max_steps) + elapsed_sec = current_time - start_time + # get vars + global_step = trainer.global_step + 1 + epoch = trainer.current_epoch + 1 + if hasattr(trainer, 'batch_idx'): + batch = (trainer.batch_idx + 1) + else: + warnings.warn('batch_idx missing on pl.Trainer') + batch = global_step % max_batches # completion train_pct = global_step / max_steps - train_remain_time = (current_time - start_time) * (1 - train_pct) / train_pct + train_remain_time = elapsed_sec * (1 - train_pct) / train_pct # seconds + # get speed -- TODO: make this a moving average? + if global_step >= elapsed_sec: + step_speed_str = f'{global_step / elapsed_sec:4.2f}it/s' + else: + step_speed_str = f'{elapsed_sec / global_step:4.2f}s/it' # info dict info_dict = { k: f'{v:.4g}' if isinstance(v, (int, float)) else f'{v}' @@ -62,8 +75,9 @@ def do_interval(self, trainer: pl.Trainer, pl_module: pl.LightningModule, curren sorted_k = sorted(info_dict.keys(), key=lambda k: ('loss' != k.lower(), 'loss' not in k.lower(), k)) # log log.info( + f'[{int(elapsed_sec)}s, {step_speed_str}] ' f'EPOCH: {epoch}/{max_epochs} - {global_step:0{len(str(max_steps))}d}/{max_steps} ' - f'({int(train_pct * 100):02d}%) [{int(train_remain_time)}s] ' + f'({int(train_pct * 100):02d}%) [rem. {int(train_remain_time)}s] ' f'STEP: {batch:{len(str(max_batches))}d}/{max_batches} ({int(batch / max_batches * 100):02d}%) ' f'| {" ".join(f"{k}={info_dict[k]}" for k in sorted_k)}' ) From 1c68c7e851e10fe40b7518359a91b6ff29c3378c Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Wed, 18 Aug 2021 03:43:15 +0200 Subject: [PATCH 08/17] fix smallnorb --- disent/dataset/data/_groundtruth__norb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disent/dataset/data/_groundtruth__norb.py b/disent/dataset/data/_groundtruth__norb.py index d898e37f..8dd74012 100644 --- a/disent/dataset/data/_groundtruth__norb.py +++ b/disent/dataset/data/_groundtruth__norb.py @@ -164,7 +164,7 @@ def __init__(self, data_root: Optional[str] = None, prepare: bool = False, is_te self._data, _ = read_norb_dataset(dat_path=dat_path, cat_path=cat_path, info_path=info_path) def _get_observation(self, idx): - return self._data[idx] + return self._data[idx][:, :, None] # data is missing channel dim @property def datafiles(self) -> Sequence[DataFileHashedDl]: From f9094ba377d6fee3e6e97711ca7fde5471621ff5 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:33:01 +0200 Subject: [PATCH 09/17] cherrypick 9eaebd - fix for PL 1.4 --- experiment/run.py | 21 +++++++++++++++++---- experiment/util/hydra_data.py | 6 +++++- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/experiment/run.py b/experiment/run.py index 7ebad503..d000c69a 100644 --- a/experiment/run.py +++ b/experiment/run.py @@ -25,11 +25,13 @@ import logging import os import sys +from datetime import datetime import hydra import pytorch_lightning as pl import torch import torch.utils.data +import wandb from omegaconf import DictConfig from omegaconf import OmegaConf from pytorch_lightning.loggers import CometLogger @@ -254,6 +256,21 @@ def hydra_create_framework(framework_cfg: DisentConfigurable.cfg, datamodule, cf def run(cfg: DictConfig, config_path: str = None): + + # get the time the run started + time_string = datetime.today().strftime('%Y-%m-%d--%H-%M-%S') + log.info(f'Starting run at time: {time_string}') + + # -~-~-~-~-~-~-~-~-~-~-~-~- # + + # cleanup from old runs: + try: + wandb.finish() + except: + pass + + # -~-~-~-~-~-~-~-~-~-~-~-~- # + # allow the cfg to be edited cfg = make_non_strict(cfg) @@ -272,10 +289,6 @@ def run(cfg: DictConfig, config_path: str = None): log.info(f"Orig working directory : {hydra.utils.get_original_cwd()}") # hydra config does not support variables in defaults lists, we handle this manually - print(os.getcwd()) - print(os.getcwd()) - print(os.getcwd()) - print(os.getcwd()) cfg = merge_specializations(cfg, config_path=CONFIG_PATH if (config_path is None) else config_path) # check CUDA setting diff --git a/experiment/util/hydra_data.py b/experiment/util/hydra_data.py index cf1d54b1..1071c6a1 100644 --- a/experiment/util/hydra_data.py +++ b/experiment/util/hydra_data.py @@ -78,7 +78,11 @@ class HydraDataModule(pl.LightningDataModule): def __init__(self, hparams: DictConfig): super().__init__() - self.hparams = hparams + # support pytorch lightning < 1.4 + if not hasattr(self, 'hparams'): + self.hparams = {} + # set values + self.hparams.update(hparams) # transform: prepares data from datasets self.data_transform = instantiate_recursive(self.hparams.dataset.transform) assert (self.data_transform is None) or callable(self.data_transform) From cbedd99f860db5e36f88625b6246148c131c285d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:34:08 +0200 Subject: [PATCH 10/17] cherrypick 57154d0 - sample_random_obs_traversal --- disent/dataset/data/_groundtruth.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/disent/dataset/data/_groundtruth.py b/disent/dataset/data/_groundtruth.py index eddd41da..e33a5f99 100644 --- a/disent/dataset/data/_groundtruth.py +++ b/disent/dataset/data/_groundtruth.py @@ -25,6 +25,8 @@ import logging import os from abc import ABCMeta +from typing import Any +from typing import List from typing import Optional from typing import Sequence from typing import Tuple @@ -100,6 +102,22 @@ def __getitem__(self, idx): def _get_observation(self, idx): raise NotImplementedError + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + # EXTRAS # + # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # + + def sample_random_obs_traversal(self, f_idx: int = None, base_factors=None, num: int = None, mode='interval', obs_collect_fn=None) -> Tuple[np.ndarray, np.ndarray, Union[List[Any], Any]]: + """ + Same API as sample_random_factor_traversal, but also + returns the corresponding indices and uncollated list of observations + """ + factors = self.sample_random_factor_traversal(f_idx=f_idx, base_factors=base_factors, num=num, mode=mode) + indices = self.pos_to_idx(factors) + obs = [self[i] for i in indices] + if obs_collect_fn is not None: + obs = obs_collect_fn(obs) + return factors, indices, obs + # ========================================================================= # # Basic Array Ground Truth Dataset # From 05c063a2392fd431489a74ced5211168db0d04ec Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:34:39 +0200 Subject: [PATCH 11/17] cherrypick 2d67ccd - fix lerp --- disent/schedule/lerp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/disent/schedule/lerp.py b/disent/schedule/lerp.py index 9fea0d14..5803c785 100644 --- a/disent/schedule/lerp.py +++ b/disent/schedule/lerp.py @@ -39,7 +39,7 @@ def scale(r, a, b): def lerp(r, a, b): """Linear interpolation between parameters, respects bounds when t is out of bounds [0, 1]""" - assert a < b + # assert a < b r = np.clip(r, 0., 1.) # precise method, guarantees v==b when t==1 | simplifies to: a + t*(b-a) return (1 - r) * a + r * b From e2892e4ee250c83603fa2f05004ac359cd7421f0 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:36:04 +0200 Subject: [PATCH 12/17] cherrypick a2e5386 - renamed _wrapper to _base --- disent/dataset/__init__.py | 2 +- disent/dataset/{_wrapper.py => _base.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename disent/dataset/{_wrapper.py => _base.py} (100%) diff --git a/disent/dataset/__init__.py b/disent/dataset/__init__.py index 7ec11016..84f37c72 100644 --- a/disent/dataset/__init__.py +++ b/disent/dataset/__init__.py @@ -23,4 +23,4 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ # wrapper -from disent.dataset._wrapper import DisentDataset +from disent.dataset._base import DisentDataset diff --git a/disent/dataset/_wrapper.py b/disent/dataset/_base.py similarity index 100% rename from disent/dataset/_wrapper.py rename to disent/dataset/_base.py From dec13f2a026248c3b715f9ba96aa245f7e140b93 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:37:36 +0200 Subject: [PATCH 13/17] minor fix --- disent/dataset/data/_groundtruth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/disent/dataset/data/_groundtruth.py b/disent/dataset/data/_groundtruth.py index e33a5f99..0b5d7bcd 100644 --- a/disent/dataset/data/_groundtruth.py +++ b/disent/dataset/data/_groundtruth.py @@ -30,6 +30,7 @@ from typing import Optional from typing import Sequence from typing import Tuple +from typing import Union import numpy as np from torch.utils.data import Dataset From efc007ac2ac5f95d90c17704dce653788178134d Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:42:25 +0200 Subject: [PATCH 14/17] cherrypick 9eea8b2 - remove Deprecate, fixes hydra submittit pickle bug --- disent/frameworks/helper/reconstructions.py | 2 +- disent/util/deprecate.py | 68 +++++++++++++++++++++ experiment/util/hydra_utils.py | 6 +- requirements.txt | 1 - 4 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 disent/util/deprecate.py diff --git a/disent/frameworks/helper/reconstructions.py b/disent/frameworks/helper/reconstructions.py index 4aca1138..e6252799 100644 --- a/disent/frameworks/helper/reconstructions.py +++ b/disent/frameworks/helper/reconstructions.py @@ -30,7 +30,7 @@ import torch import torch.nn.functional as F -from deprecated import deprecated +from disent.util.deprecate import deprecated from disent.frameworks.helper.util import compute_ave_loss from disent.nn.modules import DisentModule diff --git a/disent/util/deprecate.py b/disent/util/deprecate.py new file mode 100644 index 00000000..97d8dc11 --- /dev/null +++ b/disent/util/deprecate.py @@ -0,0 +1,68 @@ +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ +# MIT License +# +# Copyright (c) 2021 Nathan Juraj Michlo +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ + +import logging +from functools import wraps + + +# ========================================================================= # +# Deprecate # +# ========================================================================= # + + +def deprecated(msg: str): + """ + Mark a function or class as deprecated, and print a warning the + first time it is used. + - This decorator wraps functions, but only replaces the __init__ + method of a class so that we can still inherit from a deprecated class! + """ + def _decorator(fn): + # we need to handle classes and function separately + is_class = isinstance(fn, type) and hasattr(fn, '__init__') + # backup the original function & data + call_fn = fn.__init__ if is_class else fn + dat = (fn.__module__, f'{fn.__module__}.{fn.__name__}', str(msg)) + # wrapper function + @wraps(call_fn) + def _caller(*args, **kwargs): + nonlocal dat + # print the message! + if dat is not None: + name, path, dsc = dat + logging.getLogger(name).warning(f'[DEPRECATED] {path} - {repr(dsc)}') + dat = None + return call_fn(*args, **kwargs) + # handle function or class + if is_class: + fn.__init__ = _caller + else: + fn = _caller + return fn + return _decorator + + +# ========================================================================= # +# END # +# ========================================================================= # diff --git a/experiment/util/hydra_utils.py b/experiment/util/hydra_utils.py index 2a426952..41b90a13 100644 --- a/experiment/util/hydra_utils.py +++ b/experiment/util/hydra_utils.py @@ -23,14 +23,15 @@ # ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~ import logging -from typing import Optional +from copy import deepcopy import hydra -from deprecated import deprecated from omegaconf import DictConfig from omegaconf import ListConfig from omegaconf import OmegaConf +from disent.util.deprecate import deprecated + log = logging.getLogger(__name__) @@ -71,6 +72,7 @@ def instantiate_recursive(config): @deprecated('replace with hydra 1.1') def make_non_strict(cfg: DictConfig): + cfg = deepcopy(cfg) return OmegaConf.create({**cfg}) diff --git a/requirements.txt b/requirements.txt index 2858b562..639231c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,4 +17,3 @@ h5py>=2.10.0 # as of tensorflow 2.4 it does not support h5py 3+ # UTILITY # ======= tqdm>=4.60.0 -Deprecated>=1.2.12 From fce492dde9b3d7f690c2f07d1787243b093a05ad Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 15:47:05 +0200 Subject: [PATCH 15/17] update config defaults --- experiment/config/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experiment/config/config.yaml b/experiment/config/config.yaml index 072c8cdf..5a708472 100644 --- a/experiment/config/config.yaml +++ b/experiment/config/config.yaml @@ -48,7 +48,7 @@ optimizer: # - Unfortunately this hack needs to exists as hydra does not yet support this kinda of variable interpolation in defaults. specializations: # original samplers -- - # dataset_sampler: ${dataset.data_type}_${framework.data_sample_mode} + dataset_sampler: ${dataset.data_type}_${framework.data_sample_mode} # newer samplers -- only active for frameworks that require 3 observations, otherwise random for 2, and exact for 1 - dataset_sampler: gt_dist_${framework.data_sample_mode} + # dataset_sampler: gt_dist_${framework.data_sample_mode} From dd8cba76a99e9f89e066ff9f70d4bb2bb93247eb Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 16:11:04 +0200 Subject: [PATCH 16/17] py38 to py39 tests --- .github/workflows/python-test.yml | 9 +++++---- requirements-exp.txt | 2 +- requirements.txt | 2 +- setup.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index c7034fb6..e2900e6a 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -5,10 +5,10 @@ name: test on: push: - branches: [ main, dev ] + branches: [ "main", "dev", "dev*", "feature*"] tags: [ '*' ] pull_request: - branches: [ main, dev ] + branches: [ "main", "dev", "dev*", "feature*"] jobs: build: @@ -16,7 +16,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] # [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.8] + python-version: ["3.8", "3.9"] steps: - uses: actions/checkout@v2 @@ -31,6 +31,7 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install -r requirements.txt python3 -m pip install -r requirements-test.txt + python3 -m pip install -r requirements-exp.txt - name: Test with pytest run: | @@ -39,6 +40,6 @@ jobs: - uses: codecov/codecov-action@v1 with: token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: true + fail_ci_if_error: false # codecov automatically merges all generated files # if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9 diff --git a/requirements-exp.txt b/requirements-exp.txt index 63b586f1..7eefc6d8 100644 --- a/requirements-exp.txt +++ b/requirements-exp.txt @@ -22,6 +22,6 @@ wandb>=0.10.32 # UTILITY # ======= -hydra-core==1.0.6 +hydra-core==1.0.7 hydra-colorlog==1.0.1 hydra-submitit-launcher==1.1.1 diff --git a/requirements.txt b/requirements.txt index 639231c8..b559ab01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ pip>=21.0 # DATA SCIENCE & ML # ================= -numpy>=1.21.0 +numpy>=1.19.0 torch>=1.9.0 torchvision>=0.10.0 pytorch-lightning>=1.3.7 diff --git a/setup.py b/setup.py index 4df25f28..837bd010 100644 --- a/setup.py +++ b/setup.py @@ -49,7 +49,7 @@ author_email="NathanJMichlo@gmail.com", version="0.1.0", - python_requires=">=3.8", + python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(), install_requires=install_requires, From 68db787d7124cd19733632bfa56b5f5bc6502343 Mon Sep 17 00:00:00 2001 From: Nathan Michlo Date: Mon, 4 Oct 2021 16:15:39 +0200 Subject: [PATCH 17/17] version bump v0.2.0 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 837bd010..27a5a6bc 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ author="Nathan Juraj Michlo", author_email="NathanJMichlo@gmail.com", - version="0.1.0", + version="0.2.0", python_requires=">=3.8", # we make use of standard library features only in 3.8 packages=setuptools.find_packages(),