Skip to content

Commit

Permalink
update unittest fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoanlu committed Apr 28, 2024
1 parent 6fccd8e commit 934d610
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion tests/config/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ trainer:
num_warmup_steps: 500

dataloader:
batch_size: 8
batch_size: 3

normalizer:
action:
Expand Down
Binary file modified tests/test_dataset.joblib
Binary file not shown.
7 changes: 4 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,21 @@ def test_init(self):
self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset))

def test_iter(self):
batch_size = self.config["dataloader"]["batch_size"]
dataset = PlanarQuadrotorStateDataset(dataset_path=self.dataset_path, config=self.config)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.config["dataloader"]["batch_size"],
batch_size=batch_size,
shuffle=True,
pin_memory=True,
)

# batch context matches expectecd shapes
batch = next(iter(dataloader))
self.assertEqual(
batch["obs"].shape, (self.batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim)
batch["obs"].shape, (batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim)
)
self.assertEqual(batch["action"].shape, (self.batch_size, self.pred_horizon, self.action_dim))
self.assertEqual(batch["action"].shape, (batch_size, self.pred_horizon, self.action_dim))


if __name__ == "__main__":
Expand Down

0 comments on commit 934d610

Please sign in to comment.