Skip to content

Commit

Permalink
Merge branch 'avoid_samp_obs' into fix_lm_grid
Browse files Browse the repository at this point in the history
  • Loading branch information
tgross03 committed Jul 12, 2024
2 parents cb63ded + f4da26b commit b48b076
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 26 deletions.
1 change: 1 addition & 0 deletions docs/changes/31.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- use observation class to pass sampling options to the fits writer
44 changes: 24 additions & 20 deletions pyvisgen/fits/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pyvisgen.layouts.layouts as layouts


def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
def create_vis_hdu(data, obs, layout="vlba", source_name="sim-source-0"):
u = data.u

v = data.v
Expand All @@ -23,7 +23,7 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):

BASELINE = data.base_num

INTTIM = np.repeat(np.array(conf["corr_int_time"], dtype=">f4"), len(u))
INTTIM = np.repeat(np.array(obs.int_time, dtype=">f4"), len(u))

# visibility data
values = data.get_values()
Expand All @@ -36,10 +36,10 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
# in dim 4 = IFs , dim = 1, dim 4 = number of jones, 3 = real, imag, weight

# wcs
ra = conf["fov_center_ra"]
dec = conf["fov_center_dec"]
freq = (conf["ref_frequency"] * un.Hz).value
freq_d = (conf["bandwidths"][0] * un.Hz).value
ra = obs.ra.cpu().numpy().item()
dec = obs.dec.cpu().numpy().item()
freq = obs.ref_frequency.cpu().numpy().item()
freq_d = obs.bandwidths.cpu().numpy().item()

ws = wcs.WCS(naxis=7)
ws.wcs.crpix = [1, 1, 1, 1, 1, 1, 1]
Expand Down Expand Up @@ -87,7 +87,7 @@ def create_vis_hdu(data, conf, layout="vlba", source_name="sim-source-0"):
hdu_vis.header.comments["PTYPE6"] = "Relative Julian date ?"
hdu_vis.header.comments["PTYPE7"] = "Integration time"

date_obs = conf["scan_start"].date().strftime("%Y-%m-%d")
date_obs = obs.start.strftime("%Y-%m-%d")

date_map = Time.now().to_value(format="iso", subfmt="date")

Expand Down Expand Up @@ -165,24 +165,28 @@ def create_time_hdu(data):
return hdu_time


def create_frequency_hdu(conf):
def create_frequency_hdu(obs):
FRQSEL = np.array([1], dtype=">i4")
col1 = fits.Column(name="FRQSEL", format="1J", unit=" ", array=FRQSEL)

IF_FREQ = np.array(
[np.array(conf["frequency_offsets"])],
[np.array(obs.frequency_offsets.cpu().numpy())],
dtype=">f8",
) # start with 0, add ch_with per IF
col2 = fits.Column(
name="IF FREQ", format=str(IF_FREQ.shape[-1]) + "D", unit="Hz", array=IF_FREQ
)

CH_WIDTH = np.repeat(np.array([conf["bandwidths"]], dtype=">f4"), 1, axis=1)
CH_WIDTH = np.repeat(
np.array([obs.bandwidths.cpu().numpy()], dtype=">f4"), 1, axis=1
)
col3 = fits.Column(
name="CH WIDTH", format=str(CH_WIDTH.shape[-1]) + "E", unit="Hz", array=CH_WIDTH
)

TOTAL_BANDWIDTH = np.repeat(np.array([conf["bandwidths"]], dtype=">f4"), 1, axis=1)
TOTAL_BANDWIDTH = np.repeat(
np.array([obs.bandwidths.cpu().numpy()], dtype=">f4"), 1, axis=1
)
col4 = fits.Column(
name="TOTAL BANDWIDTH",
format=str(TOTAL_BANDWIDTH.shape[-1]) + "E",
Expand Down Expand Up @@ -220,8 +224,8 @@ def create_frequency_hdu(conf):
return hdu_freq


def create_antenna_hdu(conf):
array = layouts.get_array_layout(conf["layout"], writer=True)
def create_antenna_hdu(obs):
array = layouts.get_array_layout(obs.layout, writer=True)

ANNAME = np.chararray(len(array), itemsize=8, unicode=True)
ANNAME[:] = array["station_name"].values
Expand Down Expand Up @@ -288,8 +292,8 @@ def create_antenna_hdu(conf):
)
hdu_ant = fits.BinTableHDU.from_columns(coldefs_ant)

freq = (conf["ref_frequency"] * un.Hz).value
ref_date = Time(conf["scan_start"].isoformat(), format="isot")
freq = (obs.ref_frequency.cpu().numpy() * un.Hz).value
ref_date = obs.start

from astropy.utils import iers

Expand Down Expand Up @@ -325,7 +329,7 @@ def create_antenna_hdu(conf):
hdu_ant.header["UT1UTC"] = (iers_b.ut1_utc(ref_date).value, "UT1 - UTC (sec)")
hdu_ant.header["DATUTC"] = (0, "time system - UTC (sec)") # missing
hdu_ant.header["TIMSYS"] = ("UTC", "Time system")
hdu_ant.header["ARRNAM"] = (conf["layout"], "Array name")
hdu_ant.header["ARRNAM"] = (obs.layout, "Array name")
hdu_ant.header["XYZHAND"] = ("RIGHT", "Handedness of station coordinates")
hdu_ant.header["FRAME"] = ("????", "Coordinate frame, FOR IGNORANCE")
hdu_ant.header["NUMORB"] = (0, "Number orbital parameters in table (n orb)")
Expand Down Expand Up @@ -360,11 +364,11 @@ def create_antenna_hdu(conf):
return hdu_ant


def create_hdu_list(data, conf):
def create_hdu_list(data, obs):
warnings.filterwarnings("ignore", module="astropy.io.fits")
vis_hdu = create_vis_hdu(data, conf)
vis_hdu = create_vis_hdu(data, obs)
time_hdu = create_time_hdu(data)
freq_hdu = create_frequency_hdu(conf)
ant_hdu = create_antenna_hdu(conf)
freq_hdu = create_frequency_hdu(obs)
ant_hdu = create_antenna_hdu(obs)
hdu_list = fits.HDUList([vis_hdu, time_hdu, freq_hdu, ant_hdu])
return hdu_list
10 changes: 5 additions & 5 deletions pyvisgen/simulation/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,21 @@ def simulate_data_set(config, slurm=False, job_id=None, n=None):
if len(SI.shape) == 2:
SI = SI.unsqueeze(0)

obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"])
hdu_list = writer.create_hdu_list(vis_data, samp_ops)
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)

else:
for i in tqdm(range(len(data))):
SIs = get_images(data, i)

for j, SI in enumerate(tqdm(SIs)):
obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
vis_data = vis_loop(obs, SI, noisy=conf["noisy"], mode=conf["mode"])

out = out_path / Path("vis_" + str(j + len(SIs) * i) + ".fits")
hdu_list = writer.create_hdu_list(vis_data, samp_ops)
hdu_list = writer.create_hdu_list(vis_data, obs)
hdu_list.writeto(out, overwrite=True)


Expand Down Expand Up @@ -103,7 +103,7 @@ def create_observation(conf):
dense=dense,
sensitivity_cut=rc["sensitivity_cut"],
)
return obs, rc
return obs


def create_sampling_rc(conf):
Expand Down
1 change: 1 addition & 0 deletions pyvisgen/simulation/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
self.sensitivity_cut = sensitivity_cut
self.device = torch.device(device)

self.layout = array_layout
self.array = layouts.get_array_layout(array_layout)
self.num_baselines = int(
len(self.array.st_num) * (len(self.array.st_num) - 1) / 2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_vis_loop():
from pyvisgen.utils.data import load_bundles, open_bundles

bundles = load_bundles(conf["in_path"])
obs, samp_ops = create_observation(conf)
obs = create_observation(conf)
# num_active_telescopes = test_opts(samp_ops)
data = open_bundles(bundles[0])
SI = torch.tensor(data[0])[None]
Expand Down

0 comments on commit b48b076

Please sign in to comment.