Skip to content

Commit

Permalink
Update AstroVIPER to work with latest version of XRADIO and GraphVIPER.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan-Willem committed Nov 18, 2024
1 parent 27d39c5 commit a3ea040
Show file tree
Hide file tree
Showing 16 changed files with 3,291 additions and 4,516 deletions.
767 changes: 0 additions & 767 deletions dev/alma_single_field/single_field_tutorial.ipynb

This file was deleted.

360 changes: 173 additions & 187 deletions docs/astroviper_tutorial_calculate_stats.ipynb

Large diffs are not rendered by default.

772 changes: 393 additions & 379 deletions docs/astroviper_tutorial_mosaics.ipynb

Large diffs are not rendered by default.

454 changes: 74 additions & 380 deletions docs/astroviper_tutorial_spectral_moment.ipynb

Large diffs are not rendered by default.

990 changes: 512 additions & 478 deletions docs/atsroviper_tutorial_fft.ipynb

Large diffs are not rendered by default.

2,042 changes: 2,042 additions & 0 deletions examples/convert_spectral_reference_frame.ipynb

Large diffs are not rendered by default.

2,256 changes: 0 additions & 2,256 deletions examples/example_graph_build.ipynb

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _standard_grid_psf_numpy_wrap(uvw, weight, freq_chan, cgk_1D, grid_parms):


# When jit is used round is repolaced by standard c++ round that is different to python round
@jit(nopython=True, cache=True, nogil=True) #fastmath=True
@jit(nopython=True, cache=True, nogil=True) # fastmath=True
def _standard_grid_jit(
grid,
sum_weight,
Expand Down Expand Up @@ -200,8 +200,8 @@ def _standard_grid_jit(
Returns
-------
"""
#By hardcoding the support and oversampling values, the innermost for loops can be unrolled by the compiler leading to significantly faster code.

# By hardcoding the support and oversampling values, the innermost for loops can be unrolled by the compiler leading to significantly faster code.
# support = 7
# oversampling = 100

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#Enable fastmath (don't check dims).
# Enable fastmath (don't check dims).
from numba import jit
import numpy as np
import math
Expand Down Expand Up @@ -30,24 +30,37 @@ def _create_prolate_spheroidal_kernel(oversampling, support, n_uv):
support_center = support // 2
oversampling_center = oversampling // 2

support_values = (np.arange(support) - support_center)
support_values = np.arange(support) - support_center
if (oversampling % 2) == 0:
oversampling_values = ((np.arange(oversampling + 1) - oversampling_center) / oversampling)[:, None]
kernel_points_1D = (np.broadcast_to(support_values, (oversampling + 1, support)) + oversampling_values)
oversampling_values = (
(np.arange(oversampling + 1) - oversampling_center) / oversampling
)[:, None]
kernel_points_1D = (
np.broadcast_to(support_values, (oversampling + 1, support))
+ oversampling_values
)
else:
oversampling_values = ((np.arange(oversampling) - oversampling_center) / oversampling)[:, None]
kernel_points_1D = (np.broadcast_to(support_values, (oversampling, support)) + oversampling_values)
oversampling_values = (
(np.arange(oversampling) - oversampling_center) / oversampling
)[:, None]
kernel_points_1D = (
np.broadcast_to(support_values, (oversampling, support))
+ oversampling_values
)

kernel_points_1D = kernel_points_1D / support_center

_, kernel_1D = _prolate_spheroidal_function(kernel_points_1D)
# kernel_1D /= np.sum(np.real(kernel_1D[oversampling_center,:]))

if (oversampling % 2) == 0:
kernel = np.zeros((oversampling + 1, oversampling + 1, support, support),
dtype=np.double) # dtype=np.complex128
kernel = np.zeros(
(oversampling + 1, oversampling + 1, support, support), dtype=np.double
) # dtype=np.complex128
else:
kernel = np.zeros((oversampling, oversampling, support, support), dtype=np.double)
kernel = np.zeros(
(oversampling, oversampling, support, support), dtype=np.double
)

for x in range(oversampling):
for y in range(oversampling):
Expand Down Expand Up @@ -75,12 +88,14 @@ def _create_prolate_spheroidal_kernel_1D(oversampling, support):
support_center = support // 2
oversampling_center = oversampling // 2
u = np.arange(oversampling * (support_center)) / (support_center * oversampling)
#print(u)
# print(u)

long_half_kernel_1D = np.zeros(oversampling * (support_center + 1))
_, long_half_kernel_1D[0:oversampling * (support_center)] = _prolate_spheroidal_function(u)

#print(_prolate_spheroidal_function(u))
_, long_half_kernel_1D[0 : oversampling * (support_center)] = (
_prolate_spheroidal_function(u)
)

# print(_prolate_spheroidal_function(u))
return long_half_kernel_1D


Expand All @@ -98,9 +113,18 @@ def _prolate_spheroidal_function(u):
to the edge. The grid correction function is just 1/GRDSF(NU) where NU
is now the distance to the edge of the image.
"""
p = np.array([[8.203343e-2, -3.644705e-1, 6.278660e-1, -5.335581e-1, 2.312756e-1],
[4.028559e-3, -3.697768e-2, 1.021332e-1, -1.201436e-1, 6.412774e-2]])
q = np.array([[1.0000000e0, 8.212018e-1, 2.078043e-1], [1.0000000e0, 9.599102e-1, 2.918724e-1]])
p = np.array(
[
[8.203343e-2, -3.644705e-1, 6.278660e-1, -5.335581e-1, 2.312756e-1],
[4.028559e-3, -3.697768e-2, 1.021332e-1, -1.201436e-1, 6.412774e-2],
]
)
q = np.array(
[
[1.0000000e0, 8.212018e-1, 2.078043e-1],
[1.0000000e0, 9.599102e-1, 2.918724e-1],
]
)

_, n_p = p.shape
_, n_q = q.shape
Expand All @@ -114,7 +138,7 @@ def _prolate_spheroidal_function(u):
uend[(u >= 0.0) & (u < 0.75)] = 0.75
uend[(u >= 0.75) & (u <= 1.0)] = 1.0

delusq = u ** 2 - uend ** 2
delusq = u**2 - uend**2

top = p[part, 0]
for k in range(1, n_p): # small constant size loop
Expand All @@ -125,18 +149,17 @@ def _prolate_spheroidal_function(u):
bot += q[part, k] * np.power(delusq, k)

grdsf = np.zeros(u.shape, dtype=np.float64)
ok = (bot > 0.0)
ok = bot > 0.0
grdsf[ok] = top[ok] / bot[ok]
ok = np.abs(u > 1.0)
grdsf[ok] = 0.0

# Return the correcting image and the gridding kernel value
return grdsf, (1 - u ** 2) * grdsf
return grdsf, (1 - u**2) * grdsf


def _coordinates(npixel: int):
""" 1D array which spans [-.5,.5[ with 0 at position npixel/2
"""
"""1D array which spans [-.5,.5[ with 0 at position npixel/2"""
return (np.arange(npixel) - npixel // 2) / npixel


Expand All @@ -146,4 +169,3 @@ def _coordinates2(npixel: int):
2. (0,0) at pixel (floor(n/2),floor(n/2))
"""
return (np.mgrid[0:npixel, 0:npixel] - npixel // 2) / npixel

6 changes: 3 additions & 3 deletions src/astroviper/_domain/_imaging/_make_imaging_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _make_imaging_weights(ms_xds, grid_parms, imaging_weights_parms, sel_parms):
)

# Calculate Briggs
#print('weight_density_grid',weight_density_grid)
# print('weight_density_grid',weight_density_grid)
briggs_factors = _calculate_briggs_parms(
weight_density_grid, sum_weight, _imaging_weights_parms
) # 2 x chan x pol
Expand All @@ -87,8 +87,8 @@ def _calculate_briggs_parms(grid_of_imaging_weights, sum_weight, imaging_weights
if imaging_weights_parms["weighting"] == "briggs":
robust = imaging_weights_parms["robust"]
briggs_factors = np.ones((2,) + sum_weight.shape)
squared_sum_weight = np.sum((grid_of_imaging_weights)**2, axis=(2, 3))

squared_sum_weight = np.sum((grid_of_imaging_weights) ** 2, axis=(2, 3))
briggs_factors[0, :, :] = (
np.square(5.0 * 10.0 ** (-robust)) / (squared_sum_weight / sum_weight)
)[None, None, :, :]
Expand Down
4 changes: 1 addition & 3 deletions src/astroviper/_domain/_imaging/_make_uv_sampling_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def _make_uv_sampling_grid(
)



def _make_uv_sampling_grid_single_field(
ms_xds, cgk_1D, img_xds, vis_sel_parms, img_sel_parms, grid_parms
):
Expand Down Expand Up @@ -173,7 +172,7 @@ def _make_uv_sampling_grid_single_field(

do_psf = True
do_imaging_weight = False

_standard_grid_jit(
grid,
sum_weight,
Expand All @@ -191,4 +190,3 @@ def _make_uv_sampling_grid_single_field(
support=7,
oversampling=100,
)

3 changes: 0 additions & 3 deletions src/astroviper/_domain/_imaging/_make_visibility_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def _make_visibility_grid(
)




def _make_visibility_grid_single_field(
ms_xds, cgk_1D, img_xds, vis_sel_parms, img_sel_parms, grid_parms
):
Expand Down Expand Up @@ -193,7 +191,6 @@ def _make_visibility_grid_single_field(
do_psf = False
do_imaging_weight = False


_standard_grid_jit(
grid,
sum_weight,
Expand Down
2 changes: 1 addition & 1 deletion src/astroviper/imaging/_utils/_make_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def _make_image(input_params):
import time
from xradio.correlated_data.load_processing_set import ProcessingSetIterator
from xradio.measurement_set.load_processing_set import ProcessingSetIterator
import toolviper.utils.logger as logger
import dask

Expand Down
65 changes: 37 additions & 28 deletions src/astroviper/imaging/_utils/_make_image_single_field.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


def _make_image_single_field(input_params):
import time
from xradio.correlated_data.load_processing_set import ProcessingSetIterator
Expand All @@ -20,7 +18,9 @@ def _make_image_single_field(input_params):
_make_uv_sampling_grid_single_field,
)
from xradio.image import make_empty_sky_image
from astroviper._domain._imaging._make_visibility_grid import _make_visibility_grid_single_field
from astroviper._domain._imaging._make_visibility_grid import (
_make_visibility_grid_single_field,
)
from astroviper._domain._imaging._fft_norm_img_xds import _fft_norm_img_xds

import xarray as xr
Expand All @@ -30,7 +30,6 @@ def _make_image_single_field(input_params):
start_1 = time.time()
grid_params = input_params["grid_params"]


image_freq_coord = input_params["task_coords"]["frequency"]["data"]

if input_params["polarization"] is not None:
Expand Down Expand Up @@ -79,16 +78,19 @@ def _make_image_single_field(input_params):
load_sub_datasets=True,
)
logger.debug("1.5 Created ProcessingSetIterator ")

from astroviper._domain._imaging._imaging_utils.gcf_prolate_spheroidal import _create_prolate_spheroidal_kernel_1D

from astroviper._domain._imaging._imaging_utils.gcf_prolate_spheroidal import (
_create_prolate_spheroidal_kernel_1D,
)

cgk_1D = _create_prolate_spheroidal_kernel_1D(100, 7)

start_2 = time.time()
for ms_xds in ps_iter:
start_compute = time.time()

# Create a mask where baseline_antenna1_name does not equal baseline_antenna2_name
mask = ms_xds['baseline_antenna1_name'] != ms_xds['baseline_antenna2_name']
mask = ms_xds["baseline_antenna1_name"] != ms_xds["baseline_antenna2_name"]
# Apply the mask to the Dataset
ms_xds = ms_xds.where(mask, drop=True)

Expand All @@ -100,7 +102,7 @@ def _make_image_single_field(input_params):
sel_parms={"data_group_in": data_group},
)
T_weights = T_weights + time.time() - start_4

start_7 = time.time()
_make_uv_sampling_grid_single_field(
ms_xds,
Expand All @@ -111,7 +113,6 @@ def _make_image_single_field(input_params):
grid_parms=grid_params,
) # Will become the PSF.
T_uv_sampling_grid = T_uv_sampling_grid + time.time() - start_7


start_8 = time.time()
_make_visibility_grid_single_field(
Expand All @@ -124,22 +125,25 @@ def _make_image_single_field(input_params):
)
T_vis_grid = T_vis_grid + time.time() - start_8
T_compute = T_compute + time.time() - start_compute


#print(img_xds)
from astroviper._domain._imaging._imaging_utils._make_pb_symmetric import _airy_disk_rorder

# print(img_xds)
from astroviper._domain._imaging._imaging_utils._make_pb_symmetric import (
_airy_disk_rorder,
)

pb_parms = {}
pb_parms["list_dish_diameters"] = np.array([10.7])
pb_parms["list_dish_diameters"] = np.array([10.7])
pb_parms["list_blockage_diameters"] = np.array([0.75])
pb_parms["ipower"] = 1

grid_params["image_center"] = (np.array(grid_params["image_size"]) // 2).tolist()
#(1, 1, len(pol), 1, 1))
#print(_airy_disk_rorder(ms_xds.frequency.values, ms_xds.polarization.values, pb_parms, grid_params).shape)
# (1, 1, len(pol), 1, 1))
# print(_airy_disk_rorder(ms_xds.frequency.values, ms_xds.polarization.values, pb_parms, grid_params).shape)

#img_xds["PRIMARY_BEAM"] = xr.DataArray(_airy_disk_rorder(ms_xds.frequency.values, ms_xds.polarization.values, pb_parms, grid_params)[0,...], dims=("frequency", "polarization", "l", "m"))
img_xds["PRIMARY_BEAM"] = xr.DataArray(np.ones(img_xds.UV_SAMPLING.shape), dims=("frequency", "polarization", "l", "m"))

# img_xds["PRIMARY_BEAM"] = xr.DataArray(_airy_disk_rorder(ms_xds.frequency.values, ms_xds.polarization.values, pb_parms, grid_params)[0,...], dims=("frequency", "polarization", "l", "m"))
img_xds["PRIMARY_BEAM"] = xr.DataArray(
np.ones(img_xds.UV_SAMPLING.shape), dims=("frequency", "polarization", "l", "m")
)

T_load = time.time() - start_2 - T_compute

Expand All @@ -152,15 +156,20 @@ def _make_image_single_field(input_params):
logger.debug("Compute " + str(T_compute))

start_9 = time.time()

gcf_xds = xr.Dataset()
gcf_xds.attrs["oversampling"] = [100,100]
gcf_xds.attrs["SUPPORT"] = [7,7]
from astroviper._domain._imaging._imaging_utils.gcf_prolate_spheroidal import _create_prolate_spheroidal_kernel
_, ps_corr_image = _create_prolate_spheroidal_kernel(100, 7, n_uv=img_xds["UV_SAMPLING"].shape[-2:])
gcf_xds.attrs["oversampling"] = [100, 100]
gcf_xds.attrs["SUPPORT"] = [7, 7]
from astroviper._domain._imaging._imaging_utils.gcf_prolate_spheroidal import (
_create_prolate_spheroidal_kernel,
)

_, ps_corr_image = _create_prolate_spheroidal_kernel(
100, 7, n_uv=img_xds["UV_SAMPLING"].shape[-2:]
)

#print(ps_corr_image.shape)
gcf_xds['PS_CORR_IMAGE'] = xr.DataArray(ps_corr_image, dims=("l","m"))
# print(ps_corr_image.shape)
gcf_xds["PS_CORR_IMAGE"] = xr.DataArray(ps_corr_image, dims=("l", "m"))

_fft_norm_img_xds(
img_xds,
Expand Down
2 changes: 1 addition & 1 deletion src/astroviper/imaging/cube_imaging_niter0.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def cube_imaging_niter0(
import xarray as xr
import dask
import os
from xradio.correlated_data import open_processing_set
from xradio.measurement_set import open_processing_set
from graphviper.graph_tools.coordinate_utils import make_parallel_coord
from graphviper.graph_tools import generate_dask_workflow, generate_airflow_workflow
from graphviper.graph_tools import map, reduce
Expand Down
12 changes: 7 additions & 5 deletions tests/domain/imaging/test_cube_imaging_niter0.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
@pytest.fixture
def antennae_from_s3():
from toolviper.dask.client import local_client

viper_client = local_client(cores=4, memory_limit="4GB")

from astroviper.imaging.cube_imaging_niter0 import cube_imaging_niter0
from xradio.correlated_data import open_processing_set

ps_store = "s3://viper-test-data/Antennae_North.cal.lsrk.split.py39.v3.vis.zarr"
ps = open_processing_set(ps_store, intents=["OBSERVE_TARGET#ON_SOURCE"])

Expand All @@ -21,7 +22,9 @@ def antennae_from_s3():
"image_size": [500, 500],
"cell_size": np.array([-0.13, 0.13]) * np.pi / (180 * 3600),
"fft_padding": 1.0,
"phase_direction": ps['Antennae_North.cal.lsrk.split_04'].VISIBILITY.field_and_source_xds.FIELD_PHASE_CENTER,
"phase_direction": ps[
"Antennae_North.cal.lsrk.split_04"
].VISIBILITY.field_and_source_xds.FIELD_PHASE_CENTER,
}
ms_name = "Antennae_North.cal.lsrk.split_00"
n_chunks = None
Expand All @@ -40,7 +43,7 @@ def antennae_from_s3():
data_variables=data_variables,
)
yield output

assert os.path.exists("Antennae_North_Cube.img.zarr")

# cleanup
Expand All @@ -49,4 +52,3 @@ def antennae_from_s3():

def test_file_creation(antennae_from_s3):
assert os.path.exists("Antennae_North_Cube.img.zarr")

0 comments on commit a3ea040

Please sign in to comment.