Skip to content

Commit

Permalink
Adding a refant to the fringefitter
Browse files Browse the repository at this point in the history
  • Loading branch information
dessmalljive committed Sep 13, 2024
1 parent 8062811 commit 51fd462
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 18 deletions.
188 changes: 188 additions & 0 deletions src/astroviper/calibration/apply_fringe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from xradio.vis.read_processing_set import read_processing_set

import dask
import numpy as np
import xarray as xa
import pandas as pd
import datetime

# I am surprised this is not some kind of standard function, but the
# pandas version scolds me for using it on a string.
#
# This implementation is O(n^2) but if that is ever an issue we are
# doing something else very wrong.
def unique(s):
"Get the unique characters in a string in the order they occur in the original"
u = []
for c in s:
if c not in u:
u.append(c)
return u

def nanCount(j):
return np.sum(np.isnan(j))

def numberCount(j):
return np.sum(~np.isnan(j))


def makeCalTable(xds):
"An attempt to make a calibration table out of coordinates"
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
coords = xa.Coordinates(coords={'time' : xa.Coordinates(coords = {'time' : 0.5*(xds.time[0]+xds.time[-1])}),
'antenna_name' : xds.antenna_xds.antenna_name,
'polarization' : pols_ant,
'parameter' : ['one', 'two', 'three', 'ah-ha']
})
cds = xa.Dataset(data_vars = dict(cals=(coords.sizes.keys(), np.zeros(tuple(coords.sizes.values()), complex))),
coords=coords)
return cds

#############################################################################
# How to make a (time, frequency) grid:
#
# np.expand_dims(dt, 1) + np.expand_dims(df, 0) => array of shape (nt, nf)
#############################################################################

class GridJonesCalculator(object):
def __init__(self, xds):
"""
"""
# We scrape all the metadata from an xds. Maybe this is wise, maybe not.
self.xds = xds
self.frequency = xds.frequency
self.time = xds.time
self.baseline_id = xds.baseline_id
self.VISIBILITY = xds.VISIBILITY
self.n_ants = xds.antenna_xds.antenna_name.size
# We expand out a copy of the baseline_antenna1_id array to have shape
# (1, n_baselines, 1, 1, 1)
self.ant1_mask = np.expand_dims(xds.baseline_antenna1_name.values, (0, 2, 3, 4))
self.ant2_mask = np.expand_dims(xds.baseline_antenna2_name.values, (0, 2, 3, 4))
self.makeAccumulatedJoneses()

def makeAccumulatedJoneses(self):
vcs = self.VISIBILITY.shape
# We are going to use 2x2 matrices for our Jones matrices because that's how they multiply
assert vcs[-1] == 4 # We'll figure other cases out later
new_shape = vcs[:-1] + (2, 2)
self.j_a1_composed = np.zeros(new_shape, complex)
self.j_a1_composed = np.identity(2)
self.j_a2_composed = np.ones(new_shape, complex)
self.j_a2_composed = np.identity(2)

def insertBaselineDimension(self, j, nbaselines):
"""We add a baseline dimension, but we also broadcast to it"""
# Insert a baseline dimension to a calibration matrix.
j_shape = j.shape
new_shape = j_shape[:1] + (nbaselines,) + j_shape[1:]
j = np.expand_dims(j, 1)
j = np.broadcast_to(j, new_shape)
return j

def calcGridJonesAnt(self, fp, df, dt):
"""Return Jones matrices for all grid points of an xds for a single set of fringefit parameters (which come in pairs one for each polarization)"""
# We now assume phi0, tau and r are 2-vectors
phi0, tau, r = fp
# We upscale the dimensions so that things broadcast nicely:
df_shaped = np.expand_dims(df, (0, 2))
dt_shaped = np.expand_dims(dt, (1, 2))
phi_shaped = np.expand_dims(phi0, (0, 1))
# Calculate phases:
phi = (phi_shaped +
2*np.pi*tau.values*df_shaped +
2*np.pi*r.values*dt_shaped)
# And then phasors.
# (I spent a long time trying to express this in a neat numpy way.
# I did not succeed, so I do it the ugly stupid way for now.)
many_jones_diags = np.exp(1J*phi, dtype=complex)
many_jones = np.zeros(many_jones_diags.shape + (2,), complex)
many_jones[:, :, 0, 0] = many_jones_diags[:, :, 0]
many_jones[:, :, 1, 1] = many_jones_diags[:, :, 1]
return many_jones

def calcGridJones(self, cal_quantum):
dt = (self.time - cal_quantum.t_ref).values
df = (self.frequency - cal_quantum.f_ref).values
# This needs fixed too
for iant, ant in enumerate(cal_quantum.coords['antenna_name'].values):
fp = cal_quantum.sel(antenna_name=ant)
params = np.sum(~np.isnan(fp.values))
print(f"{ant=} {params=}")
if params == 0:
continue
print(f"{fp.values}")
j = self.calcGridJonesAnt(fp, df, dt)
count = np.sum(~np.isnan(j))
print(f"{count=}")
print(f"{np.max(np.abs(j.flatten()))=}")
j = self.insertBaselineDimension(j, self.xds.baseline_id.size)
# Then we can make a version of our baseline jones matrices that only affects a specific ant1:
j_1_mask = np.where(self.ant1_mask != ant, j, 1)
j_2_mask = np.where(self.ant2_mask != ant, j, 1)
# The second array needs to be hermitianized on the last two axes, which we have to do by hand
j_a2 = j.transpose(0, 1, 2, 4, 3).conj()
# Then we can apply those entries to our corrected data array by multiplication:
# First antenna corrected by multiplication from the left:
print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}")
self.j_a1_composed = np.matmul(j, self.j_a1_composed)
# Second antenna in baseline corrected (by Hermitian matrix) from the right
self.j_a2_composed = np.matmul(self.j_a2_composed, j)
if False:
print(f"{numberCount(self.j_a1_composed)=}")
print(f"{nanCount(self.j_a1_composed)=}")
print(f"{numberCount(self.j_a2_composed)=}")
print(f"{nanCount(self.j_a2_composed)=}")
print(f"{np.max(np.abs(self.j_a1_composed.flatten()))=}")


# We need to consult the data for polarizations now.
ps = read_processing_set('n14c3.zarr')
ps.keys()

# Current version of this ps is split by SPW and not by field
xds = ps['n14c3_099']

# In fact, all xdses have the same polarization setup here, but whomst can say if that is always true?
# Actually, I think maybe we could?
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name,
'parameter' : range(3),
'polarization' : pols_ant
})
q = xa.DataArray(coords=quantumCoords)

q.attrs['f_ref'] = xds.frequency[0]
q.attrs['t_ref'] = xds.time[0]

# Note that the f_ref attr copies over a lot of metadata.
# Which I think is a good thing?
q.attrs['f_ref'].attrs['spectral_window_name']


# You can't assign to DataArrays by name, only by integer index.
q[0] = [[0.0, 0.0], [1.0e-9,-1.0e-9], [0, 0]]

gjc = GridJonesCalculator(xds)
gjc.calcGridJones(q)

def squareUpLastDimension(v):
s = v.shape[:-1] + (2,2)
v2 = np.reshape(v, s)
return v2




v = squareUpLastDimension(xds.VISIBILITY.values)
# And this is the payoff I guess
v2 = gjc.j_a1_composed @ v @ gjc.j_a2_composed

# This works, although we should do better along the polarization axis.
xds.assign({'FROBBED' : xa.DataArray( coords=(xds.time, xds.baseline_id, xds.frequency, pols_ant, pols_ant), data=v2)})

# Isn't there meant to be a nice way to get spw now?
#>>> xds.partition_info['spectral_window_name']

# We can also now do this at ps level:
ps2 = ps.sel(spw_name='spw_0')
29 changes: 29 additions & 0 deletions src/astroviper/calibration/exercise_fringe_sbd2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from xradio.vis.read_processing_set import read_processing_set
from graphviper.graph_tools.coordinate_utils import (interpolate_data_coords_onto_parallel_coords,
make_parallel_coord)
from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow
from graphviper.graph_tools.coordinate_utils import make_time_coord
from graphviper.graph_tools.coordinate_utils import make_frequency_coord

from astroviper.calibration.fringefit import fringefit_single
import dask
import xarray as xa

ps = read_processing_set('n14c3.zarr')
ps.keys()

xds = ps['n14c3_000']

#
meas = make_time_coord(time_start='2014-10-22 13:18:00', time_delta=120, n_samples=2)
parallel_coords = {}
parallel_coords['baseline_id'] = make_parallel_coord(
coord=xds.baseline_id, n_chunks=1)
parallel_coords['time'] = make_parallel_coord(meas, n_chunks=1)
node_task_data_mapping = interpolate_data_coords_onto_parallel_coords(parallel_coords,
ps, ps_partition=['spectral_window_name'])
subsel = {'polarization': 'LL'}
res = fringefit_single(ps, node_task_data_mapping, subsel)

# print(res)

98 changes: 80 additions & 18 deletions src/astroviper/calibration/fringefit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,71 @@
from graphviper.graph_tools.generate_dask_workflow import generate_dask_workflow
from typing import Dict, Union

from xradio.vis.read_processing_set import read_processing_set

import dask
import numpy as np
import xarray as xa
import pandas as pd
import datetime

## I should figure out where this belongs at some point
def unique(s):
"Get the unique characters in a string in the order they occur in the original"
u = []
for c in s:
if c not in u:
u.append(c)
return u




def getFourierSpacings(xds):
f = xds.frequency.values
df = (f[-1] - f[0])/(len(f)-1)
dF = len(f)*df
ddelay = 1/dF
#
t = xds.time.values
dt = (t[-1] - t[0])/(len(t)-1)
dT = len(t) * dt
drate = 1/dT
return (ddelay, drate)

def makeCalArray(xds, ref_ant):
pols_ant = unique(''.join([c for c in ''.join(xds.polarization.values)]))
quantumCoords = xa.Coordinates(coords={'antenna_name' : xds.antenna_xds.antenna_name,
'polarization' : pols_ant,
'parameter' : range(3)
})
q = xa.DataArray(coords=quantumCoords)
ref_freq = xds.frequency.reference_frequency['data']
# Should we choose this reference time?
ref_time = xds.time[0]
q.attrs['reference_frequency'] = ref_freq
q.attrs['reference_time'] = ref_time
q.attrs['reference_antenna'] = ref_ant
return q

def _fringe_node_task(input_params: Dict):
ps = input_params['ps']
data_selection = input_params['data_selection']
ref_ant = input_params['ref_ant']
# FIXME: for now we do single band
if len(data_selection.keys())>1:
if len(data_selection.keys()) > 1:
print(f'{data_selection.keys()=}')
raise RuntimeError("We only do single xdses so far")
name = list(data_selection.keys())[0]
xds = ps[name]
q = makeCalArray(xds, ref_ant)
data_sub_selection = input_params['data_sub_selection']
pol = data_sub_selection['polarization']
pols = data_sub_selection['polarization']
# FIXME!
pol = pols[0]
xds2 = xds.isel(**data_selection[name])
xds2 = xds2.sel(polarization=pol)
xds2 = xds2.sel(polarization=pols)
ddelay, drate = getFourierSpacings(xds2)
vis = xds2.VISIBILITY
ang = np.angle(vis)
nvis = np.exp(1J*ang)
Expand All @@ -31,26 +83,35 @@ def _fringe_node_task(input_params: Dict):
),
axes=(0,2)
)
ref_ant = 1
res = {}
bl_slice = data_selection[name]["baseline_id"]
baselines = xds.baseline_id[bl_slice].values
ant1s = xds2.baseline_antenna1_id.values
ant2s = xds2.baseline_antenna1_id.values
baselines = xds2.baseline_id[bl_slice].values
ant1s = xds2.baseline_antenna1_name.values
ant2s = xds2.baseline_antenna2_name.values
try:
# FIXME:
#
# In the case of subcubes we *don't* get all the antenna1_id stuff!
# antenna1_id.values: [ 6 6 6 7 7 7 7 7 8 8 8 8 9 9 9 10 10 11]
# baselines : [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77]
for i, (bl, ant1, ant2) in enumerate(zip(baselines, ant1s, ant2s)):
a = np.abs(fftvis[:,i,:])
if ref_ant not in [ant1, ant2]:
# print(f"Skipping {ant1}-{ant2}")
continue
if ref_ant == ant1 and ref_ant==ant2:
print("Skipping autos")
# print(f"{ant1}-{ant2}")
ant = ant1 if (ant2 == ref_ant) else ant2
spw = xds.partition_info['spectral_window_name']
t = xds.time[0].values
print(f"{ant} {spw} {t}")
a = np.abs(fftvis[:, i, :])
ind = np.unravel_index(np.argmax(a, axis=None), a.shape)
res.setdefault(xds2.time.values[0], {})[bl] = (ind, a[ind], a.shape)
# breakpoint()
ix, iy = ind
phi0 = np.angle(a[ind])
delay = ix*ddelay
ref_freq = xds.frequency.reference_frequency['data']
rate = iy*drate/ref_freq
q.loc[dict(antenna_name=ant, polarization=pol)] = [phi0, delay, rate]
except IndexError as e:
print(f'{xds2.baseline_antenna1_id.values}\n{baselines=}')
print(f'{xds2.baseline_antenna1_name.values}\n{baselines=}')
raise e
return res
return q

def _fringefit_reduce(graph_inputs: xr.Dataset, input_params: Dict):
merged = {}
Expand All @@ -64,13 +125,14 @@ def _fringefit_reduce(graph_inputs: xr.Dataset, input_params: Dict):
return merged


def fringefit_single(ps, node_task_data_mapping: Dict, sub_selection: Dict):
def fringefit_single(ps, node_task_data_mapping: Dict, sub_selection: Dict, ref_ant: int):
"""
TODO!
"""
input_params = {}
input_params['data_sub_selection'] = sub_selection
input_params['ps'] = ps
input_params['ref_ant'] = ref_ant
graph = map(
input_data = ps,
node_task_data_mapping = node_task_data_mapping,
Expand Down

0 comments on commit 51fd462

Please sign in to comment.