Skip to content

Commit

Permalink
permit TEOS10 conserv temp and abs sal
Browse files Browse the repository at this point in the history
  • Loading branch information
jpolton committed May 13, 2024
1 parent 70879df commit 634225f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 18 deletions.
11 changes: 8 additions & 3 deletions coast/data/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,13 @@ def construct_density(
# jth self.dataset.z_dim.size,
# self.dataset.id_dim.size,
)
sal = self.dataset.practical_salinity.to_masked_array()
temp = self.dataset.potential_temperature.to_masked_array()

if CT_AS:
temp = self.dataset.conservative_temperature.to_masked_array()
sal = self.dataset.absolute_salinity.to_masked_array()
else:
temp = self.dataset.potential_temperature.to_masked_array()
sal = self.dataset.practical_salinity.to_masked_array()

if np.shape(sal) != shape_ds:
sal = sal.T
Expand Down Expand Up @@ -1053,7 +1058,7 @@ def construct_density(
else:
attributes = {"units": "kg / m^3", "standard name": "In-situ density "}

density = np.squeeze(density)
#density = np.squeeze(density) # squeezing out id_dim, if size=1 is bad.
self.dataset[new_var_name] = xr.DataArray(density, coords=coords, dims=dims, attrs=attributes)

except AttributeError as err:
Expand Down
70 changes: 55 additions & 15 deletions coast/diagnostics/profile_stratification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,13 @@ def __init__(self, profile: xr.Dataset):
self.nz = profile.dataset.dims["z_dim"]
debug(f"Initialised {get_slug(self)}")

def clean_data(profile: xr.Dataset, gridded: xr.Dataset, Zmax, limits=[0, 0, 0, 0], rmax=25.0):
def clean_data(profile: xr.Dataset, gridded: xr.Dataset, Zmax, CT_AS: bool=False, limits=[0, 0, 0, 0], rmax=25.0):
"""
parameters:
CT_AS: bool - determines whether conservative_temperature and absolute salinity are expected (if True).
if False: potential_temperature and practical_salinity
Cleaning data for stratification metric calculations
Stage 1:...
Expand All @@ -54,10 +59,17 @@ def clean_data(profile: xr.Dataset, gridded: xr.Dataset, Zmax, limits=[0, 0, 0,
# find profiles good for SST and NBT
dz_max = 25.0

if not CT_AS:
temperature_var = "potential_temperature"
salinity_var = "practical_salinity"
else:
temperature_var = "conservative_temperature"
salinity_var = "absolute_salinity"

n_prf = profile.dataset.id_dim.shape[0]
n_depth = profile.dataset.z_dim.shape[0]
tmp_clean = profile.dataset.potential_temperature.values[:, :]
sal_clean = profile.dataset.practical_salinity.values[:, :]
tmp_clean = profile.dataset[temperature_var].values[:, :]
sal_clean = profile.dataset[salinity_var].values[:, :]

any_tmp = np.sum(~np.isnan(tmp_clean), axis=1) != 0
any_sal = np.sum(~np.isnan(sal_clean), axis=1) != 0
Expand All @@ -71,6 +83,11 @@ def first_nonzero(arr, axis=0, invalid_val=np.nan):
profile.gridded_to_profile_2d(gridded, "bathymetry", limits=limits, rmax=rmax)
D_prf = profile.dataset.bathymetry.values
z = profile.dataset.depth
if np.shape(z.values) != (n_prf, n_depth): z = z.transpose()
if np.shape(z.values) != (n_prf, n_depth): print(f"Problem with the shape of profile.dataset.depth")

print(f"shape pot temp:{np.shape(profile.dataset[temperature_var].values[:,:])}")
print(f"shape z:{np.shape(z)}. shape D_prf:{np.shape(np.repeat(D_prf[:, np.newaxis], n_depth, axis=1))}")
test_surface = z < np.minimum(dz_max, 0.25 * np.repeat(D_prf[:, np.newaxis], n_depth, axis=1))
test_tmp = np.logical_and(test_surface, ~np.isnan(tmp_clean))
test_sal = np.logical_and(test_surface, ~np.isnan(sal_clean))
Expand Down Expand Up @@ -112,9 +129,12 @@ def first_nonzero(arr, axis=0, invalid_val=np.nan):
# fill holes in data
# jth This is slow, there may be a more 'vector' way of doing it
# %%
tmp1 = profile.dataset.potential_temperature.values[:, :]
sal1 = profile.dataset.practical_salinity.values[:, :]
tmp1 = profile.dataset[temperature_var].values[:, :]
sal1 = profile.dataset[salinity_var].values[:, :]
z1 = profile.dataset.depth.values[:, :]
if np.shape(z1) != (n_prf, n_depth): z1 = z1.transpose()
if np.shape(z1) != (n_prf, n_depth): print(f"Problem with the shape of profile.dataset.depth")

for i_prf in range(n_prf):
tmp = tmp1[i_prf, :]
sal = sal1[i_prf, :]
Expand All @@ -134,16 +154,16 @@ def first_nonzero(arr, axis=0, invalid_val=np.nan):
"longitude": (("id_dim"), profile.dataset.longitude.values),
}
dims = ["id_dim", "z_dim"]
profile.dataset["potential_temperature"] = xr.DataArray(tmp_clean, coords=coords, dims=dims)
profile.dataset["practical_salinity"] = xr.DataArray(sal_clean, coords=coords, dims=dims)
profile.dataset[temperature_var] = xr.DataArray(tmp_clean, coords=coords, dims=dims)
profile.dataset[salinity_var] = xr.DataArray(sal_clean, coords=coords, dims=dims)
profile.dataset["sea_surface_temperature"] = xr.DataArray(SST, coords=coords, dims=["id_dim"])
profile.dataset["sea_surface_salinity"] = xr.DataArray(SSS, coords=coords, dims=["id_dim"])
profile.dataset["good_profile"] = xr.DataArray(good_profile, coords=coords, dims=["id_dim"])
print("All nice and clean")
# %%
return profile

def calc_pea(self, profile: xr.Dataset, gridded: xr.Dataset, Zmax, rmax=25.0, limits=[0, 0, 0, 0]):
def calc_pea(self, profile: xr.Dataset, gridded: xr.Dataset, Zmax, CT_AS: bool=False, rmax=25.0, limits=[0, 0, 0, 0]):
"""
Calculates Potential Energy Anomaly
Expand All @@ -158,7 +178,25 @@ def calc_pea(self, profile: xr.Dataset, gridded: xr.Dataset, Zmax, rmax=25.0, li
gravity = 9.81
# Clean data This is quit slow and over writes potential temperature and practical salinity variables

profile = ProfileStratification.clean_data(profile, gridded, Zmax)
if not CT_AS:
temperature_var = "potential_temperature"
salinity_var = "practical_salinity"
else:
temperature_var = "conservative_temperature"
salinity_var = "absolute_salinity"

## JP ## profile = ProfileStratification.clean_data(profile, gridded, Zmax, CT_AS)
n_prf = profile.dataset.id_dim.shape[0]
coords = {
"time": ("id_dim", profile.dataset.time.values),
"latitude": (("id_dim"), profile.dataset.latitude.values),
"longitude": (("id_dim"), profile.dataset.longitude.values),
}
good_profile = np.array(np.ones(n_prf), dtype=bool)
profile.dataset["good_profile"] = xr.DataArray(good_profile, coords=coords, dims=["id_dim"])
profile.dataset["sea_surface_temperature"] = profile.dataset[temperature_var].isel(z_dim=0)
profile.dataset["sea_surface_salinity"] = profile.dataset[salinity_var].isel(z_dim=0)


# Define grid spacing, dz. Required for depth integral
profile.calculate_vertical_spacing()
Expand All @@ -178,9 +216,9 @@ def calc_pea(self, profile: xr.Dataset, gridded: xr.Dataset, Zmax, rmax=25.0, li
# ) # jth why not just use depth here?

if not "density" in profile.dataset:
profile.construct_density(CT_AS=False, pot_dens=True)
profile.construct_density(CT_AS=CT_AS, pot_dens=True)
if not "density_bar" in profile.dataset:
profile.construct_density(CT_AS=False, rhobar=True, Zd_mask=Zd_mask, pot_dens=True)
profile.construct_density(CT_AS=CT_AS, rhobar=True, Zd_mask=Zd_mask, pot_dens=True)
rho = profile.dataset.variables["density"].fillna(0) # density
rhobar = profile.dataset.variables["density_bar"] # density with depth-mean T and S

Expand All @@ -197,13 +235,15 @@ def calc_pea(self, profile: xr.Dataset, gridded: xr.Dataset, Zmax, rmax=25.0, li
"longitude": (("id_dim"), profile.dataset.longitude.values),
}
dims = ["id_dim"]
attributes = {"units": "J / m^3", "standard_name": "Potential Energy Anomaly"}
self.dataset["pea"] = xr.DataArray(pot_energy_anom, coords=coords, dims=dims, attrs=attributes)
pea_attributes = {"units": "J / m^3", "standard_name": "Potential Energy Anomaly"}
sst_attributes = {"units": "deg C", "standard_name": "Sea Surface Temperature"}
sss_attributes = {"units": "psu", "standard_name": "Sea Surface Salinity"}
self.dataset["pea"] = xr.DataArray(pot_energy_anom, coords=coords, dims=dims, attrs=pea_attributes)
self.dataset["sst"] = xr.DataArray(
profile.dataset.variables["sea_surface_temperature"], coords=coords, dims=dims, attrs=attributes
profile.dataset.variables["sea_surface_temperature"], coords=coords, dims=dims, attrs=sst_attributes
)
self.dataset["sss"] = xr.DataArray(
profile.dataset.variables["sea_surface_salinity"], coords=coords, dims=dims, attrs=attributes
profile.dataset.variables["sea_surface_salinity"], coords=coords, dims=dims, attrs=sss_attributes
)

def quick_plot(self, var: xr.DataArray = None):
Expand Down

0 comments on commit 634225f

Please sign in to comment.