Skip to content

Commit

Permalink
Reformatted code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Mar 4, 2024
1 parent 661ee31 commit 7f1cf0e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
1 change: 0 additions & 1 deletion hitchhiking_rotations/cfgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@
from .cfg_pcd_to_pose import get_cfg_pcd_to_pose
from .cfg_pose_to_cube_image import get_cfg_pose_to_cube_image
from .cfg_pose_to_fourier import get_cfg_pose_to_fourier

24 changes: 16 additions & 8 deletions hitchhiking_rotations/datasets/fourier_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@
from hitchhiking_rotations import HITCHHIKING_ROOT_DIR
from hitchhiking_rotations.utils import save_pickle, load_pickle


class PoseToFourierDataset(Dataset):
"""
Loads data from fourier dataset
"""
def __init__(self, mode, dataset_size, device, nb, nf):

path = join(HITCHHIKING_ROOT_DIR, "assets", "datasets", "fourier_dataset",
f"fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl")
def __init__(self, mode, dataset_size, device, nb, nf):
path = join(
HITCHHIKING_ROOT_DIR, "assets", "datasets", "fourier_dataset", f"fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl"
)

if os.path.exists(path):
dic = load_pickle(path)
Expand All @@ -39,9 +41,9 @@ def __len__(self):
def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx]

class random_fourier_function():

def __init__(self, n_basis, seed, A0=0., L=1.):
class random_fourier_function:
def __init__(self, n_basis, seed, A0=0.0, L=1.0):
np.random.seed(seed)
self.L = L
self.n_basis = n_basis
Expand All @@ -53,10 +55,14 @@ def __init__(self, n_basis, seed, A0=0., L=1.):
def __call__(self, x):
fFs = self.A0 / 2
for k in range(len(self.A)):
fFs = (fFs + self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) +
self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L))
fFs = (
fFs
+ self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L)
+ self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L)
)
return fFs


def create_data(N_points, nb, seed):
"""
Create data from fourier series.
Expand All @@ -75,6 +81,7 @@ def create_data(N_points, nb, seed):
features = np.apply_along_axis(four_func, 1, inputs)
return rots.as_quat().astype(np.float32), features.astype(np.float32)


def plot_fourier_data(rotations, features):
import pandas as pd
import seaborn as sns
Expand All @@ -91,5 +98,6 @@ def plot_fourier_data(rotations, features):
g.set(xlim=(-1.2, 1.2), ylim=(-1.2, 1.2))
plt.show()


if __name__ == "__main__":
create_data(N_points=100, nb=2, seed=5)
create_data(N_points=100, nb=2, seed=5)

0 comments on commit 7f1cf0e

Please sign in to comment.