Skip to content

Commit

Permalink
Updated dymos to utilize new OpenMDAO load_case capability. (#954)
Browse files Browse the repository at this point in the history
* impl load_case method on Phase, update tests

* update load_case occurrences

* deprecate previous load_case

* Updated load case to account for new timeseries updates and to clip interpolated values

* updated grid refinement to utilize phase.load_case

* revert docs workflow to master version

---------

Co-authored-by: swryan <[email protected]>
  • Loading branch information
robfalck and swryan authored Jul 26, 2023
1 parent 79ca083 commit 43846fe
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

import numpy as np
import openmdao.api as om
import openmdao

import dymos as dm
from dymos.examples.cannonball.size_comp import CannonballSizeComp
from dymos.examples.cannonball.cannonball_ode import CannonballODE

om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')])


@use_tempdirs
class TestTwoPhaseCannonballLoadCase(unittest.TestCase):
Expand Down Expand Up @@ -167,7 +170,10 @@ def test_load_case_missing_phase(self):
p.set_val('traj.descent.states:gam', descent.interp('gam', [0, -45]), units='deg')

case = om.CaseReader('dymos_solution.db').get_case('final')
dm.load_case(p, previous_solution=case)
if om_version < (3, 26, 1):
dm.load_case(p, previous_solution=case)
else:
p.load_case(case)

p.run_model()

Expand Down
31 changes: 24 additions & 7 deletions dymos/grid_refinement/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from .hp_adaptive.hp_adaptive import HPAdaptive
from .write_iteration import write_error, write_refine_iter

import openmdao.api as om
from dymos.grid_refinement.error_estimation import check_error
from dymos.load_case import load_case, find_phases

import numpy as np
import sys

import openmdao


def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_prefix=None, reset_iter_counts=True):
"""
Expand All @@ -33,8 +36,8 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre
refinement_methods = {'hp': HPAdaptive, 'ph': PHAdaptive}
_case_prefix = '' if case_prefix is None else f'{case_prefix}_'

failed = problem.run_driver(case_prefix=f'{_case_prefix}{refine_method}_0_'
if refine_iteration_limit > 0 else _case_prefix,
case_prefix = f'{_case_prefix}{refine_method}_0_'
failed = problem.run_driver(case_prefix=case_prefix if refine_iteration_limit > 0 else _case_prefix,
reset_iter_counts=reset_iter_counts)

if refine_iteration_limit > 0:
Expand All @@ -60,12 +63,25 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre
for stream in f, sys.stdout:
write_refine_iter(stream, i, phases, refine_results)

prev_soln = {'inputs': problem.model.list_inputs(out_stream=None, units=True, prom_name=True),
'outputs': problem.model.list_outputs(out_stream=None, units=True, prom_name=True)}

problem.setup()
om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')])
if om_version < (3, 27, 1):
prev_soln = {'inputs': problem.model.list_inputs(out_stream=None, units=True, prom_name=True),
'outputs': problem.model.list_outputs(out_stream=None, units=True, prom_name=True)}

load_case(problem, prev_soln)
problem.setup()
load_case(problem, prev_soln, deprecation_warning=False)
else:
prev_soln = {
'inputs': problem.model.list_inputs(out_stream=None, return_format='dict',
units=True, prom_name=True),
'outputs': problem.model.list_outputs(out_stream=None, return_format='dict',
units=True, prom_name=True)
}

problem.setup()
for phase_path in refined_phases:
phs = problem.model._get_subsystem(phase_path)
phs.load_case(prev_soln)

failed = problem.run_driver(case_prefix=f'{_case_prefix}{refine_method}_{i}_')

Expand All @@ -76,4 +92,5 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre
else:
print('Successfully completed grid refinement.', file=stream)
print(50 * '=')

return failed
2 changes: 1 addition & 1 deletion dymos/grid_refinement/test/test_grid_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_refine_hp_non_ode_rate_sources(self):
p.driver.declare_coloring()

traj = p.model.add_subsystem('traj', dm.Trajectory())
tx = transcription = dm.Radau(num_segments=5)
tx = dm.Radau(num_segments=5)
phase = traj.add_phase('phase0', dm.Phase(ode_class=_BrysonDenhamODE, transcription=tx))

#
Expand Down
12 changes: 8 additions & 4 deletions dymos/load_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
from openmdao.recorders.case import Case
from .phase import AnalyticPhase, Phase
from .trajectory import Trajectory
from openmdao.utils.om_warnings import issue_warning
from openmdao.utils.om_warnings import issue_warning, warn_deprecation


def find_phases(sys):
"""
Finds all instances of Dymos Phases within the given system, and returns them as a dictionary.
They are keyed by promoted name if use_prom_path=True, otherwise they are keyed by their
absolute name.
Parameters
----------
Expand Down Expand Up @@ -59,7 +57,7 @@ def find_trajectories(sys):
return traj_paths


def load_case(problem, previous_solution):
def load_case(problem, previous_solution, deprecation_warning=True):
"""
Populate a guess for the given problem involving Dymos Phases by interpolating results
from the previous solution.
Expand All @@ -72,7 +70,13 @@ def load_case(problem, previous_solution):
A dictionary with key 'inputs' mapped to the output of problem.model.list_inputs for
a previous iteration, and key 'outputs' mapped to the output of prob.model.list_outputs.
Both list_inputs and list_outputs should be called with `units=True` and `prom_names=True`.
deprecation_warning : bool
When False, no deprecation warning will be issued, otherwise warning will be issued.
(defaults to True)
"""
if deprecation_warning:
warn_deprecation("The Dymos load_case method is deprecated for OpenMDAO 3.28.0 and later, "
"the load_case method on Problem should be used instead.")

# allow old style arguments using a Case or OpenMDAO problem instead of dictionary
assert (isinstance(previous_solution, Case) or isinstance(previous_solution, dict))
Expand Down
169 changes: 169 additions & 0 deletions dymos/phase/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
import openmdao
import openmdao.api as om
from openmdao.utils.mpi import MPI
from openmdao.utils.om_warnings import issue_warning
from openmdao.core.system import System
from openmdao.recorders.case import Case

import dymos as dm

from .options import ControlOptionsDictionary, ParameterOptionsDictionary, \
Expand Down Expand Up @@ -2643,3 +2646,169 @@ def _is_fixed(self, var_name, var_type, loc):
return not self.parameter_options[var_name]['opt']

return False # No way to know so we allow these to go through

def load_case(self, case):
"""
Pull all input and output variables from a case into the Phase.
Parameters
----------
case : Case or dict
A Case from a CaseReader, or a dictionary with key 'inputs' mapped to the
output of problem.model.list_inputs and key 'outputs' mapped to the output
of prob.model.list_outputs. Both list_inputs and list_outputs should be called
with `units=True`, `prom_names=True` and `return_format='dict'`.
"""
# allow old style arguments using a Case or OpenMDAO problem instead of dictionary
assert (isinstance(case, Case) or isinstance(case, dict))
if isinstance(case, Case):
previous_solution = {
'inputs': case.list_inputs(out_stream=None, return_format='dict',
units=True, prom_name=True),
'outputs': case.list_outputs(out_stream=None, return_format='dict',
units=True, prom_name=True)
}
else:
previous_solution = case

prev_vars_abs2prom = {}
prev_vars_abs2prom.update({k: v['prom_name'] for k, v in previous_solution['inputs'].items()})
prev_vars_abs2prom.update({k: v['prom_name'] for k, v in previous_solution['outputs'].items()})
prev_vars_prom2abs = {v: k for k, v in prev_vars_abs2prom.items()}

prev_vars = {}
prev_vars.update({v['prom_name']: {'val': v['val'], 'units': v['units'], 'abs_name': k}
for k, v in previous_solution['inputs'].items()})
prev_vars.update({v['prom_name']: {'val': v['val'], 'units': v['units'], 'abs_name': k}
for k, v in previous_solution['outputs'].items()})

phase_io = {'inputs': self.list_inputs(units=True, prom_name=True, out_stream=None),
'outputs': self.list_outputs(units=True, prom_name=True, out_stream=None)}

phase_vars = {}
phase_vars.update({f"{self.pathname}.{v['prom_name']}": {'val': v['val'], 'units': v['units'], 'abs_name': k}
for k, v in phase_io['inputs']})
phase_vars.update({f"{self.pathname}.{v['prom_name']}": {'val': v['val'], 'units': v['units'], 'abs_name': k}
for k, v in phase_io['outputs']})

phase_name = self.name

# Get the initial time and duration from the previous result and set them into the new phase.
integration_name = self.time_options['name']

try:
prev_time_path = prev_vars_abs2prom[f'{self.pathname}.timeseries.timeseries_comp.{integration_name}']
except KeyError:
om.issue_warning(f'load_case for phase {self.name} failed - phase not found in case data.')
return

prev_timeseries_prom_path, _, _ = prev_time_path.rpartition(f'.{integration_name}')
prev_phase_prom_path, _, _ = prev_timeseries_prom_path.rpartition('.timeseries')

prev_time_val = prev_vars[prev_time_path]['val']
prev_time_val, unique_idxs = np.unique(prev_time_val, return_index=True)
prev_time_units = prev_vars[prev_time_path]['units']

t_initial = prev_time_val[0]
t_duration = prev_time_val[-1] - prev_time_val[0]

self.set_val('t_initial', t_initial, units=prev_time_units)
self.set_val('t_duration', t_duration, units=prev_time_units)

# Interpolate the timeseries state outputs from the previous solution onto the new grid.
if not isinstance(self, dm.AnalyticPhase):
for state_name, options in self.state_options.items():
if f'{prev_timeseries_prom_path}.states:{state_name}' in prev_vars_prom2abs:
prev_state_path = f'{prev_timeseries_prom_path}.states:{state_name}'
elif f'{prev_timeseries_prom_path}.{state_name}' in prev_vars_prom2abs:
prev_state_path = f'{prev_timeseries_prom_path}.{state_name}'
else:
issue_warning(f'Unable to find state {state_name} in timeseries data from case being loaded.',
om.OpenMDAOWarning)
continue

prev_state_val = prev_vars[prev_state_path]['val']
prev_state_units = prev_vars[prev_state_path]['units']
interp_vals = self.interp(name=state_name,
xs=prev_time_val,
ys=prev_state_val[unique_idxs],
kind='slinear')
if options['lower'] is not None or options['upper'] is not None:
interp_vals = interp_vals.clip(options['lower'], options['upper'])
self.set_val(f'states:{state_name}',
interp_vals,
units=prev_state_units)
try:
self.set_val(f'initial_states:{state_name}', prev_state_val[0, ...], units=prev_state_units)
except KeyError:
pass

if options['fix_final']:
warning_message = f"{phase_name}.states:{state_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
issue_warning(warning_message)

# Interpolate the timeseries control outputs from the previous solution onto the new grid.
for control_name, options in self.control_options.items():
if f'{prev_timeseries_prom_path}.controls:{control_name}' in prev_vars_prom2abs:
prev_control_path = f'{prev_timeseries_prom_path}.controls:{control_name}'
elif f'{prev_timeseries_prom_path}.{control_name}' in prev_vars_prom2abs:
prev_control_path = f'{prev_timeseries_prom_path}.{control_name}'
else:
issue_warning(f'Unable to find control {control_name} in timeseries data from case being loaded.',
om.OpenMDAOWarning)
continue

prev_control_val = prev_vars[prev_control_path]['val']
prev_control_units = prev_vars[prev_control_path]['units']
interp_vals = self.interp(name=control_name,
xs=prev_time_val,
ys=prev_control_val[unique_idxs],
kind='slinear')
if options['lower'] is not None or options['upper'] is not None:
interp_vals = interp_vals.clip(options['lower'], options['upper'])
self.set_val(f'controls:{control_name}', interp_vals, units=prev_control_units)
if options['fix_final']:
warning_message = f"{phase_name}.controls:{control_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
issue_warning(warning_message)

# Set the output polynomial control outputs from the previous solution as the value
for pc_name, options in self.polynomial_control_options.items():
if f'{prev_timeseries_prom_path}.polynomial_controls:{pc_name}' in prev_vars_prom2abs:
prev_pc_path = f'{prev_timeseries_prom_path}.polynomial_controls:{pc_name}'
elif f'{prev_timeseries_prom_path}.{pc_name}' in prev_vars_prom2abs:
prev_pc_path = f'{prev_timeseries_prom_path}.{pc_name}'
else:
issue_warning(f'Unable to find polynomial control {pc_name} in timeseries data from case being '
f'loaded.', om.OpenMDAOWarning)
continue

prev_pc_val = prev_vars[prev_pc_path]['val']
prev_pc_units = prev_vars[prev_pc_path]['units']
interp_vals = self.interp(name=pc_name,
xs=prev_time_val,
ys=prev_pc_val[unique_idxs],
kind='slinear')
if options['lower'] is not None or options['upper'] is not None:
interp_vals = interp_vals.clip(options['lower'], options['upper'])
self.set_val(f'polynomial_controls:{pc_name}',
interp_vals,
units=prev_pc_units)
if options['fix_final']:
warning_message = f"{phase_name}.polynomial_controls:{pc_name} specifies 'fix_final=True'. " \
f"If the given restart file has a" \
f" different final value this will overwrite the user-specified value"
issue_warning(warning_message)

# Set the timeseries parameter outputs from the previous solution as the parameter value
for param_name in self.parameter_options:
if f'{prev_phase_prom_path}.parameter_vals:{param_name}' in prev_vars:
prev_param_val = prev_vars[f'{prev_phase_prom_path}.parameter_vals:{param_name}']['val']
prev_param_units = prev_vars[f'{prev_phase_prom_path}.parameter_vals:{param_name}']['units']
self.set_val(f'parameters:{param_name}', prev_param_val[0, ...], units=prev_param_units)
else:
issue_warning(f'Unable to find "{prev_phase_prom_path}.parameter_vals:{param_name}" '
f'in data from case being loaded.')
9 changes: 7 additions & 2 deletions dymos/run_problem.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings

import openmdao
import openmdao.api as om
from openmdao.recorders.case import Case
from ._options import options as dymos_options
from dymos.trajectory.trajectory import Trajectory
from dymos.load_case import load_case
from dymos.visualization.timeseries_plots import timeseries_plots

from .grid_refinement.refinement import _refine_iter
Expand Down Expand Up @@ -83,7 +83,12 @@ def run_problem(problem, refine_method='hp', refine_iteration_limit=0, run_drive
problem.final_setup()

if restart is not None:
load_case(problem, case)
om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')])
if om_version < (3, 27, 1):
from dymos.load_case import load_case
load_case(problem, case, deprecation_warning=False)
else:
problem.load_case(case)

for traj in problem.model.system_iter(include_self=True, recurse=True, typ=Trajectory):
traj._check_phase_graph()
Expand Down
1 change: 0 additions & 1 deletion dymos/test/test_load_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def test_load_case_lgl_to_radau(self):
outputs = dict([(o[0], o[1]) for o in case.list_outputs(units=True, shape=True,
out_stream=None)])

print(outputs)
time_val = outputs['phase0.timeseries.timeseries_comp.time']['val']
theta_val = outputs['phase0.timeseries.timeseries_comp.theta']['val']

Expand Down

0 comments on commit 43846fe

Please sign in to comment.