From 51fd4621a5b4d9d6724a5136cc1504215b877a78 Mon Sep 17 00:00:00 2001 From: Des Small Date: Fri, 13 Sep 2024 16:06:01 +0200 Subject: [PATCH] Adding a refant to the fringefitter --- src/astroviper/calibration/apply_fringe.py | 188 ++++++++++++++++++ .../calibration/exercise_fringe_sbd2.py | 29 +++ src/astroviper/calibration/fringefit.py | 98 +++++++-- 3 files changed, 297 insertions(+), 18 deletions(-) create mode 100644 src/astroviper/calibration/apply_fringe.py create mode 100644 src/astroviper/calibration/exercise_fringe_sbd2.py diff --git a/src/astroviper/calibration/apply_fringe.py b/src/astroviper/calibration/apply_fringe.py new file mode 100644 index 0000000..9303ee3 --- /dev/null +++ b/src/astroviper/calibration/apply_fringe.py @@ -0,0 +1,188 @@ +from xradio.vis.read_processing_set import read_processing_set + +import dask +import numpy as np +import xarray as xa +import pandas as pd +import datetime + +# I am surprised this is not some kind of standard function, but the +# pandas version scolds me for using it on a string. +# +# This implementation is O(n^2) but if that is ever an issue we are +# doing something else very wrong. +def unique(s): + "Get the unique characters in a string in the order they occur in the original" + u = [] + for c in s: + if c not in u: + u.append(c) + return u + +def nanCount(j): + return np.sum(np.isnan(j)) + +def numberCount(j): + return np.sum(~np.isnan(j)) + + +def makeCalTable(xds): + "An attempt to make a calibration table out of coordinates" + pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)])) + coords = xa.Coordinates(coords={'time' : xa.Coordinates(coords = {'time' : 0.5*(xds.time[0]+xds.time[-1])}), + 'antenna_name' : xds.antenna_xds.antenna_name, + 'polarization' : pols_ant, + 'parameter' : ['one', 'two', 'three', 'ah-ha'] + }) + cds = xa.Dataset(data_vars = dict(cals=(coords.sizes.keys(), np.zeros(tuple(coords.sizes.values()), complex))), + coords=coords) + return cds + +############################################################################# +# How to make a (time, frequency) grid: +# +# np.expand_dims(dt, 1) + np.expand_dims(df, 0) => array of shape (nt, nf) +############################################################################# + +class GridJonesCalculator(object): + def __init__(self, xds): + """ +""" + # We scrape all the metadata from an xds. Maybe this is wise, maybe not. + self.xds = xds + self.frequency = xds.frequency + self.time = xds.time + self.baseline_id = xds.baseline_id + self.VISIBILITY = xds.VISIBILITY + self.n_ants = xds.antenna_xds.antenna_name.size + # We expand out a copy of the baseline_antenna1_id array to have shape + # (1, n_baselines, 1, 1, 1) + self.ant1_mask = np.expand_dims(xds.baseline_antenna1_name.values, (0, 2, 3, 4)) + self.ant2_mask = np.expand_dims(xds.baseline_antenna2_name.values, (0, 2, 3, 4)) + self.makeAccumulatedJoneses() + + def makeAccumulatedJoneses(self): + vcs = self.VISIBILITY.shape + # We are going to use 2x2 matrices for our Jones matrices because that's how they multiply + assert vcs[-1] == 4 # We'll figure other cases out later + new_shape = vcs[:-1] + (2, 2) + self.j_a1_composed = np.zeros(new_shape, complex) + self.j_a1_composed = np.identity(2) + self.j_a2_composed = np.ones(new_shape, complex) + self.j_a2_composed = np.identity(2) + + def insertBaselineDimension(self, j, nbaselines): + """We add a baseline dimension, but we also broadcast to it""" + # Insert a baseline dimension to a calibration matrix. + j_shape = j.shape + new_shape = j_shape[:1] + (nbaselines,) + j_shape[1:] + j = np.expand_dims(j, 1) + j = np.broadcast_to(j, new_shape) + return j + + def calcGridJonesAnt(self, fp, df, dt): + """Return Jones matrices for all grid points of an xds for a single set of fringefit parameters (which come in pairs one for each polarization)""" + # We now assume phi0, tau and r are 2-vectors + phi0, tau, r = fp + # We upscale the dimensions so that things broadcast nicely: + df_shaped = np.expand_dims(df, (0, 2)) + dt_shaped = np.expand_dims(dt, (1, 2)) + phi_shaped = np.expand_dims(phi0, (0, 1)) + # Calculate phases: + phi = (phi_shaped + + 2*np.pi*tau.values*df_shaped + + 2*np.pi*r.values*dt_shaped) + # And then phasors. + # (I spent a long time trying to express this in a neat numpy way. + # I did not succeed, so I do it the ugly stupid way for now.) + many_jones_diags = np.exp(1J*phi, dtype=complex) + many_jones = np.zeros(many_jones_diags.shape + (2,), complex) + many_jones[:, :, 0, 0] = many_jones_diags[:, :, 0] + many_jones[:, :, 1, 1] = many_jones_diags[:, :, 1] + return many_jones + + def calcGridJones(self, cal_quantum): + dt = (self.time - cal_quantum.t_ref).values + df = (self.frequency - cal_quantum.f_ref).values + # This needs fixed too + for iant, ant in enumerate(cal_quantum.coords['antenna_name'].values): + fp = cal_quantum.sel(antenna_name=ant) + params = np.sum(~np.isnan(fp.values)) + print(f"{ant=} {params=}") + if params == 0: + continue + print(f"{fp.values}") + j = self.calcGridJonesAnt(fp, df, dt) + count = np.sum(~np.isnan(j)) + print(f"{count=}") + print(f"{np.max(np.abs(j.flatten()))=}") + j = self.insertBaselineDimension(j, self.xds.baseline_id.size) + # Then we can make a version of our baseline jones matrices that only affects a specific ant1: + j_1_mask = np.where(self.ant1_mask != ant, j, 1) + j_2_mask = np.where(self.ant2_mask != ant, j, 1) + # The second array needs to be hermitianized on the last two axes, which we have to do by hand + j_a2 = j.transpose(0, 1, 2, 4, 3).conj() + # Then we can apply those entries to our corrected data array by multiplication: + # First antenna corrected by multiplication from the left: + print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}") + self.j_a1_composed = np.matmul(j, self.j_a1_composed) + # Second antenna in baseline corrected (by Hermitian matrix) from the right + self.j_a2_composed = np.matmul(self.j_a2_composed, j) + if False: + print(f"{numberCount(self.j_a1_composed)=}") + print(f"{nanCount(self.j_a1_composed)=}") + print(f"{numberCount(self.j_a2_composed)=}") + print(f"{nanCount(self.j_a2_composed)=}") + print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}") + + +# We need to consult the data for polarizations now. +ps = read_processing_set('n14c3.zarr') +ps.keys() + +# Current version of this ps is split by SPW and not by field +xds = ps['n14c3_099'] + +# In fact, all xdses have the same polarization setup here, but whomst can say if that is always true? +# Actually, I think maybe we could? +pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)])) +quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name, + 'parameter' : range(3), + 'polarization' : pols_ant + }) +q = xa.DataArray(coords=quantumCoords) + +q.attrs['f_ref'] = xds.frequency[0] +q.attrs['t_ref'] = xds.time[0] + +# Note that the f_ref attr copies over a lot of metadata. +# Which I think is a good thing? +q.attrs['f_ref'].attrs['spectral_window_name'] + + +# You can't assign to DataArrays by name, only by integer index. +q[0] = [[0.0, 0.0], [1.0e-9,-1.0e-9], [0, 0]] + +gjc = GridJonesCalculator(xds) +gjc.calcGridJones(q) + +def squareUpLastDimension(v): + s = v.shape[:-1] + (2,2) + v2 = np.reshape(v, s) + return v2 + + + + +v = squareUpLastDimension(xds.VISIBILITY.values) +# And this is the payoff I guess +v2 = gjc.j_a1_composed @ v @ gjc.j_a2_composed + +# This works, although we should do better along the polarization axis. +xds.assign({'FROBBED' : xa.DataArray( coords=(xds.time, xds.baseline_id, xds.frequency, pols_ant, pols_ant), data=v2)}) + +# Isn't there meant to be a nice way to get spw now? +#>>> xds.partition_info['spectral_window_name'] + +# We can also now do this at ps level: +ps2 = ps.sel(spw_name='spw_0') diff --git a/src/astroviper/calibration/exercise_fringe_sbd2.py b/src/astroviper/calibration/exercise_fringe_sbd2.py new file mode 100644 index 0000000..62d8d03 --- /dev/null +++ b/src/astroviper/calibration/exercise_fringe_sbd2.py @@ -0,0 +1,29 @@ +from xradio.vis.read_processing_set import read_processing_set +from graphviper.graph_tools.coordinate_utils import (interpolate_data_coords_onto_parallel_coords, + make_parallel_coord) +from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow +from graphviper.graph_tools.coordinate_utils import make_time_coord +from graphviper.graph_tools.coordinate_utils import make_frequency_coord + +from astroviper.calibration.fringefit import fringefit_single +import dask +import xarray as xa + +ps = read_processing_set('n14c3.zarr') +ps.keys() + +xds = ps['n14c3_000'] + +# +meas = make_time_coord(time_start='2014-10-22 13:18:00', time_delta=120, n_samples=2) +parallel_coords = {} +parallel_coords['baseline_id'] = make_parallel_coord( + coord=xds.baseline_id, n_chunks=1) +parallel_coords['time'] = make_parallel_coord(meas, n_chunks=1) +node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords, + ps, ps_partition=['spectral_window_name']) +subsel = {'polarization': 'LL'} +res = fringefit_single(ps, node_task_data_mapping, subsel) + +# print(res) + diff --git a/src/astroviper/calibration/fringefit.py b/src/astroviper/calibration/fringefit.py index 4016518..62d544a 100644 --- a/src/astroviper/calibration/fringefit.py +++ b/src/astroviper/calibration/fringefit.py @@ -6,19 +6,71 @@ from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow from typing import Dict, Union +from xradio.vis.read_processing_set import read_processing_set + +import dask +import numpy as np +import xarray as xa +import pandas as pd +import datetime + +## I should figure out where this belongs at some point +def unique(s): + "Get the unique characters in a string in the order they occur in the original" + u = [] + for c in s: + if c not in u: + u.append(c) + return u + + + + +def getFourierSpacings(xds): + f = xds.frequency.values + df = (f[-1] - f[0])/(len(f)-1) + dF = len(f)*df + ddelay = 1/dF + # + t = xds.time.values + dt = (t[-1] - t[0])/(len(t)-1) + dT = len(t) * dt + drate = 1/dT + return (ddelay, drate) + +def makeCalArray(xds, ref_ant): + pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)])) + quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name, + 'polarization' : pols_ant, + 'parameter' : range(3) + }) + q = xa.DataArray(coords=quantumCoords) + ref_freq = xds.frequency.reference_frequency['data'] + # Should we choose this reference time? + ref_time = xds.time[0] + q.attrs['reference_frequency'] = ref_freq + q.attrs['reference_time'] = ref_time + q.attrs['reference_antenna'] = ref_ant + return q + def _fringe_node_task(input_params: Dict): ps = input_params['ps'] data_selection = input_params['data_selection'] + ref_ant = input_params['ref_ant'] # FIXME: for now we do single band - if len(data_selection.keys())>1: + if len(data_selection.keys()) > 1: print(f'{data_selection.keys()=}') raise RuntimeError("We only do single xdses so far") name = list(data_selection.keys())[0] xds = ps[name] + q = makeCalArray(xds, ref_ant) data_sub_selection = input_params['data_sub_selection'] - pol = data_sub_selection['polarization'] + pols = data_sub_selection['polarization'] + # FIXME! + pol = pols[0] xds2 = xds.isel(**data_selection[name]) - xds2 = xds2.sel(polarization=pol) + xds2 = xds2.sel(polarization=pols) + ddelay, drate = getFourierSpacings(xds2) vis = xds2.VISIBILITY ang = np.angle(vis) nvis = np.exp(1J*ang) @@ -31,26 +83,35 @@ def _fringe_node_task(input_params: Dict): ), axes=(0,2) ) - ref_ant = 1 - res = {} bl_slice = data_selection[name]["baseline_id"] - baselines = xds.baseline_id[bl_slice].values - ant1s = xds2.baseline_antenna1_id.values - ant2s = xds2.baseline_antenna1_id.values + baselines = xds2.baseline_id[bl_slice].values + ant1s = xds2.baseline_antenna1_name.values + ant2s = xds2.baseline_antenna2_name.values try: - # FIXME: - # - # In the case of subcubes we *don't* get all the antenna1_id stuff! - # antenna1_id.values: [ 6 6 6 7 7 7 7 7 8 8 8 8 9 9 9 10 10 11] - # baselines : [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77] for i, (bl, ant1, ant2) in enumerate(zip(baselines, ant1s, ant2s)): - a = np.abs(fftvis[:,i,:]) + if ref_ant not in [ant1, ant2]: + # print(f"Skipping {ant1}-{ant2}") + continue + if ref_ant == ant1 and ref_ant==ant2: + print("Skipping autos") + # print(f"{ant1}-{ant2}") + ant = ant1 if (ant2 == ref_ant) else ant2 + spw = xds.partition_info['spectral_window_name'] + t = xds.time[0].values + print(f"{ant} {spw} {t}") + a = np.abs(fftvis[:, i, :]) ind = np.unravel_index(np.argmax(a, axis=None), a.shape) - res.setdefault(xds2.time.values[0], {})[bl] = (ind, a[ind], a.shape) + # breakpoint() + ix, iy = ind + phi0 = np.angle(a[ind]) + delay = ix*ddelay + ref_freq = xds.frequency.reference_frequency['data'] + rate = iy*drate/ref_freq + q.loc[dict(antenna_name=ant, polarization=pol)] = [phi0, delay, rate] except IndexError as e: - print(f'{xds2.baseline_antenna1_id.values}\n{baselines=}') + print(f'{xds2.baseline_antenna1_name.values}\n{baselines=}') raise e - return res + return q def _fringefit_reduce(graph_inputs: xr.Dataset, input_params: Dict): merged = {} @@ -64,13 +125,14 @@ def _fringefit_reduce(graph_inputs: xr.Dataset, input_params: Dict): return merged -def fringefit_single(ps, node_task_data_mapping: Dict, sub_selection: Dict): +def fringefit_single(ps, node_task_data_mapping: Dict, sub_selection: Dict, ref_ant: int): """ TODO! """ input_params = {} input_params['data_sub_selection'] = sub_selection input_params['ps'] = ps + input_params['ref_ant'] = ref_ant graph = map( input_data = ps, node_task_data_mapping = node_task_data_mapping,