From 43846feee955662ad44b23fd8fc093854a9067a5 Mon Sep 17 00:00:00 2001 From: Rob Falck Date: Wed, 26 Jul 2023 08:50:41 -0400 Subject: [PATCH] Updated dymos to utilize new OpenMDAO load_case capability. (#954) * impl load_case method on Phase, update tests * update load_case occurrences * deprecate previous load_case * Updated load case to account for new timeseries updates and to clip interpolated values * updated grid refinement to utilize phase.load_case * revert docs workflow to master version --------- Co-authored-by: swryan --- .../test/test_load_case_missing_phase.py | 8 +- dymos/grid_refinement/refinement.py | 31 +++- .../test/test_grid_refinement.py | 2 +- dymos/load_case.py | 12 +- dymos/phase/phase.py | 169 ++++++++++++++++++ dymos/run_problem.py | 9 +- dymos/test/test_load_case.py | 1 - 7 files changed, 216 insertions(+), 16 deletions(-) diff --git a/dymos/examples/cannonball/test/test_load_case_missing_phase.py b/dymos/examples/cannonball/test/test_load_case_missing_phase.py index ae64d53cc..b3f516d24 100644 --- a/dymos/examples/cannonball/test/test_load_case_missing_phase.py +++ b/dymos/examples/cannonball/test/test_load_case_missing_phase.py @@ -5,11 +5,14 @@ import numpy as np import openmdao.api as om +import openmdao import dymos as dm from dymos.examples.cannonball.size_comp import CannonballSizeComp from dymos.examples.cannonball.cannonball_ode import CannonballODE +om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')]) + @use_tempdirs class TestTwoPhaseCannonballLoadCase(unittest.TestCase): @@ -167,7 +170,10 @@ def test_load_case_missing_phase(self): p.set_val('traj.descent.states:gam', descent.interp('gam', [0, -45]), units='deg') case = om.CaseReader('dymos_solution.db').get_case('final') - dm.load_case(p, previous_solution=case) + if om_version < (3, 26, 1): + dm.load_case(p, previous_solution=case) + else: + p.load_case(case) p.run_model() diff --git a/dymos/grid_refinement/refinement.py b/dymos/grid_refinement/refinement.py index 71ab8680a..381e7ca04 100644 --- a/dymos/grid_refinement/refinement.py +++ b/dymos/grid_refinement/refinement.py @@ -5,12 +5,15 @@ from .hp_adaptive.hp_adaptive import HPAdaptive from .write_iteration import write_error, write_refine_iter +import openmdao.api as om from dymos.grid_refinement.error_estimation import check_error from dymos.load_case import load_case, find_phases import numpy as np import sys +import openmdao + def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_prefix=None, reset_iter_counts=True): """ @@ -33,8 +36,8 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre refinement_methods = {'hp': HPAdaptive, 'ph': PHAdaptive} _case_prefix = '' if case_prefix is None else f'{case_prefix}_' - failed = problem.run_driver(case_prefix=f'{_case_prefix}{refine_method}_0_' - if refine_iteration_limit > 0 else _case_prefix, + case_prefix = f'{_case_prefix}{refine_method}_0_' + failed = problem.run_driver(case_prefix=case_prefix if refine_iteration_limit > 0 else _case_prefix, reset_iter_counts=reset_iter_counts) if refine_iteration_limit > 0: @@ -60,12 +63,25 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre for stream in f, sys.stdout: write_refine_iter(stream, i, phases, refine_results) - prev_soln = {'inputs': problem.model.list_inputs(out_stream=None, units=True, prom_name=True), - 'outputs': problem.model.list_outputs(out_stream=None, units=True, prom_name=True)} - - problem.setup() + om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')]) + if om_version < (3, 27, 1): + prev_soln = {'inputs': problem.model.list_inputs(out_stream=None, units=True, prom_name=True), + 'outputs': problem.model.list_outputs(out_stream=None, units=True, prom_name=True)} - load_case(problem, prev_soln) + problem.setup() + load_case(problem, prev_soln, deprecation_warning=False) + else: + prev_soln = { + 'inputs': problem.model.list_inputs(out_stream=None, return_format='dict', + units=True, prom_name=True), + 'outputs': problem.model.list_outputs(out_stream=None, return_format='dict', + units=True, prom_name=True) + } + + problem.setup() + for phase_path in refined_phases: + phs = problem.model._get_subsystem(phase_path) + phs.load_case(prev_soln) failed = problem.run_driver(case_prefix=f'{_case_prefix}{refine_method}_{i}_') @@ -76,4 +92,5 @@ def _refine_iter(problem, refine_iteration_limit=0, refine_method='hp', case_pre else: print('Successfully completed grid refinement.', file=stream) print(50 * '=') + return failed diff --git a/dymos/grid_refinement/test/test_grid_refinement.py b/dymos/grid_refinement/test/test_grid_refinement.py index 0a07981eb..ffc0b135b 100644 --- a/dymos/grid_refinement/test/test_grid_refinement.py +++ b/dymos/grid_refinement/test/test_grid_refinement.py @@ -45,7 +45,7 @@ def test_refine_hp_non_ode_rate_sources(self): p.driver.declare_coloring() traj = p.model.add_subsystem('traj', dm.Trajectory()) - tx = transcription = dm.Radau(num_segments=5) + tx = dm.Radau(num_segments=5) phase = traj.add_phase('phase0', dm.Phase(ode_class=_BrysonDenhamODE, transcription=tx)) # diff --git a/dymos/load_case.py b/dymos/load_case.py index edd3c2e24..b16f5bfc8 100644 --- a/dymos/load_case.py +++ b/dymos/load_case.py @@ -4,14 +4,12 @@ from openmdao.recorders.case import Case from .phase import AnalyticPhase, Phase from .trajectory import Trajectory -from openmdao.utils.om_warnings import issue_warning +from openmdao.utils.om_warnings import issue_warning, warn_deprecation def find_phases(sys): """ Finds all instances of Dymos Phases within the given system, and returns them as a dictionary. - They are keyed by promoted name if use_prom_path=True, otherwise they are keyed by their - absolute name. Parameters ---------- @@ -59,7 +57,7 @@ def find_trajectories(sys): return traj_paths -def load_case(problem, previous_solution): +def load_case(problem, previous_solution, deprecation_warning=True): """ Populate a guess for the given problem involving Dymos Phases by interpolating results from the previous solution. @@ -72,7 +70,13 @@ def load_case(problem, previous_solution): A dictionary with key 'inputs' mapped to the output of problem.model.list_inputs for a previous iteration, and key 'outputs' mapped to the output of prob.model.list_outputs. Both list_inputs and list_outputs should be called with `units=True` and `prom_names=True`. + deprecation_warning : bool + When False, no deprecation warning will be issued, otherwise warning will be issued. + (defaults to True) """ + if deprecation_warning: + warn_deprecation("The Dymos load_case method is deprecated for OpenMDAO 3.28.0 and later, " + "the load_case method on Problem should be used instead.") # allow old style arguments using a Case or OpenMDAO problem instead of dictionary assert (isinstance(previous_solution, Case) or isinstance(previous_solution, dict)) diff --git a/dymos/phase/phase.py b/dymos/phase/phase.py index 2b93461f7..bb32cb829 100644 --- a/dymos/phase/phase.py +++ b/dymos/phase/phase.py @@ -9,7 +9,10 @@ import openmdao import openmdao.api as om from openmdao.utils.mpi import MPI +from openmdao.utils.om_warnings import issue_warning from openmdao.core.system import System +from openmdao.recorders.case import Case + import dymos as dm from .options import ControlOptionsDictionary, ParameterOptionsDictionary, \ @@ -2643,3 +2646,169 @@ def _is_fixed(self, var_name, var_type, loc): return not self.parameter_options[var_name]['opt'] return False # No way to know so we allow these to go through + + def load_case(self, case): + """ + Pull all input and output variables from a case into the Phase. + + Parameters + ---------- + case : Case or dict + A Case from a CaseReader, or a dictionary with key 'inputs' mapped to the + output of problem.model.list_inputs and key 'outputs' mapped to the output + of prob.model.list_outputs. Both list_inputs and list_outputs should be called + with `units=True`, `prom_names=True` and `return_format='dict'`. + """ + # allow old style arguments using a Case or OpenMDAO problem instead of dictionary + assert (isinstance(case, Case) or isinstance(case, dict)) + if isinstance(case, Case): + previous_solution = { + 'inputs': case.list_inputs(out_stream=None, return_format='dict', + units=True, prom_name=True), + 'outputs': case.list_outputs(out_stream=None, return_format='dict', + units=True, prom_name=True) + } + else: + previous_solution = case + + prev_vars_abs2prom = {} + prev_vars_abs2prom.update({k: v['prom_name'] for k, v in previous_solution['inputs'].items()}) + prev_vars_abs2prom.update({k: v['prom_name'] for k, v in previous_solution['outputs'].items()}) + prev_vars_prom2abs = {v: k for k, v in prev_vars_abs2prom.items()} + + prev_vars = {} + prev_vars.update({v['prom_name']: {'val': v['val'], 'units': v['units'], 'abs_name': k} + for k, v in previous_solution['inputs'].items()}) + prev_vars.update({v['prom_name']: {'val': v['val'], 'units': v['units'], 'abs_name': k} + for k, v in previous_solution['outputs'].items()}) + + phase_io = {'inputs': self.list_inputs(units=True, prom_name=True, out_stream=None), + 'outputs': self.list_outputs(units=True, prom_name=True, out_stream=None)} + + phase_vars = {} + phase_vars.update({f"{self.pathname}.{v['prom_name']}": {'val': v['val'], 'units': v['units'], 'abs_name': k} + for k, v in phase_io['inputs']}) + phase_vars.update({f"{self.pathname}.{v['prom_name']}": {'val': v['val'], 'units': v['units'], 'abs_name': k} + for k, v in phase_io['outputs']}) + + phase_name = self.name + + # Get the initial time and duration from the previous result and set them into the new phase. + integration_name = self.time_options['name'] + + try: + prev_time_path = prev_vars_abs2prom[f'{self.pathname}.timeseries.timeseries_comp.{integration_name}'] + except KeyError: + om.issue_warning(f'load_case for phase {self.name} failed - phase not found in case data.') + return + + prev_timeseries_prom_path, _, _ = prev_time_path.rpartition(f'.{integration_name}') + prev_phase_prom_path, _, _ = prev_timeseries_prom_path.rpartition('.timeseries') + + prev_time_val = prev_vars[prev_time_path]['val'] + prev_time_val, unique_idxs = np.unique(prev_time_val, return_index=True) + prev_time_units = prev_vars[prev_time_path]['units'] + + t_initial = prev_time_val[0] + t_duration = prev_time_val[-1] - prev_time_val[0] + + self.set_val('t_initial', t_initial, units=prev_time_units) + self.set_val('t_duration', t_duration, units=prev_time_units) + + # Interpolate the timeseries state outputs from the previous solution onto the new grid. + if not isinstance(self, dm.AnalyticPhase): + for state_name, options in self.state_options.items(): + if f'{prev_timeseries_prom_path}.states:{state_name}' in prev_vars_prom2abs: + prev_state_path = f'{prev_timeseries_prom_path}.states:{state_name}' + elif f'{prev_timeseries_prom_path}.{state_name}' in prev_vars_prom2abs: + prev_state_path = f'{prev_timeseries_prom_path}.{state_name}' + else: + issue_warning(f'Unable to find state {state_name} in timeseries data from case being loaded.', + om.OpenMDAOWarning) + continue + + prev_state_val = prev_vars[prev_state_path]['val'] + prev_state_units = prev_vars[prev_state_path]['units'] + interp_vals = self.interp(name=state_name, + xs=prev_time_val, + ys=prev_state_val[unique_idxs], + kind='slinear') + if options['lower'] is not None or options['upper'] is not None: + interp_vals = interp_vals.clip(options['lower'], options['upper']) + self.set_val(f'states:{state_name}', + interp_vals, + units=prev_state_units) + try: + self.set_val(f'initial_states:{state_name}', prev_state_val[0, ...], units=prev_state_units) + except KeyError: + pass + + if options['fix_final']: + warning_message = f"{phase_name}.states:{state_name} specifies 'fix_final=True'. " \ + f"If the given restart file has a" \ + f" different final value this will overwrite the user-specified value" + issue_warning(warning_message) + + # Interpolate the timeseries control outputs from the previous solution onto the new grid. + for control_name, options in self.control_options.items(): + if f'{prev_timeseries_prom_path}.controls:{control_name}' in prev_vars_prom2abs: + prev_control_path = f'{prev_timeseries_prom_path}.controls:{control_name}' + elif f'{prev_timeseries_prom_path}.{control_name}' in prev_vars_prom2abs: + prev_control_path = f'{prev_timeseries_prom_path}.{control_name}' + else: + issue_warning(f'Unable to find control {control_name} in timeseries data from case being loaded.', + om.OpenMDAOWarning) + continue + + prev_control_val = prev_vars[prev_control_path]['val'] + prev_control_units = prev_vars[prev_control_path]['units'] + interp_vals = self.interp(name=control_name, + xs=prev_time_val, + ys=prev_control_val[unique_idxs], + kind='slinear') + if options['lower'] is not None or options['upper'] is not None: + interp_vals = interp_vals.clip(options['lower'], options['upper']) + self.set_val(f'controls:{control_name}', interp_vals, units=prev_control_units) + if options['fix_final']: + warning_message = f"{phase_name}.controls:{control_name} specifies 'fix_final=True'. " \ + f"If the given restart file has a" \ + f" different final value this will overwrite the user-specified value" + issue_warning(warning_message) + + # Set the output polynomial control outputs from the previous solution as the value + for pc_name, options in self.polynomial_control_options.items(): + if f'{prev_timeseries_prom_path}.polynomial_controls:{pc_name}' in prev_vars_prom2abs: + prev_pc_path = f'{prev_timeseries_prom_path}.polynomial_controls:{pc_name}' + elif f'{prev_timeseries_prom_path}.{pc_name}' in prev_vars_prom2abs: + prev_pc_path = f'{prev_timeseries_prom_path}.{pc_name}' + else: + issue_warning(f'Unable to find polynomial control {pc_name} in timeseries data from case being ' + f'loaded.', om.OpenMDAOWarning) + continue + + prev_pc_val = prev_vars[prev_pc_path]['val'] + prev_pc_units = prev_vars[prev_pc_path]['units'] + interp_vals = self.interp(name=pc_name, + xs=prev_time_val, + ys=prev_pc_val[unique_idxs], + kind='slinear') + if options['lower'] is not None or options['upper'] is not None: + interp_vals = interp_vals.clip(options['lower'], options['upper']) + self.set_val(f'polynomial_controls:{pc_name}', + interp_vals, + units=prev_pc_units) + if options['fix_final']: + warning_message = f"{phase_name}.polynomial_controls:{pc_name} specifies 'fix_final=True'. " \ + f"If the given restart file has a" \ + f" different final value this will overwrite the user-specified value" + issue_warning(warning_message) + + # Set the timeseries parameter outputs from the previous solution as the parameter value + for param_name in self.parameter_options: + if f'{prev_phase_prom_path}.parameter_vals:{param_name}' in prev_vars: + prev_param_val = prev_vars[f'{prev_phase_prom_path}.parameter_vals:{param_name}']['val'] + prev_param_units = prev_vars[f'{prev_phase_prom_path}.parameter_vals:{param_name}']['units'] + self.set_val(f'parameters:{param_name}', prev_param_val[0, ...], units=prev_param_units) + else: + issue_warning(f'Unable to find "{prev_phase_prom_path}.parameter_vals:{param_name}" ' + f'in data from case being loaded.') diff --git a/dymos/run_problem.py b/dymos/run_problem.py index 3026a7d9f..625edb901 100755 --- a/dymos/run_problem.py +++ b/dymos/run_problem.py @@ -1,10 +1,10 @@ import warnings +import openmdao import openmdao.api as om from openmdao.recorders.case import Case from ._options import options as dymos_options from dymos.trajectory.trajectory import Trajectory -from dymos.load_case import load_case from dymos.visualization.timeseries_plots import timeseries_plots from .grid_refinement.refinement import _refine_iter @@ -83,7 +83,12 @@ def run_problem(problem, refine_method='hp', refine_iteration_limit=0, run_drive problem.final_setup() if restart is not None: - load_case(problem, case) + om_version = tuple([int(s) for s in openmdao.__version__.split('-')[0].split('.')]) + if om_version < (3, 27, 1): + from dymos.load_case import load_case + load_case(problem, case, deprecation_warning=False) + else: + problem.load_case(case) for traj in problem.model.system_iter(include_self=True, recurse=True, typ=Trajectory): traj._check_phase_graph() diff --git a/dymos/test/test_load_case.py b/dymos/test/test_load_case.py index 1ffc27afc..458cc3138 100644 --- a/dymos/test/test_load_case.py +++ b/dymos/test/test_load_case.py @@ -148,7 +148,6 @@ def test_load_case_lgl_to_radau(self): outputs = dict([(o[0], o[1]) for o in case.list_outputs(units=True, shape=True, out_stream=None)]) - print(outputs) time_val = outputs['phase0.timeseries.timeseries_comp.time']['val'] theta_val = outputs['phase0.timeseries.timeseries_comp.theta']['val']