Skip to content

Commit

Permalink
Add integration test for prescribed jext
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-brown committed Jul 16, 2024
1 parent a50e64b commit 0c8d5e6
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 53 deletions.
10 changes: 8 additions & 2 deletions torax/sources/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,22 @@ def _build_psi_profiles(
dict of psi source profiles.
"""
psi_profiles = {}
# jext is precomputed in the core profiles.
# jext is computed separately, as it is stored separately in the SourceModels
dynamic_jext_runtime_params = dynamic_runtime_params_slice.sources[
source_models.jext_name
]
_, jext_profile = source_models.jext.get_value(
dynamic_runtime_params_slice,
dynamic_jext_runtime_params,
geo,
core_profiles,
)
psi_profiles[source_models.jext_name] = jax_utils.select(
jnp.logical_or(
explicit == dynamic_jext_runtime_params.is_explicit,
calculate_anyway,
),
core_profiles.currents.jext,
jext_profile,
jnp.zeros_like(geo.r),
)
# Iterate through the rest of the sources and compute profiles for the ones
Expand Down
Binary file added torax/tests/test_data/test_prescribed_jext.nc
Binary file not shown.
79 changes: 79 additions & 0 deletions torax/tests/test_data/test_prescribed_jext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Tests driving a prescribed, time-varying external current source.
Constant transport coefficient model, circular geometry.
"""
import jax.numpy as jnp

# Define the jext profile
def gaussian(r, center, width, amplitude):
return amplitude * jnp.exp(-((r - center) ** 2) / (2 * width ** 2))

r = jnp.linspace(0, 1, 32)
jext_profile_0 = gaussian(r, center=0.35, width=0.05, amplitude=1e6)
jext_profile_1 = gaussian(r, center=0.15, width=0.1, amplitude=1e6)


# Create the config
CONFIG = {
'runtime_params': {
'profile_conditions': {
'set_pedestal': False,
'nbar': 0.85,
# set flat Ohmic current to provide larger range of current
# evolution for test
# 'nu': 0,
},
'numerics': {
'ion_heat_eq': True,
'el_heat_eq': True,
'dens_eq': True,
'current_eq': True,
'resistivity_mult': 100, # to shorten current diffusion time
't_final': 5,
},
},
'geometry': {
'geometry_type': 'circular',
},
'sources': {
# Only drive the external current source
'jext': {
'mode': 'prescribed',
'prescribed_values': {
0: {r_i.item(): jext_i.item() for r_i, jext_i in zip(r, jext_profile_0)},
2.5: {r_i.item(): jext_i.item() for r_i, jext_i in zip(r, jext_profile_1)},
},
# Disable the formula-based jext term by setting its width to 0
'wext': 0.0,
},
# Disable density sources/sinks
'nbi_particle_source': {
'S_nbi_tot': 0.0,
},
'gas_puff_source': {
'S_puff_tot': 0.0,
},
'pellet_source': {
'S_pellet_tot': 0.0,
},
# Use default heat sources
'generic_ion_el_heat_source': {},
'qei_source': {},
},
'transport': {
'transport_model': 'constant',
'constant_params': {
# diffusion coefficient in electron density equation in m^2/s
'De_const': 0.5,
# convection coefficient in electron density equation in m^2/s
'Ve_const': -0.2,
},
},
'stepper': {
'stepper_type': 'linear',
'predictor_corrector': False,
},
'time_step_calculator': {
'calculator_type': 'chi',
},
}
Binary file not shown.
51 changes: 0 additions & 51 deletions torax/tests/test_data/test_prescribed_particle_source.py

This file was deleted.

0 comments on commit 0c8d5e6

Please sign in to comment.