Skip to content

Commit

Permalink
towards allowing unconventional TFs
Browse files Browse the repository at this point in the history
  • Loading branch information
kkappler committed Nov 9, 2024
1 parent c872290 commit e8df262
Showing 1 changed file with 25 additions and 39 deletions.
64 changes: 25 additions & 39 deletions aurora/pipelines/process_mth5.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@
# =============================================================================


def make_stft_objects(
processing_config, i_dec_level, run_obj, run_xrds, units="MT"
):
def make_stft_objects(processing_config, i_dec_level, run_obj, run_xrds, units="MT"):
"""
Operates on a "per-run" basis. Applies STFT to all time series in the input run.
Expand Down Expand Up @@ -103,9 +101,7 @@ def make_stft_objects(
].channel_scale_factors
elif run_obj.station_metadata.id == processing_config.stations.remote[0].id:
scale_factors = (
processing_config.stations.remote[0]
.run_dict[run_id]
.channel_scale_factors
processing_config.stations.remote[0].run_dict[run_id].channel_scale_factors
)

stft_obj = calibrate_stft_obj(
Expand Down Expand Up @@ -152,9 +148,7 @@ def process_tf_decimation_level(
The transfer function values packed into an object
"""
frequency_bands = config.decimations[i_dec_level].frequency_bands_obj()
transfer_function_obj = TTFZ(
i_dec_level, frequency_bands, processing_config=config
)
transfer_function_obj = TTFZ(i_dec_level, frequency_bands, processing_config=config)
dec_level_config = config.decimations[i_dec_level]
# segment_weights = coherence_weights(dec_level_config, local_stft_obj, remote_stft_obj)
transfer_function_obj = process_transfer_functions(
Expand Down Expand Up @@ -183,9 +177,7 @@ def triage_issue_289(local_stfts: list, remote_stfts: list):
for i_chunk in range(n_chunks):
ok = local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape
if not ok:
logger.warning(
"Mismatch in FC array lengths detected -- Issue #289"
)
logger.warning("Mismatch in FC array lengths detected -- Issue #289")
glb = max(
local_stfts[i_chunk].time.min(),
remote_stfts[i_chunk].time.min(),
Expand All @@ -196,18 +188,13 @@ def triage_issue_289(local_stfts: list, remote_stfts: list):
)
cond1 = local_stfts[i_chunk].time >= glb
cond2 = local_stfts[i_chunk].time <= lub
local_stfts[i_chunk] = local_stfts[i_chunk].where(
cond1 & cond2, drop=True
)
local_stfts[i_chunk] = local_stfts[i_chunk].where(cond1 & cond2, drop=True)
cond1 = remote_stfts[i_chunk].time >= glb
cond2 = remote_stfts[i_chunk].time <= lub
remote_stfts[i_chunk] = remote_stfts[i_chunk].where(
cond1 & cond2, drop=True
)
assert (
local_stfts[i_chunk].time.shape
== remote_stfts[i_chunk].time.shape
)
assert local_stfts[i_chunk].time.shape == remote_stfts[i_chunk].time.shape
return local_stfts, remote_stfts


Expand Down Expand Up @@ -306,9 +293,7 @@ def load_stft_obj_from_mth5(
An STFT from mth5.
"""
station_obj = station_obj_from_row(row)
fc_group = station_obj.fourier_coefficients_group.get_fc_group(
run_obj.metadata.id
)
fc_group = station_obj.fourier_coefficients_group.get_fc_group(run_obj.metadata.id)
fc_decimation_level = fc_group.get_decimation_level(f"{i_dec_level}")
stft_obj = fc_decimation_level.to_xarray(channels=channels)

Expand Down Expand Up @@ -369,10 +354,7 @@ def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None:
raise NotImplementedError(msg)

# Get FC group (create if needed)
if (
run_obj.metadata.id
in station_obj.fourier_coefficients_group.groups_list
):
if run_obj.metadata.id in station_obj.fourier_coefficients_group.groups_list:
fc_group = station_obj.fourier_coefficients_group.get_fc_group(
run_obj.metadata.id
)
Expand All @@ -393,9 +375,7 @@ def save_fourier_coefficients(dec_level_config, row, run_obj, stft_obj) -> None:
dec_level_name,
decimation_level_metadata=decimation_level_metadata,
)
fc_decimation_level.from_xarray(
stft_obj, decimation_level_metadata.sample_rate
)
fc_decimation_level.from_xarray(stft_obj, decimation_level_metadata.sample_rate)
fc_decimation_level.update_metadata()
fc_group.update_metadata()
else:
Expand Down Expand Up @@ -535,9 +515,7 @@ def process_mth5_legacy(
local_merged_stft_obj,
remote_merged_stft_obj,
)
ttfz_obj.apparent_resistivity(
tfk.config.channel_nomenclature, units=units
)
ttfz_obj.apparent_resistivity(tfk.config.channel_nomenclature, units=units)
tf_dict[i_dec_level] = ttfz_obj

if show_plot:
Expand All @@ -549,10 +527,20 @@ def process_mth5_legacy(
tf_dict=tf_dict, processing_config=tfk.config
)

tf_cls = tfk.export_tf_collection(tf_collection)

if z_file_path:
tf_cls.write(z_file_path)
try:
tf_cls = tfk.export_tf_collection(tf_collection)
if z_file_path:
tf_cls.write(z_file_path)
except Exception as e:
msg = "TF collection could not export to mt_metadata TransferFunction\n"
msg += f"Failed with exception {e}\n"
msg += "Perhaps an unconventional mixture of input/output channels was used\n"
msg += f"Input channels were {tfk.config.decimations[0].input_channels}\n"
msg += f"Output channels were {tfk.config.decimations[0].output_channels}\n"
msg += "No z-file will be written in this case\n"
msg += "Will return a legacy TransferFunctionCollection object, not mt_metadata object."
logger.error(msg)
return_collection = True

tfk.dataset.close_mth5s()
if return_collection:
Expand Down Expand Up @@ -602,9 +590,7 @@ def process_mth5(
The transfer function object
"""
if processing_type not in SUPPORTED_PROCESSINGS:
raise NotImplementedError(
f"Processing type {processing_type} not supported"
)
raise NotImplementedError(f"Processing type {processing_type} not supported")

if processing_type == "legacy":
try:
Expand Down

0 comments on commit e8df262

Please sign in to comment.