Skip to content

Commit

Permalink
Merge pull request #16 from martius-lab/dev/visu_cube_experiment
Browse files Browse the repository at this point in the history
dev - new visu script
  • Loading branch information
AndReGeist authored Mar 14, 2024
2 parents 76174db + faa916a commit 7df7bbc
Show file tree
Hide file tree
Showing 18 changed files with 334 additions and 64 deletions.
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
fail_fast: false

repos:
- repo: https://github.com/psf/black
rev: 23.10.1
hooks:
- id: black
args: [--line-length=120]

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=120, "--ignore=W291,E731"]
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ List of each experiment as in paper and how to reproduce it
```shell
pip3 install black==23.10
cd hitchhiking_rotations && black --line-length 120 ./

# Using precommit
pip3 install pre-commit
cd hitchhiking_rotations && python3 -m pre_commit install
cd hitchhiking_rotations && python3 -m pre_commit run

```
### Add License Headers
```shell
Expand Down
128 changes: 119 additions & 9 deletions hitchhiking_rotations/cfgs/cfg_cube_image_to_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def get_cfg_cube_image_to_pose(device):
shared_trainer_cfg = {
"_target_": "hitchhiking_rotations.utils.Trainer",
"lr": 0.001,
"optimizer": "SGD",
"optimizer": "Adam",
"logger": "${logger}",
"verbose": "${verbose}",
"device": device,
Expand All @@ -17,7 +17,7 @@ def get_cfg_cube_image_to_pose(device):
return {
"verbose": True,
"batch_size": 32,
"epochs": 100,
"epochs": 1000,
"training_data": {
"_target_": "hitchhiking_rotations.datasets.CubeImageToPoseDataset",
"mode": "train",
Expand Down Expand Up @@ -45,7 +45,7 @@ def get_cfg_cube_image_to_pose(device):
"metrics": ["l1", "l2", "geodesic_distance", "chordal_distance"],
},
"trainers": {
"r9_l1": {
"r9_svd_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -55,7 +55,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_l2": {
"r9_svd_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -65,7 +65,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_geodesic_distance": {
"r9_svd_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -75,7 +75,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_chordal_distance": {
"r9_svd_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -85,7 +85,67 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model9}",
},
},
"r9_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:flatten}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:l1}",
"model": "${model9}",
},
},
"r9_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:flatten}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:l2}",
"model": "${model9}",
},
},
"r9_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:n_3x3}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:geodesic_distance}",
"model": "${model9}",
},
},
"r9_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:n_3x3}",
"postprocess_pred_logging": "${u:procrustes_to_rotmat}",
"loss": "${u:chordal_distance}",
"model": "${model9}",
},
},
"r6_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_gramschmidt_f}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:gramschmidt_to_rotmat}",
"loss": "${u:l1}",
"model": "${model6}",
},
},
"r6_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_gramschmidt_f}",
"postprocess_pred_loss": "${u:flatten}",
"postprocess_pred_logging": "${u:gramschmidt_to_rotmat}",
"loss": "${u:l2}",
"model": "${model6}",
},
},
"r6_gso_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -95,7 +155,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_l2": {
"r6_gso_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -105,7 +165,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_geodesic_distance": {
"r6_gso_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -115,7 +175,7 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"r6_chordal_distance": {
"r6_gso_chordal_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
Expand All @@ -125,6 +185,16 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model6}",
},
},
"quat_c_geodesic_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:passthrough}",
"postprocess_pred_loss": "${u:quaternion_to_rotmat}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:geodesic_distance}",
"model": "${model4}",
},
},
"quat_c_chordal_distance": {
**shared_trainer_cfg,
**{
Expand Down Expand Up @@ -175,6 +245,46 @@ def get_cfg_cube_image_to_pose(device):
"model": "${model4}",
},
},
"quat_rf_cosine_distance": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:cosine_distance}",
"model": "${model4}",
},
},
"quat_rf_l2": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l2}",
"model": "${model4}",
},
},
"quat_rf_l1": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l1}",
"model": "${model4}",
},
},
"quat_rf_l2_dp": {
**shared_trainer_cfg,
**{
"preprocess_target": "${u:rotmat_to_quaternion_rand_flip}",
"postprocess_pred_loss": "${u:passthrough}",
"postprocess_pred_logging": "${u:quaternion_to_rotmat}",
"loss": "${u:l2_dp}",
"model": "${model4}",
},
},
"rotvec_l1": {
**shared_trainer_cfg,
**{
Expand Down
2 changes: 1 addition & 1 deletion hitchhiking_rotations/cfgs/cfg_pose_to_cube_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_cfg_pose_to_cube_image(device):
return {
"verbose": False,
"batch_size": 128,
"epochs": 100,
"epochs": 1000,
"training_data": {
"_target_": "hitchhiking_rotations.datasets.PoseToCubeImageDataset",
"mode": "train",
Expand Down
17 changes: 8 additions & 9 deletions hitchhiking_rotations/datasets/cube_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# See LICENSE file in the project root for details.
#
import mujoco
import torch
import numpy as np
from PIL import Image

Expand All @@ -15,15 +14,15 @@ def __init__(self, height: int, width: int):
<mujoco>
<worldbody>
<light name="top" pos="0 0 0"/>
<body name="cube" euler="0 0 0">
<body name="cube" euler="0 0 0" pos="0 0 0">
<joint type="ball" stiffness="0" damping="0" frictionloss="0" armature="0"/>
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/>
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/>
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/>
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/>
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/>
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/>
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/>
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/>
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/>
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/>
</body>
</worldbody>
</mujoco>
Expand Down
12 changes: 11 additions & 1 deletion hitchhiking_rotations/datasets/cube_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from hitchhiking_rotations.utils import save_pickle, load_pickle
import os
from os.path import join
import pickle
from scipy.spatial.transform import Rotation
import torch
import roma
Expand Down Expand Up @@ -55,3 +54,14 @@ def __init__(self, mode, dataset_size, device):

def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.imgs[idx].type(torch.float32) / 255


if __name__ == "__main__":
from PIL import Image
import numpy as np

dataset = CubeImageToPoseDataset("train", 2048, "cpu")
for i in range(10):
img, quat = dataset[i]
img = Image.fromarray(np.uint8(img.cpu().numpy() * 255))
img.save(join(HITCHHIKING_ROOT_DIR, "results", f"example_img_{i}.png"))
4 changes: 0 additions & 4 deletions hitchhiking_rotations/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
Expand All @@ -31,8 +29,6 @@ def __init__(self, input_dim, width, height):
IMAGE_CHANNEL = 3
Z_DIM = 10
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64

self.INP_SIZE = 5
self.input_dim = input_dim
Expand Down
2 changes: 1 addition & 1 deletion hitchhiking_rotations/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
from .logger import OrientationLogger
from .trainer import Trainer
from .loading import *
from .helper import passthrough, flatten
from .helper import passthrough, flatten, n_3x3
from .notation import RotRep
2 changes: 0 additions & 2 deletions hitchhiking_rotations/utils/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import roma
import torch

# x to rotmat


def euler_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
return euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY")
Expand Down
4 changes: 4 additions & 0 deletions hitchhiking_rotations/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def passthrough(*x):

def flatten(x):
return x.reshape(x.shape[0], -1)


def n_3x3(x):
return x.reshape(-1, 3, 3)
Loading

0 comments on commit 7df7bbc

Please sign in to comment.