From b870e55714ea38f825193d97e6d6982b185cd8bb Mon Sep 17 00:00:00 2001 From: yallup Date: Mon, 15 Jul 2024 17:40:08 +0100 Subject: [PATCH] fix tests --- clax/clax.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/clax/clax.py b/clax/clax.py index 9fc8c39..27a1f0f 100644 --- a/clax/clax.py +++ b/clax/clax.py @@ -60,8 +60,8 @@ def loss(self, params, batch_stats, batch, labels, rng): def _train(self, samples, labels, batches_per_epoch, **kwargs): """Internal wrapping of training loop.""" self.trace = Trace() - batch_size = kwargs.get("batch_size", 1024) - epochs = kwargs.get("epochs", 10) + batch_size = kwargs.get("batch_size") + epochs = kwargs.get("epochs") # epochs *= batches_per_epoch @jit @@ -114,8 +114,10 @@ def _init_state(self, **kwargs): self.schedule = optax.warmup_cosine_decay_schedule( init_value=0.0, peak_value=lr, - warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs), - decay_steps=int((1 - cold_fraction) * target_batches_per_epoch * epochs), + warmup_steps=int(warmup_fraction * target_batches_per_epoch * epochs + 1), + decay_steps=int( + (1 - cold_fraction) * target_batches_per_epoch * epochs + 1 + ), end_value=lr * cold_lr, exponent=1.0, ) @@ -134,7 +136,7 @@ def _init_state(self, **kwargs): tx=optimizer, ) - def fit(self, samples, labels, epochs=10, **kwargs): + def fit(self, samples, labels, epochs=10, batch_size=1024, **kwargs): """Fit the classifier on provided samples. Args: @@ -156,8 +158,10 @@ def fit(self, samples, labels, epochs=10, **kwargs): cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3. """ restart = kwargs.get("restart", False) - batch_size = kwargs.get("batch_size", 1024) + data_size = samples.shape[0] + batch_size = min(batch_size, data_size) + kwargs["batch_size"] = batch_size batches_per_epoch = data_size // batch_size self.ndims = samples.shape[-1] kwargs["epochs"] = epochs @@ -210,7 +214,7 @@ def loss(self, params, batch_stats, batch, labels, rng): loss = self.loss_fn(output.squeeze(), labels).mean() return loss, updates - def fit(self, samples_a, samples_b, epochs=10, **kwargs): + def fit(self, samples_a, samples_b, epochs=10, batch_size=1024, **kwargs): """Fit the classifier on provided samples. Args: @@ -233,9 +237,11 @@ def fit(self, samples_a, samples_b, epochs=10, **kwargs): cold_lr (float): The factor to reduce learning rate to use during the cold phase. Defaults to 1e-3. """ restart = kwargs.get("restart", False) - batch_size = kwargs.get("batch_size", 1024) self.ndims = kwargs.get("ndims", samples_a.shape[-1]) data_size = samples_a.shape[0] + + batch_size = min(batch_size, data_size) + kwargs["batch_size"] = batch_size batches_per_epoch = data_size // batch_size kwargs["epochs"] = epochs if (not self.state) | restart: