From 95fc759c6e02276061b89cfdba46711d82af381f Mon Sep 17 00:00:00 2001 From: BirkhoffG <26811230+BirkhoffG@users.noreply.github.com> Date: Thu, 1 Feb 2024 22:24:23 -0500 Subject: [PATCH] Update test cases for testing num of batches --- jax_dataloader/tests.py | 4 ++++ nbs/tests.ipynb | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/jax_dataloader/tests.py b/jax_dataloader/tests.py index 40af62d..6687350 100644 --- a/jax_dataloader/tests.py +++ b/jax_dataloader/tests.py @@ -18,6 +18,7 @@ def get_batch(batch): # %% ../nbs/tests.ipynb 4 def test_no_shuffle(cls, ds, batch_size: int, feats, labels): dl = cls(ds, batch_size=batch_size, shuffle=False) + assert len(dl) == len(feats) // batch_size + 1 for _ in range(2): X_list, Y_list = [], [] for batch in dl: @@ -31,6 +32,7 @@ def test_no_shuffle(cls, ds, batch_size: int, feats, labels): # %% ../nbs/tests.ipynb 5 def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels): dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=True) + assert len(dl) == len(feats) // batch_size for _ in range(2): X_list, Y_list = [], [] for batch in dl: @@ -46,6 +48,7 @@ def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels): def test_shuffle(cls, ds, batch_size: int, feats, labels): dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False) last_X, last_Y = jnp.array([]), jnp.array([]) + assert len(dl) == len(feats) // batch_size + 1 for _ in range(2): X_list, Y_list = [], [] for batch in dl: @@ -65,6 +68,7 @@ def test_shuffle(cls, ds, batch_size: int, feats, labels): # %% ../nbs/tests.ipynb 7 def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels): dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=True) + assert len(dl) == len(feats) // batch_size for _ in range(2): X_list, Y_list = [], [] for batch in dl: diff --git a/nbs/tests.ipynb b/nbs/tests.ipynb index 7c2f8ba..fd16efe 100644 --- a/nbs/tests.ipynb +++ b/nbs/tests.ipynb @@ -58,6 +58,7 @@ "#| exporti\n", "def test_no_shuffle(cls, ds, batch_size: int, feats, labels):\n", " dl = cls(ds, batch_size=batch_size, shuffle=False)\n", + " assert len(dl) == len(feats) // batch_size + 1\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n", @@ -78,6 +79,7 @@ "#| exporti\n", "def test_no_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):\n", " dl = cls(ds, batch_size=batch_size, shuffle=False, drop_last=True)\n", + " assert len(dl) == len(feats) // batch_size\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n", @@ -100,6 +102,7 @@ "def test_shuffle(cls, ds, batch_size: int, feats, labels):\n", " dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=False)\n", " last_X, last_Y = jnp.array([]), jnp.array([])\n", + " assert len(dl) == len(feats) // batch_size + 1\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n", @@ -126,6 +129,7 @@ "#| exporti\n", "def test_shuffle_drop_last(cls, ds, batch_size: int, feats, labels):\n", " dl = cls(ds, batch_size=batch_size, shuffle=True, drop_last=True)\n", + " assert len(dl) == len(feats) // batch_size\n", " for _ in range(2):\n", " X_list, Y_list = [], []\n", " for batch in dl:\n",