diff --git a/tests/config/test_config.yaml b/tests/config/test_config.yaml index e4602a7..9ad5b51 100644 --- a/tests/config/test_config.yaml +++ b/tests/config/test_config.yaml @@ -39,7 +39,7 @@ trainer: num_warmup_steps: 500 dataloader: - batch_size: 8 + batch_size: 3 normalizer: action: diff --git a/tests/test_dataset.joblib b/tests/test_dataset.joblib index 5975e10..09c4dd2 100644 Binary files a/tests/test_dataset.joblib and b/tests/test_dataset.joblib differ diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 427a782..31ca4e1 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -24,10 +24,11 @@ def test_init(self): self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset)) def test_iter(self): + batch_size = self.config["dataloader"]["batch_size"] dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config) dataloader = torch.utils.data.DataLoader( dataset, - batch_size=self.config["dataloader"]["batch_size"], + batch_size=batch_size, shuffle=True, pin_memory=True, ) @@ -35,9 +36,9 @@ def test_iter(self): # batch context matches expectecd shapes batch = next(iter(dataloader)) self.assertEqual( - batch["obs"].shape, (self.batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim) + batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim) ) - self.assertEqual(batch["action"].shape, (self.batch_size, self.pred_horizon, self.action_dim)) + self.assertEqual(batch["action"].shape, (batch_size, self.pred_horizon, self.action_dim)) if __name__ == "__main__":