Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update reservoir precip accumulation implementation #2349

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

SECONDS_PER_DAY = 86400
TOLERANCE = 1.0e-12
ML_STEPPER_NAMES = ["machine_learning", "reservoir_predictor"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,7 +93,16 @@ def _column_pq2(ds: xr.Dataset) -> xr.DataArray:


def _column_dq1(ds: xr.Dataset) -> xr.DataArray:
if "net_heating_due_to_machine_learning" in ds:

ml_col_heating_names = {
f"column_heating_due_to_{stepper}" for stepper in ML_STEPPER_NAMES
}
if len(ml_col_heating_names.intersection(set(ds.variables))) > 0:
column_dq1 = xr.zeros_like(ds.PRATEsfc)
for var in ml_col_heating_names:
if var in ds:
column_dq1 = column_dq1 + ds[var]
elif "net_heating_due_to_machine_learning" in ds:
warnings.warn(
"'net_heating_due_to_machine_learning' is a deprecated variable name. "
"It will not be supported in future versions of fv3net. Use "
Expand All @@ -110,8 +120,6 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray:
)
# fix isochoric vs isobaric transition issue
column_dq1 = 716.95 / 1004 * ds.net_heating
elif "column_heating_due_to_machine_learning" in ds:
column_dq1 = ds.column_heating_due_to_machine_learning
elif "storage_of_internal_energy_path_due_to_machine_learning" in ds:
column_dq1 = ds.storage_of_internal_energy_path_due_to_machine_learning
else:
Expand All @@ -125,8 +133,15 @@ def _column_dq1(ds: xr.Dataset) -> xr.DataArray:


def _column_dq2(ds: xr.Dataset) -> xr.DataArray:
if "net_moistening_due_to_machine_learning" in ds:
column_dq2 = SECONDS_PER_DAY * ds.net_moistening_due_to_machine_learning

ml_col_moistening_names = {
f"net_moistening_due_to_{stepper}" for stepper in ML_STEPPER_NAMES
}
if len(ml_col_moistening_names.intersection(set(ds.variables))) > 0:
column_dq2 = xr.zeros_like(ds.PRATEsfc)
for var in ml_col_moistening_names:
if var in ds:
column_dq2 = column_dq2 + ds[var]
elif "storage_of_specific_humidity_path_due_to_machine_learning" in ds:
column_dq2 = (
SECONDS_PER_DAY
Expand Down
31 changes: 21 additions & 10 deletions workflows/prognostic_c48_run/runtime/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,19 +586,32 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics:
diags,
state_updates,
) = self._reservoir_predict_stepper(self._state.time, self._state)

logger.info(f"Reservoir stepper diagnostics: {list(diags.keys())}")
logger.info(
f"Reservoir stepper state updates: {list(state_updates.keys())}"
)

if self._reservoir_predict_stepper.is_diagnostic: # type: ignore
rename_diagnostics(diags, label="reservoir_predictor")

(
stepper_diags,
net_moistening,
diags_from_tendencies,
_,
) = self._reservoir_predict_stepper.get_diagnostics(
self._state, tendencies_from_state_prediction
)
diags.update(stepper_diags)
if self._reservoir_predict_stepper.is_diagnostic: # type: ignore
rename_diagnostics(diags, label="reservoir_predictor")
diags.update(diags_from_tendencies)

state_updates[TOTAL_PRECIP] = precipitation_sum(
self._state[TOTAL_PRECIP], net_moistening, self._timestep,
net_moistening_due_to_reservoir_adjustment = diags.get(
"net_moistening_due_to_reservoir_adjustment",
xr.zeros_like(self._state[TOTAL_PRECIP]),
)
precip = self._reservoir_predict_stepper.update_precip( # type: ignore
self._state[TOTAL_PRECIP], net_moistening_due_to_reservoir_adjustment,
)
diags.update(precip)
state_updates[TOTAL_PRECIP] = precip[TOTAL_PRECIP]

self._state.update_mass_conserving(state_updates)

Expand All @@ -609,9 +622,7 @@ def _apply_reservoir_update_to_state(self) -> Diagnostics:
"cnvprcp_after_python": self._fv3gfs.get_diagnostic_by_name(
"cnvprcp"
).data_array,
TOTAL_PRECIP_RATE: precipitation_rate(
self._state[TOTAL_PRECIP], self._timestep
),
TOTAL_PRECIP_RATE: precip["total_precip_rate_res_interval_avg"],
}
)

Expand Down
156 changes: 142 additions & 14 deletions workflows/prognostic_c48_run/runtime/steppers/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import fv3fit
from fv3fit._shared.halos import append_halos_using_mpi
from fv3fit.reservoir.adapters import ReservoirDatasetAdapter
from runtime.names import SST, SPHUM, TEMP
from runtime.names import SST, SPHUM, TEMP, PHYSICS_PRECIP_RATE, TOTAL_PRECIP
from runtime.tendency import add_tendency, tendencies_from_state_updates
from runtime.diagnostics import (
enforce_heating_and_moistening_tendency_constraints,
Expand Down Expand Up @@ -63,6 +63,7 @@ class ReservoirConfig:
rename_mapping: NameDict = dataclasses.field(default_factory=dict)
hydrostatic: bool = False
mse_conserving_limiter: bool = False
interval_average_precipitation: bool = False


class _FiniteStateMachine:
Expand Down Expand Up @@ -104,6 +105,75 @@ def __call__(self, state: str):
)


class TendencyPrecipTracker:
def __init__(self, reservoir_timestep_seconds: float):
self.reservoir_timestep_seconds = reservoir_timestep_seconds
self.physics_precip_averager = TimeAverageInputs([PHYSICS_PRECIP_RATE])
self._air_temperature_at_previous_interval = None
self._specific_humidity_at_previous_interval = None

def increment_physics_precip_rate(self, physics_precip_rate):
self.physics_precip_averager.increment_running_average(
{PHYSICS_PRECIP_RATE: physics_precip_rate}
)

def average_physics_precip_rate(self):
return self.physics_precip_average.get_averages()[PHYSICS_PRECIP_RATE]

def update_tracked_state(self, air_temperature, specific_humidity):
self._air_temperature_at_previous_interval = air_temperature
self._specific_humidity_at_previous_interval = specific_humidity

def calculate_tendencies(self, air_temperature, specific_humidity):
if (
self._specific_humidity_at_previous_interval is None
or self._air_temperature_at_previous_interval is None
):
logger.info(
"Previous reservoir prediction of specific_humidity and "
"air_temperature not saved. Returning zero tendencies"
)
dQ1, dQ2 = xr.zeros_like(air_temperature), xr.zeros_like(air_temperature)
else:
dQ1 = (
air_temperature - self._air_temperature_at_previous_interval
) / self.reservoir_timestep_seconds
dQ2 = (
specific_humidity - self._specific_humidity_at_previous_interval
) / self.reservoir_timestep_seconds
return {"dQ1": dQ1, "dQ2": dQ2}

def interval_avg_precip_rates(self, net_moistening_due_to_reservoir):
physics_precip_rate = self.physics_precip_averager.get_averages()[
PHYSICS_PRECIP_RATE
]
total_precip_rate = physics_precip_rate - net_moistening_due_to_reservoir
total_precip_rate = total_precip_rate.where(total_precip_rate >= 0, 0)
reservoir_precip_rate = total_precip_rate - physics_precip_rate
return {
"total_precip_rate_res_interval_avg": total_precip_rate,
"physics_precip_rate_res_interval_avg": physics_precip_rate,
"reservoir_precip_rate_res_interval_avg": reservoir_precip_rate,
}

def accumulated_precip_update(
self,
physics_precip_total_over_model_timestep,
reservoir_precip_rate_over_res_interval,
reservoir_timestep,
):
# Since the reservoir correction is only applied every reservoir_timestep,
# all of the precip due to the reservoir is put into the accumulated precip
# in the model timestep at update time.
m_per_mm = 1 / 1000
reservoir_total_precip = (
reservoir_precip_rate_over_res_interval * reservoir_timestep * m_per_mm
)
total_precip = physics_precip_total_over_model_timestep + reservoir_total_precip
total_precip.attrs["units"] = "m"
return total_precip


class TimeAverageInputs:
"""
Copy of time averaging components from runtime.diagnostics.manager to
Expand Down Expand Up @@ -170,6 +240,7 @@ def __init__(
warm_start: bool = False,
hydrostatic: bool = False,
mse_conserving_limiter: bool = False,
tendency_precip_tracker: Optional[TendencyPrecipTracker] = None,
):
self.model = model
self.synchronize_steps = synchronize_steps
Expand All @@ -181,6 +252,7 @@ def __init__(
self.warm_start = warm_start
self.hydrostatic = hydrostatic
self.mse_conserving_limiter = mse_conserving_limiter
self.tendency_precip_tracker = tendency_precip_tracker

if state_machine is None:
state_machine = _FiniteStateMachine()
Expand Down Expand Up @@ -313,6 +385,7 @@ def predict(self, inputs, state):

self._state_machine(self._state_machine.PREDICT)
result = self.model.predict(inputs)

output_state = rename_dataset_members(result, self.rename_mapping)

diags = rename_dataset_members(
Expand Down Expand Up @@ -360,12 +433,19 @@ def __call__(self, time, state):
if self.input_averager is not None:
self.input_averager.increment_running_average(inputs)

if self.tendency_precip_tracker is not None:
self.tendency_precip_tracker.increment_physics_precip_rate(
state[PHYSICS_PRECIP_RATE]
)

tendencies, diags, updated_state = {}, {}, {}

if self._is_rc_update_step(time):
logger.info(f"Reservoir model predict at time {time}")
if self.input_averager is not None:
inputs.update(self.input_averager.get_averages())

tendencies, diags, updated_state = self.predict(inputs, state)
_, diags, updated_state = self.predict(inputs, state)

hybrid_diags = rename_dataset_members(
inputs, {k: f"{self.rename_mapping.get(k, k)}_hyb_in" for k in inputs}
Expand All @@ -375,14 +455,11 @@ def __call__(self, time, state):
# This check is done on the _rc_out diags since those are always available.
# This allows zero field diags to be returned on timesteps where the
# reservoir is not updating the state.
diags_Tq_vars = {f"{v}_{self.DIAGS_OUTPUT_SUFFIX}" for v in [TEMP, SPHUM]}

if diags_Tq_vars.issubset(list(diags.keys())):
# TODO: Currently the reservoir only predicts updated states and returns
# empty tendencies. If tendency predictions are implemented in the
# prognostic run, the limiter/conservation updates should be updated to
# take this option into account and use predicted tendencies directly.
tendencies_from_state_prediction = tendencies_from_state_updates(
# diags_Tq_vars = {f"{v}_{self.DIAGS_OUTPUT_SUFFIX}" for v in [TEMP, SPHUM]}
# if diags_Tq_vars.issubset(list(diags.keys())):

if self.tendency_precip_tracker is not None:
tendencies_over_model_timestep = tendencies_from_state_updates(
initial_state=state,
updated_state=updated_state,
dt=self.model_timestep,
Expand All @@ -392,7 +469,7 @@ def __call__(self, time, state):
diagnostics_updates_from_constraints,
) = enforce_heating_and_moistening_tendency_constraints(
state=state,
tendency=tendencies_from_state_prediction,
tendency=tendencies_over_model_timestep,
timestep=self.model_timestep,
mse_conserving=self.mse_conserving_limiter,
hydrostatic=self.hydrostatic,
Expand All @@ -401,19 +478,63 @@ def __call__(self, time, state):
zero_fill_missing_tendencies=True,
)

# net moistening from reservoir update is calculated using the
# difference from the last model timestep, but is interpreted
# as an update over the reservoir timestep
# Tendencies over model timesteps are popped- they are only
# used in the limiter and constraint adjustments
_, net_moistening_due_to_reservoir = self.get_diagnostics(
state,
{
"dQ1": tendency_updates_from_constraints.pop("dQ1"),
"dQ2": tendency_updates_from_constraints.pop("dQ2"),
},
)
net_moistening_res = net_moistening_due_to_reservoir * (
self.model_timestep / self.timestep.total_seconds()
)
diags.update(
{"net_moistening_due_to_reservoir_adjustment": net_moistening_res}
)
diags.update(diagnostics_updates_from_constraints)

updated_state = add_tendency(
state=state,
tendencies=tendency_updates_from_constraints,
dt=self.model_timestep,
)
tendencies.update(tendency_updates_from_constraints)

else:
tendencies, diags, updated_state = {}, {}, {}
tendencies = self.tendency_precip_tracker.calculate_tendencies(
updated_state.get(TEMP, state[TEMP]),
updated_state.get(SPHUM, state[SPHUM]),
)

self.tendency_precip_tracker.update_tracked_state(
updated_state.get(TEMP, state[TEMP]),
updated_state.get(SPHUM, state[SPHUM]),
)
diags.update(tendencies)

return tendencies, diags, updated_state

def update_precip(
self, physics_precip, net_moistening_due_to_reservoir,
):
diags = {}

# running average gets reset in this call
precip_rates = self.tendency_precip_tracker.interval_avg_precip_rates(
net_moistening_due_to_reservoir
)
diags.update(precip_rates)

diags[TOTAL_PRECIP] = self.tendency_precip_tracker.accumulated_precip_update(
physics_precip,
diags["reservoir_precip_rate_res_interval_avg"],
self.timestep.total_seconds(),
)
return diags

def get_diagnostics(self, state, tendency):
diags = compute_diagnostics(state, tendency, self.label, self.hydrostatic)
return diags, diags[f"net_moistening_due_to_{self.label}"]
Expand Down Expand Up @@ -463,6 +584,12 @@ def get_reservoir_steppers(
model, config.time_average_inputs
)

_precip_tracker_kwargs = {}
if config.interval_average_precipitation:
_precip_tracker_kwargs["tendency_precip_tracker"] = TendencyPrecipTracker(
reservoir_timestep_seconds=rc_tdelta.total_seconds(),
)

incrementer = ReservoirIncrementOnlyStepper(
model,
init_time,
Expand All @@ -487,5 +614,6 @@ def get_reservoir_steppers(
model_timestep=model_timestep,
hydrostatic=config.hydrostatic,
mse_conserving_limiter=config.mse_conserving_limiter,
**_precip_tracker_kwargs,
)
return incrementer, predictor
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ radiation_scheme: null
reservoir_corrector:
diagnostic_only: false
hydrostatic: false
interval_average_precipitation: false
models:
0: gs://vcm-ml-scratch/rc-model-tile-0
1: gs://vcm-ml-scratch/rc-model-tile-1
Expand Down