Skip to content

Commit

Permalink
apply formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoanlu committed Apr 28, 2024
1 parent 3822fb4 commit bba045b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The training script is provided as `train.ipynb`.
## Dependencies
The program was developed and tested in the following environment.
- Python 3.10
- `torch==0.13.1`
- `torch==1.13.1`
- `jax==0.4.23`
- `jaxlib==0.4.23`
- `diffusers==0.18.2`
Expand Down
13 changes: 7 additions & 6 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ def setUp(self):
self.obs_dim = self.config["controller"]["networks"]["obs_dim"]
self.obstacle_encode_dim = self.config["controller"]["networks"]["obstacle_encode_dim"]


def test_init(self):
dataset = PlanarQuadrotorStateDataset(
dataset_path=self.dataset_path,
pred_horizon=self.pred_horizon,
obs_horizon=self.obs_horizon,
action_horizon=self.action_horizon
action_horizon=self.action_horizon,
)
self.assertTrue(isinstance(dataset, PlanarQuadrotorStateDataset))

Expand All @@ -34,19 +33,21 @@ def test_iter(self):
dataset_path=self.dataset_path,
pred_horizon=self.pred_horizon,
obs_horizon=self.obs_horizon,
action_horizon=self.action_horizon
action_horizon=self.action_horizon,
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.config["dataloader"]["batch_size"],
shuffle=True,
pin_memory=True,
)

# batch context matches expectecd shapes
batch = next(iter(dataloader))
self.assertEqual(batch['obs'].shape, (self.batch_size, self.obs_dim*self.obs_horizon+self.obstacle_encode_dim))
self.assertEqual(batch['action'].shape, (self.batch_size, self.pred_horizon, self.action_dim))
self.assertEqual(
batch["obs"].shape, (self.batch_size, self.obs_dim * self.obs_horizon + self.obstacle_encode_dim)
)
self.assertEqual(batch["action"].shape, (self.batch_size, self.pred_horizon, self.action_dim))


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions tests/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

class TestConditionalUnet1D(unittest.TestCase):
def setUp(self):

with open("tests/config/test_config.yaml", "r") as file:
self.config = yaml.safe_load(file)

Expand All @@ -20,20 +19,18 @@ def setUp(self):

def test_init(self):
noise_pred_net = ConditionalUnet1D(
input_dim=self.action_dim,
global_cond_dim=self.obs_dim*self.obs_horizon + self.obstacle_encode_dim
input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim
)
self.assertTrue(isinstance(noise_pred_net, ConditionalUnet1D))

def test_inference(self):
net = ConditionalUnet1D(
input_dim=self.action_dim,
global_cond_dim=self.obs_dim*self.obs_horizon + self.obstacle_encode_dim
input_dim=self.action_dim, global_cond_dim=self.obs_dim * self.obs_horizon + self.obstacle_encode_dim
)

# example inputs
noised_action = torch.randn((1, self.pred_horizon, self.action_dim))
obs = torch.zeros((1, self.obs_horizon*self.obs_dim+self.obstacle_encode_dim))
obs = torch.zeros((1, self.obs_horizon * self.obs_dim + self.obstacle_encode_dim))
diffusion_iter = torch.zeros((1,))

# the noise prediction network
Expand Down

0 comments on commit bba045b

Please sign in to comment.