Skip to content

Commit

Permalink
Change the core_sources to always be the merged profiles
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723090670
  • Loading branch information
tamaranorman authored and Torax team committed Feb 4, 2025
1 parent 75d848b commit e410d09
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 51 deletions.
2 changes: 1 addition & 1 deletion torax/fvm/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def _calc_coeffs_full(
v_face=v_face,
source_mat_cell=source_mat_cell,
source_cell=source_cell,
auxiliary_outputs=(implicit_source_profiles, transport_coeffs),
auxiliary_outputs=(merged_source_profiles, transport_coeffs),
)

return coeffs
Expand Down
9 changes: 0 additions & 9 deletions torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,6 @@ def __call__(
explicit=True,
)

# The previous time step's state has an incomplete set of source profiles
# which was computed based on the previous time step's "guess" of the core
# profiles at this time step's t. We can merge those "implicit" source
# profiles with the explicit ones computed here.
input_state.core_sources = source_profiles_lib.SourceProfiles.merge(
explicit_source_profiles=explicit_source_profiles,
implicit_source_profiles=input_state.core_sources,
)

dt, time_step_calculator_state = self.init_time_step_calculator(
dynamic_runtime_params_slice_t,
geo_t,
Expand Down
24 changes: 0 additions & 24 deletions torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.sources import source_models as source_models_lib
from torax.sources import source_profile_builders
from torax.sources import source_profiles as source_profiles_lib
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import time_step_calculator as ts
Expand Down Expand Up @@ -622,29 +621,6 @@ def _run_simulation(
# The "sim_state" here has been updated by the loop above.
_log_timestep(sim_state)

# Update the final time step's source profiles based on the explicit source
# profiles computed based on the final state.
logging.info("Updating last step's source profiles.")
dynamic_runtime_params_slice, geo = (
runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry(
t=sim_state.t,
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
geometry_provider=geometry_provider,
)
)
explicit_source_profiles = source_profile_builders.build_source_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo,
core_profiles=sim_state.core_profiles,
source_models=step_fn.stepper.source_models,
explicit=True,
)
sim_state.core_sources = source_profiles_lib.SourceProfiles.merge(
explicit_source_profiles=explicit_source_profiles,
implicit_source_profiles=sim_state.core_sources,
)

# If the first step of the simulation was very long, call it out. It might
# have to do with tracing the jitted step_fn.
std_devs = 2 # Check if the first step is more than 2 std devs longer.
Expand Down
16 changes: 6 additions & 10 deletions torax/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,16 +507,12 @@ class ToraxSimState:
dt: timestep interval.
core_profiles: Core plasma profiles at time t.
core_transport: Core plasma transport coefficients computed at time t.
core_sources: Profiles for all sources/sinks. For any state-dependent source
models, the profiles in this dataclass are computed based on the core
profiles at time t, almost. When running `sim.run_simulation()`, any
profile from an "explicit" state-dependent source will be computed with
the core profiles at time t. Any profile from an "implicit"
state-dependent source will be computed with an intermediate state from
the previous time step's solver. This should be close to the core profiles
at time t, but is not guaranteed to be. In case exact source profiles are
required for each time step, they must be recomputed manually after
running `run_simulation()`.
core_sources: Profiles for all sources/sinks. These are the profiles that
are used to calculate the coefficients for the t+dt time step. For the
explicit sources, these are calculated at the start of the time step, so
are the values at time t. For the implicit sources, these are the most
recent guess for time t+dt. The profiles here are the merged version of
the explicit and implicit profiles.
post_processed_outputs: variables for output or intermediate observations
for overarching workflows, calculated after each simulation step.
geometry: Geometry at this time step used for the simulation.
Expand Down
24 changes: 17 additions & 7 deletions torax/stepper/stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def __call__(
Returns:
new_core_profiles: Updated core profiles.
core_sources: Source profiles of the implicit sources, computed at the
most recent guess for time t+dt. Any state-dependent source profiles
will not be computed based on the exact state of the core profiles at
time t+dt, but rather they will be computed based on the final guess the
solver used while calculating coeffs in the solver.
core_sources: Merged source profiles of all sources, including explicit
and implicit. This is the version of the source profiles that is used
to calculate the coefficients for the t+dt time step. For the explicit
sources, this is the same as the explicit_source_profiles input. For
the implicit sources, this is the most recent guess for time t+dt.
core_transport: Transport coefficients for time t+dt.
stepper_numeric_output: Error and iteration info.
"""
Expand Down Expand Up @@ -144,9 +144,19 @@ def __call__(
)
else:
x_new = tuple()
core_sources = source_profile_builders.build_all_zero_profiles(
# Calculate implicit source profiles and return the merged version. This
# is useful for inspecting prescribed sources in the output state.
implicit_source_profiles = source_profile_builders.build_source_profiles(
source_models=self.source_models,
geo=geo_t,
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt,
static_runtime_params_slice=static_runtime_params_slice,
geo=geo_t_plus_dt,
core_profiles=core_profiles_t_plus_dt,
explicit=False,
)
core_sources = source_profiles.SourceProfiles.merge(
explicit_source_profiles=explicit_source_profiles,
implicit_source_profiles=implicit_source_profiles,
)
core_transport = state.CoreTransport.zeros(geo_t)
stepper_numeric_output = state.StepperNumericOutputs()
Expand Down

0 comments on commit e410d09

Please sign in to comment.