Skip to content

Commit

Permalink
Merge pull request #79 from antarctica/hom_condition_hotfix
Browse files Browse the repository at this point in the history
Hom condition hotfix
  • Loading branch information
hjabbot authored Sep 20, 2024
2 parents bc26dde + f8982e2 commit d29d34a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 103 deletions.
2 changes: 1 addition & 1 deletion meshiphi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.12"
__version__ = "2.1.13"
__description__ = "MeshiPhi: Earth's digital twin mapped on a non-uniform mesh"
__license__ = "MIT"
__author__ = "Autonomous Marine Operations Planning (AMOP) Team, AI Lab, British Antarctic Survey"
Expand Down
12 changes: 10 additions & 2 deletions meshiphi/dataloaders/scalar/abstract_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,11 @@ def get_hom_condition_from_df(dps, splitting_conds):
else:
# Determine fraction of datapoints over threshold value
num_over_threshold = dps[dps > splitting_conds['threshold']]
frac_over_threshold = num_over_threshold.shape[0]/dps.shape[0]
num_non_nan = np.count_nonzero(~np.isnan(dps))
if num_non_nan > 0:
frac_over_threshold = num_over_threshold.shape[0]/num_non_nan
else:
frac_over_threshold = 0

# Return homogeneity condition
if frac_over_threshold <= splitting_conds['lower_bound']: hom_type = "CLR"
Expand Down Expand Up @@ -560,7 +564,11 @@ def get_hom_condition_from_xr(dps, splitting_conds):
else:
# Determine fraction of datapoints over threshold value
num_over_threshold = np.count_nonzero(dps > splitting_conds['threshold'])
frac_over_threshold = num_over_threshold/dps.size
num_non_nan = np.count_nonzero(~np.isnan(dps))
if num_non_nan > 0:
frac_over_threshold = num_over_threshold/num_non_nan
else:
frac_over_threshold = 0
# Return homogeneity condition
if frac_over_threshold <= splitting_conds['lower_bound']: hom_type = "CLR"
elif frac_over_threshold >= splitting_conds['upper_bound']:
Expand Down
138 changes: 38 additions & 100 deletions meshiphi/dataloaders/vector/abstract_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,10 +422,16 @@ def get_value_from_df(dps, variable_names, bounds, agg_type, skipna):
values = [data_count, data_count]
elif agg_type == 'MIN':
index = dps['_magnitude'].idxmin(skipna=skipna)
values = [dps[name][index] for name in variable_names]
if ~np.isnan(index):
values = [dps[name][index] for name in variable_names]
else:
values = [np.nan for name in variable_names]
elif agg_type == 'MAX':
index = dps['_magnitude'].idxmax(skipna=skipna)
values = [dps[name][index] for name in variable_names]
if ~np.isnan(index):
values = [dps[name][index] for name in variable_names]
else:
values = [np.nan for name in variable_names]
elif agg_type == 'MEAN':
values = [dps[name].mean(skipna=skipna) for name in variable_names]
elif agg_type == 'STD':
Expand Down Expand Up @@ -558,20 +564,36 @@ def get_hom_condition(self, bounds, splitting_conds, agg_type='MEAN', data=None)
# To allow multiple modes of splitting, chuck them in the splitting conditions
# Split if magnitude of curl(data) is larger than threshold
if 'curl' in splitting_conds:
curl = self.calc_curl(bounds)
if np.abs(curl) > splitting_conds['curl']:
hom_type = 'HET'
flow = self.calc_curl(bounds, collapse=False)
sc = splitting_conds['curl']
# Split if max magnitude(any_vector - ave_vector) is larger than threshold
if 'dmag' in splitting_conds:
dmag = self.calc_dmag(bounds)
if np.abs(dmag) > splitting_conds['dmag']:
hom_type = 'HET'
elif 'dmag' in splitting_conds:
flow = self.calc_dmag(bounds, collapse=False)
sc = splitting_conds['dmag']

if 'split_lock' not in sc:
sc['split_lock'] = False

if isinstance(flow, type(np.nan)) and np.isnan(flow):
return "CLR"
num_over_threshold = (flow > sc['threshold']).sum()

num_non_nan = np.count_nonzero(~np.isnan(flow))
if num_non_nan > 0:
frac_over_threshold = num_over_threshold/num_non_nan
else:
frac_over_threshold = 0


if frac_over_threshold <= sc['lower_bound']:
hom_type = "CLR"
elif frac_over_threshold >= sc['upper_bound']:
if sc['split_lock'] == True:
hom_type = "HOM"
else:
hom_type = "CLR"
else: hom_type = "HET"

# Split if Reynolds number is larger than threshold
if 'reynolds' in splitting_conds:
reynolds = self.calc_reynolds_number(bounds)
if reynolds > splitting_conds['reynolds']:
hom_type = 'HET'

logging.debug(f"\thom_condition for attribute: '{self.data_name}' in bounds:'{bounds}' returned '{hom_type}'")

Expand Down Expand Up @@ -903,89 +925,6 @@ def set_data_col_name_list(self, new_names):
self.data_name_list = new_names
return self.set_data_col_name(new_data_name)

def calc_reynolds_number(self, bounds):
'''
Calculates an approximate Reynolds number from the mean vector velocity
and cellbox size.
CURRENTLY ASSUMES DENSITY AND VISCOSITY OF SEAWATER AT 4°C!
WILL NEED MINOR REWORKING TO INCLUDE DIFFERENT FLUIDS
Args:
bounds (Boundary):
Cellbox boundary to calculate characteristic length from
Returns:
float:
Reynolds number of cellbox
'''
# Extract the speed
velocity = self.get_value(bounds, agg_type='MEAN')
speed = np.linalg.norm(list(velocity.values())) # Calculates magnitude
# Extract the characteristic length
length = bounds.calc_size()
# Calculate the reynolds number and return
logging.warning("\tReynold number used for splitting, this function assumes properties of ocean water!")
return 1028 * 0.00167 * speed * length

def calc_divergence(self, bounds, data=None, collapse=True, agg_type='MAX'):
'''
Calculates the divergence of vectors in a cellbox
Args:
bounds (Boundary):
Cellbox boundary in which all relevant vectors are contained
data (pd.DataFrame or xr.Dataset):
Dataset with 'lat' and 'long' columns/dimensions with vectors
collapes (bool):
Flag determining whether to return an aggregated value, or a
vector field (values for each individual vector).
agg_type (str):
Method of aggregation if collapsing value.
Accepts 'MAX' or 'MEAN'
Returns:
float or pd.DataFrame:
float value of aggregated div if collapse=True, or
pd.DataFrame of div vector field if collapse=False
Raises:
ValueError: If agg_type is not 'MAX' or 'MEAN'
'''
if data is None: dps = self.trim_datapoints(bounds, data=data)
else: dps = data

# Create a meshgrid of vectors from the data
vector_field = self._create_vector_meshgrid(dps, self.data_name_list)

# Get component values for each vector
fx, fy = vector_field[:, :, 0], vector_field[:, :, 1]
# If not enough datapoints to compute gradient
if 1 in fx.shape or 1 in fy.shape:
logging.debug('\tUnable to compute gradient across cell for divergence calculation')
div = np.nan
else:
# Compute partial derivatives
dfx_dy = np.gradient(fx, axis=1)
dfy_dx = np.gradient(fy, axis=0)
# Compute curl
div = dfy_dx + dfx_dy

# If div is nan
if np.isnan(div).all():
logging.debug('\tAll NaN cellbox encountered')
return np.nan
# If want to collapse to max mag value, return scalar
elif collapse:
if agg_type == 'MAX': return max(np.nanmax(div), np.nanmin(div), key=abs)
elif agg_type == 'MEAN': return np.nanmean(div)
else:
raise ValueError(f"agg_type '{agg_type}' not understood! Requires 'MAX' or 'MEAN'")
# Else return field
else:
return div


def calc_curl(self, bounds, data=None, collapse=True, agg_type='MAX'):
'''
Calculates the curl of vectors in a cellbox
Expand Down Expand Up @@ -1027,7 +966,7 @@ def calc_curl(self, bounds, data=None, collapse=True, agg_type='MAX'):
# Compute curl
curl = dfy_dx - dfx_dy

# If div is nan
# If curl is nan
if np.isnan(curl).all():
logging.debug('\tAll NaN cellbox encountered')
return np.nan
Expand Down Expand Up @@ -1081,7 +1020,7 @@ def calc_dmag(self, bounds, data=None, collapse=True, agg_type='MEAN'):
if len(d_mag) == 0:
logging.debug('\tEmpty cellbox encountered')
return np.nan
# If div is nan
# If d_mag is nan
elif np.isnan(d_mag).all():
logging.debug('\tAll NaN cellbox encountered')
return np.nan
Expand All @@ -1094,7 +1033,6 @@ def calc_dmag(self, bounds, data=None, collapse=True, agg_type='MEAN'):
# Else return field
else: return d_mag


@staticmethod
def _create_vector_meshgrid(data, data_name_list):
'''
Expand Down

0 comments on commit d29d34a

Please sign in to comment.