diff --git a/hitchhiking_rotations/cfgs/__init__.py b/hitchhiking_rotations/cfgs/__init__.py index b4a02de..3f37770 100644 --- a/hitchhiking_rotations/cfgs/__init__.py +++ b/hitchhiking_rotations/cfgs/__init__.py @@ -7,4 +7,3 @@ from .cfg_pcd_to_pose import get_cfg_pcd_to_pose from .cfg_pose_to_cube_image import get_cfg_pose_to_cube_image from .cfg_pose_to_fourier import get_cfg_pose_to_fourier - diff --git a/hitchhiking_rotations/datasets/fourier_dataset.py b/hitchhiking_rotations/datasets/fourier_dataset.py index 9eb8615..89a9898 100644 --- a/hitchhiking_rotations/datasets/fourier_dataset.py +++ b/hitchhiking_rotations/datasets/fourier_dataset.py @@ -11,14 +11,16 @@ from hitchhiking_rotations import HITCHHIKING_ROOT_DIR from hitchhiking_rotations.utils import save_pickle, load_pickle + class PoseToFourierDataset(Dataset): """ Loads data from fourier dataset """ - def __init__(self, mode, dataset_size, device, nb, nf): - path = join(HITCHHIKING_ROOT_DIR, "assets", "datasets", "fourier_dataset", - f"fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl") + def __init__(self, mode, dataset_size, device, nb, nf): + path = join( + HITCHHIKING_ROOT_DIR, "assets", "datasets", "fourier_dataset", f"fourier_dataset_{mode}_nb{nb}_nf{nf}.pkl" + ) if os.path.exists(path): dic = load_pickle(path) @@ -39,9 +41,9 @@ def __len__(self): def __getitem__(self, idx): return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.features[idx] -class random_fourier_function(): - def __init__(self, n_basis, seed, A0=0., L=1.): +class random_fourier_function: + def __init__(self, n_basis, seed, A0=0.0, L=1.0): np.random.seed(seed) self.L = L self.n_basis = n_basis @@ -53,10 +55,14 @@ def __init__(self, n_basis, seed, A0=0., L=1.): def __call__(self, x): fFs = self.A0 / 2 for k in range(len(self.A)): - fFs = (fFs + self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) + - self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L)) + fFs = ( + fFs + + self.A[k] * np.cos((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) + + self.B[k] * np.sin((k + 1) * np.pi * np.matmul(self.matrix, x) / self.L) + ) return fFs + def create_data(N_points, nb, seed): """ Create data from fourier series. @@ -75,6 +81,7 @@ def create_data(N_points, nb, seed): features = np.apply_along_axis(four_func, 1, inputs) return rots.as_quat().astype(np.float32), features.astype(np.float32) + def plot_fourier_data(rotations, features): import pandas as pd import seaborn as sns @@ -91,5 +98,6 @@ def plot_fourier_data(rotations, features): g.set(xlim=(-1.2, 1.2), ylim=(-1.2, 1.2)) plt.show() + if __name__ == "__main__": - create_data(N_points=100, nb=2, seed=5) \ No newline at end of file + create_data(N_points=100, nb=2, seed=5)