Skip to content

Commit

Permalink
Fix image rotation by fixing rd/lm grid computation
Browse files Browse the repository at this point in the history
  • Loading branch information
aknierim committed Oct 9, 2024
1 parent e8ebb96 commit 68b2a04
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions pyvisgen/simulation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def get_valid_subset(self, num_baselines, device):
date = (torch.from_numpy(t[:-1][mask] + t[1:][mask]) / 2).to(device)

return ValidBaselineSubset(
baseline_nums,
u_start,
u_stop,
u_valid,
Expand All @@ -74,13 +73,13 @@ def get_valid_subset(self, num_baselines, device):
w_start,
w_stop,
w_valid,
baseline_nums,
date,
)


@dataclass()
class ValidBaselineSubset:
baseline_nums: torch.tensor
u_start: torch.tensor
u_stop: torch.tensor
u_valid: torch.tensor
Expand All @@ -90,6 +89,7 @@ class ValidBaselineSubset:
w_start: torch.tensor
w_stop: torch.tensor
w_valid: torch.tensor
baseline_nums: torch.tensor
date: torch.tensor

def __getitem__(self, i):
Expand Down Expand Up @@ -350,7 +350,7 @@ def create_rd_grid(self):
Returns
-------
3d array
rd_grid : 3d array
Returns a 3d array with every pixel containing a RA and Dec value
"""
# transform to rad
Expand All @@ -370,9 +370,10 @@ def create_rd_grid(self):
- self.img_size / 2
) * res + dec

_, R = torch.meshgrid((r, r), indexing="ij")
D, _ = torch.meshgrid((d, d), indexing="ij")
R, _ = torch.meshgrid((r, r), indexing="ij")
_, D = torch.meshgrid((d, d), indexing="ij")
rd_grid = torch.cat([R[..., None], D[..., None]], dim=2)

return rd_grid

def create_lm_grid(self):
Expand All @@ -387,17 +388,17 @@ def create_lm_grid(self):
Returns
-------
3d array
lm_grid : 3d array
Returns a 3d array with every pixel containing a l and m value
"""
dec = torch.deg2rad(self.dec)

lm_grid = torch.zeros(self.rd.shape, device=self.device, dtype=torch.float64)
lm_grid[:, :, 0] = (torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])).T
lm_grid[:, :, 1] = (
torch.sin(self.rd[..., 1]) * torch.cos(dec)
- torch.cos(self.rd[..., 1]) * torch.sin(dec) * torch.cos(self.rd[..., 0])
).T
lm_grid[..., 0] = torch.cos(self.rd[..., 1]) * torch.sin(self.rd[..., 0])
lm_grid[..., 1] = torch.sin(self.rd[..., 1]) * torch.cos(dec) - torch.cos(
self.rd[..., 1]
) * torch.sin(dec) * torch.cos(self.rd[..., 0])

return lm_grid

def get_baselines(self, times):
Expand Down

0 comments on commit 68b2a04

Please sign in to comment.