diff --git a/README.md b/README.md index b6087cc..6e5a503 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ The training script is provided as `train.ipynb`. ## Dependencies The program was developed and tested in the following environment. - Python 3.10 -- `torch==0.13.1` +- `torch==1.13.1` - `jax==0.4.23` - `jaxlib==0.4.23` - `diffusers==0.18.2` diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 24e347e..221f433 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -19,13 +19,12 @@ def setUp(self): self.obs_dim = self.config["controller"]["networks"]["obs_dim"] self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"] - def test_init(self): dataset = PlanarQuadrotorStateDataset( dataset_path=self.dataset_path, pred_horizon=self.pred_horizon, obs_horizon=self.obs_horizon, - action_horizon=self.action_horizon + action_horizon=self.action_horizon, ) self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset)) @@ -34,7 +33,7 @@ def test_iter(self): dataset_path=self.dataset_path, pred_horizon=self.pred_horizon, obs_horizon=self.obs_horizon, - action_horizon=self.action_horizon + action_horizon=self.action_horizon, ) dataloader = torch.utils.data.DataLoader( dataset, @@ -42,11 +41,13 @@ def test_iter(self): shuffle=True, pin_memory=True, ) - + # batch context matches expectecd shapes batch = next(iter(dataloader)) - self.assertEqual(batch['obs'].shape, (self.batch_size, self.obs_dim*self.obs_horizon+self.obstacle_encode_dim)) - self.assertEqual(batch['action'].shape, (self.batch_size, self.pred_horizon, self.action_dim)) + self.assertEqual( + batch["obs"].shape, (self.batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim) + ) + self.assertEqual(batch["action"].shape, (self.batch_size, self.pred_horizon, self.action_dim)) if __name__ == "__main__": diff --git a/tests/test_networks.py b/tests/test_networks.py index 57b9bcc..7efed8d 100644 --- a/tests/test_networks.py +++ b/tests/test_networks.py @@ -7,7 +7,6 @@ class TestConditionalUnet1D(unittest.TestCase): def setUp(self): - with open("tests/config/test_config.yaml", "r") as file: self.config = yaml.safe_load(file) @@ -20,20 +19,18 @@ def setUp(self): def test_init(self): noise_pred_net = ConditionalUnet1D( - input_dim=self.action_dim, - global_cond_dim=self.obs_dim*self.obs_horizon + self.obstacle_encode_dim + input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim ) self.assertTrue(isinstance(noise_pred_net, ConditionalUnet1D)) def test_inference(self): net = ConditionalUnet1D( - input_dim=self.action_dim, - global_cond_dim=self.obs_dim*self.obs_horizon + self.obstacle_encode_dim + input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim ) # example inputs noised_action = torch.randn((1, self.pred_horizon, self.action_dim)) - obs = torch.zeros((1, self.obs_horizon*self.obs_dim+self.obstacle_encode_dim)) + obs = torch.zeros((1, self.obs_horizon * self.obs_dim + self.obstacle_encode_dim)) diffusion_iter = torch.zeros((1,)) # the noise prediction network