Skip to content

Commit

Permalink
Fixed an issue where grid refinement tests were still expecting some …
Browse files Browse the repository at this point in the history
…things to be in the timeseries. (#946)

* Fixed an issue where grid refinement tests were still expecting some things to be in the timeseries.

* error estimation now uses appropriate timeseries prefix.
  • Loading branch information
robfalck authored Jun 30, 2023
1 parent 32352cb commit 5a5e729
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
43 changes: 24 additions & 19 deletions dymos/grid_refinement/error_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,11 @@ def eval_ode_on_grid(phase, transcription):
if t_duration_targets:
p_refine.set_val(f't_duration', t_duration)

state_prefix = 'states:' if phase.timeseries_options['use_prefix'] else ''
control_prefix = 'controls:' if phase.timeseries_options['use_prefix'] else ''

for name, options in phase.state_options.items():
x_prev = phase.get_val(f'timeseries.{name}', units=options['units'])
x_prev = phase.get_val(f'timeseries.{state_prefix}{name}', units=options['units'])
x[name] = np.dot(L, x_prev)
targets = get_targets(ode, name, options['targets'])
if targets:
Expand All @@ -178,40 +181,42 @@ def eval_ode_on_grid(phase, transcription):
rate_targets = get_targets(ode, f'{name}_rate', options['rate_targets'])
rate2_targets = get_targets(ode, f'{name}_rate12', options['rate2_targets'])

u_prev = phase.get_val(f'timeseries.{name}', units=options['units'])
u_prev = phase.get_val(f'timeseries.{control_prefix}{name}', units=options['units'])
u[name] = np.dot(L, u_prev)
if targets:
p_refine.set_val(f'controls:{name}', u[name])

if rate_targets:
u_rate_prev = phase.get_val(f'timeseries.{name}_rate')
u_rate[name] = np.dot(L, u_rate_prev)
p_refine.set_val(f'control_rates:{name}_rate', u_rate[name])
if phase.timeseries_options['include_control_rates']:
if rate_targets:
u_rate_prev = phase.get_val(f'timeseries.control_rates:{name}_rate')
u_rate[name] = np.dot(L, u_rate_prev)
p_refine.set_val(f'control_rates:{name}_rate', u_rate[name])

if rate2_targets:
u_rate2_prev = phase.get_val(f'timeseries.{name}_rate2')
u_rate2[name] = np.dot(L, u_rate2_prev)
p_refine.set_val(f'control_rates:{name}_rate2', u_rate2[name])
if rate2_targets:
u_rate2_prev = phase.get_val(f'timeseries.control_rates:{name}_rate2')
u_rate2[name] = np.dot(L, u_rate2_prev)
p_refine.set_val(f'control_rates:{name}_rate2', u_rate2[name])

for name, options in phase.polynomial_control_options.items():
targets = get_targets(ode, name, options['targets'])
rate_targets = get_targets(ode, f'{name}_rate', options['rate_targets'])
rate2_targets = get_targets(ode, f'{name}_rate2', options['rate2_targets'])

p_prev = phase.get_val(f'timeseries.{name}', units=options['units'])
p_prev = phase.get_val(f'timeseries.polynomial_controls:{name}', units=options['units'])
p[name] = np.dot(L, p_prev)
if targets:
p_refine.set_val(f'polynomial_controls:{name}', p[name])

p_rate_prev = phase.get_val(f'timeseries.{name}_rate')
p_rate[name] = np.dot(L, p_rate_prev)
if rate_targets:
p_refine.set_val(f'polynomial_control_rates:{name}_rate', p_rate[name])
if phase.timeseries_options['include_control_rates']:
p_rate_prev = phase.get_val(f'timeseries.polynomial_control_rates:{name}_rate')
p_rate[name] = np.dot(L, p_rate_prev)
if rate_targets:
p_refine.set_val(f'polynomial_control_rates:{name}_rate', p_rate[name])

p_rate2_prev = phase.get_val(f'timeseries.{name}_rate2')
p_rate2[name] = np.dot(L, p_rate2_prev)
if rate2_targets:
p_refine.set_val(f'polynomial_control_rates:{name}_rate2', p_rate2[name])
p_rate2_prev = phase.get_val(f'timeseries.polynomial_control_rates:{name}_rate2')
p_rate2[name] = np.dot(L, p_rate2_prev)
if rate2_targets:
p_refine.set_val(f'polynomial_control_rates:{name}_rate2', p_rate2[name])

# Configure the parameters
for name, options in phase.parameter_options.items():
Expand Down
8 changes: 5 additions & 3 deletions dymos/grid_refinement/test/test_error_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def _run_brachistochrone(self, transcription_class=dm.Radau, control_type='contr

traj = dm.Trajectory()
phase = dm.Phase(ode_class=BrachistochroneODE, transcription=tx)
phase.timeseries_options['use_prefix'] = True
phase.timeseries_options['include_state_rates'] = True
p.model.add_subsystem('traj0', traj)
traj.add_phase('phase0', phase)

Expand Down Expand Up @@ -104,17 +106,17 @@ def test_compute_state_quadratures(self):
print(f'{tx_class.__name__} - {control_type} - g = {g}')

for name, options in phase.control_options.items():
u_solution = phase.get_val(f'timeseries.{name}')
u_solution = phase.get_val(f'timeseries.controls:{name}')
print(f'{name} interpolation error',
max(np.abs(u[name].ravel() - u_solution.ravel())))

for name, options in phase.polynomial_control_options.items():
p_solution = phase.get_val(f'timeseries.{name}')
p_solution = phase.get_val(f'timeseries.polynomial_controls:{name}')
print(f'{name} interpolation error',
max(np.abs(p[name].ravel() - p_solution.ravel())))

for name, options in phase.state_options.items():
x_solution = phase.get_val(f'timeseries.{name}')
x_solution = phase.get_val(f'timeseries.states:{name}')
f_solution = phase.get_val(f'timeseries.state_rates:{name}')

print(f'{name} interpolation error', max(np.abs(x[name].ravel() - x_solution.ravel())))
Expand Down

0 comments on commit 5a5e729

Please sign in to comment.