Skip to content

Commit

Permalink
Fixed an error where shape introspection wasn't happening on control …
Browse files Browse the repository at this point in the history
…rate2 targets
  • Loading branch information
robfalck committed Aug 14, 2023
1 parent 7bd1ff1 commit 5817db0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 77 deletions.
93 changes: 45 additions & 48 deletions dymos/trajectory/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..phase.options import TrajParameterOptionsDictionary
from ..transcriptions.common import ParameterComp
from ..utils.misc import get_rate_units, _unspecified
from ..utils.introspection import get_promoted_vars, get_source_metadata
from ..utils.introspection import get_promoted_vars, get_source_metadata, _get_common_metadata
from .._options import options as dymos_options


Expand Down Expand Up @@ -276,6 +276,29 @@ def add_parameter(self, name, units=_unspecified, val=_unspecified, desc=_unspec
upper=upper, scaler=scaler, adder=adder, ref0=ref0, ref=ref, shape=shape,
dynamic=dynamic, static_target=static_target)

def _get_phase_parameters(self):
"""
Retrieve a dict of parameter options for each phase within the trajectory.
Returns
-------
dict
A dictionary keyed by phase name. Each associated value is a dictionary
keyed by parameter name and the associated values are parameter options
for each parameter.
"""
phase_param_options = {}
for phs in self.phases._subsystems_myproc:
phase_param_options.update({phs.name: phs.parameter_options})

if self.comm.size > 1:
data = self.comm.gather(phase_param_options, root=0)
if data:
for d in data:
phase_param_options.update(d)
return phase_param_options

def _setup_parameters(self):
"""
Adds an IndepVarComp if necessary and issues appropriate connections based
Expand Down Expand Up @@ -395,7 +418,6 @@ def _configure_parameters(self):
"""
parameter_options = self.parameter_options
promoted_inputs = []

for name, options in parameter_options.items():
promoted_inputs.append(f'parameters:{name}')
targets = options['targets']
Expand All @@ -414,21 +436,16 @@ def _configure_parameters(self):
# For each phase, use introspection to get the units and shape.
# If units do not match across all phases, require user to set them.
# If shapes do not match across all phases, this is an error.
tgts = []
tgt_units = {}
tgt_shapes = {}
tgt_vals = {}
targets_per_phase = {}

for phase_name, phs in self._phases.items():
target_param = None

if targets is None or phase_name not in targets:
# Attempt to connect to an input parameter of the same name in the phase, if
# it exists.
if name in phs.parameter_options:
tgt = f'{phase_name}.parameters:{name}'
tgt_shapes[phs.name] = phs.parameter_options[name]['shape']
tgt_units[phs.name] = phs.parameter_options[name]['units']
tgt_vals[phs.name] = phs.parameter_options[name]['val']
target_param = name
else:
continue
elif targets[phase_name] is None:
Expand All @@ -437,10 +454,7 @@ def _configure_parameters(self):
elif isinstance(targets[phase_name], str):
if targets[phase_name] in phs.parameter_options:
# Connect to an input parameter with a different name in this phase
tgt = f'{phase_name}.parameters:{targets[phase_name]}'
tgt_shapes[phs.name] = phs.parameter_options[targets[phase_name]]['shape']
tgt_units[phs.name] = phs.parameter_options[targets[phase_name]]['units']
tgt_vals[phs.name] = phs.parameter_options[targets[phase_name]]['val']
target_param = targets[phase_name]
else:
msg = f'Invalid target for trajectory `{self.pathname}` parameter `{name}` in phase ' \
f"`{phase_name}`.\nTarget for phase `{phase_name}` is '{targets[phase_name]}' but " \
Expand All @@ -450,10 +464,7 @@ def _configure_parameters(self):
if name in phs.parameter_options:
# User gave a list of ODE targets which were passed to the creation of a
# new input parameter in setup, just connect to that new input parameter
tgt = f'{phase_name}.parameters:{name}'
tgt_shapes[phs.name] = phs.parameter_options[name]['shape']
tgt_units[phs.name] = phs.parameter_options[name]['units']
tgt_vals[phs.name] = phs.parameter_options[name]['val']
target_param = name
else:
msg = f'Invalid target for trajectory `{self.pathname}` parameter `{name}` in phase ' \
f"`{phase_name}`.\nThe phase did not add the parameter as expected. Please file an " \
Expand All @@ -463,9 +474,11 @@ def _configure_parameters(self):
raise ValueError(f'Unhandled target(s) ({targets[phase_name]}) for parameter {name} in '
f'phase {phase_name}. If connecting to ODE inputs in the phase, '
f'format the targets as a sequence of strings.')
tgts.append(tgt)

if not tgts:
if target_param is not None:
targets_per_phase[phase_name] = target_param

if not targets_per_phase:
# Find the reason
if targets is None:
reason = f'Option `targets=None` but no phase in the trajectory has a parameter named `{name}`.'
Expand All @@ -475,36 +488,19 @@ def _configure_parameters(self):
reason = ''
raise ValueError(f'No target was found for trajectory parameter `{name}` in any phase.\n{reason}')

if options['shape'] in {_unspecified, None}:
if len(set(tgt_shapes.values())) == 1:
options['shape'] = next(iter(tgt_shapes.values()))
else:
raise ValueError(f'Parameter {name} in Trajectory {self.pathname} is connected to '
f'targets in multiple phases that have different shapes.')
# If metadata is unspecified, use introspection to find
# it based on common values among the targets.
params_by_phase = self._get_phase_parameters()

targets = {phase_name: phs_params[targets_per_phase[phase_name]]
for phase_name, phs_params in params_by_phase.items()
if phase_name in targets_per_phase and targets_per_phase[phase_name] in phs_params}

if options['units'] is _unspecified:
tgt_units_set = set(tgt_units.values())
if len(tgt_units_set) == 1:
options['units'] = tgt_units_set.pop()
else:
ValueError(f'Parameter {name} in Trajectory {self.pathname} is connected to '
f'targets in multiple phases that have different units. You must '
f'explicitly provide units for the parameter since they cannot be '
f'inferred.')

if options['val'] is _unspecified:
val_list = list(tgt_vals.values())
unique_val = True
for val in val_list[1:]:
if not np.array_equal(val_list[0], val, equal_nan=True):
unique_val = False
if unique_val:
options['val'] = val_list[0]
else:
raise ValueError(f'Unable to automatically assign {metadata_key} based on targets. \n'
f'Targets have multiple values assigned: {err_dict}. \n'
f'Either promote targets and use set_input_defaults to assign common '
f'{metadata_key}, or explicitly provide {metadata_key} to the variable.')
options['units'] = _get_common_metadata(targets, metadata_key='units')

if options['shape'] in {None, _unspecified}:
options['shape'] = _get_common_metadata(targets, metadata_key='shape')

param_comp = self._get_subsystem('param_comp')
param_comp.add_parameter(name, val=options['val'], shape=options['shape'], units=options['units'])
Expand All @@ -520,6 +516,7 @@ def _configure_parameters(self):
ref0=options['ref0'],
ref=options['ref'])

tgts = [f'{phase_name}.parameters:{param_name}' for phase_name, param_name in targets_per_phase.items()]
self.connect(f'parameter_vals:{name}', tgts)

return promoted_inputs
Expand Down
54 changes: 25 additions & 29 deletions dymos/utils/introspection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Iterable, Sequence
import fnmatch
from numbers import Number
import re

import openmdao.api as om
import numpy as np
from openmdao.utils.array_utils import shape_to_len
from openmdao.utils.general_utils import ensure_compatible
from dymos.utils.misc import _unspecified
from .._options import options as dymos_options
from ..phase.options import StateOptionsDictionary, TimeseriesOutputOptionsDictionary
Expand Down Expand Up @@ -357,7 +359,7 @@ def configure_controls_introspection(control_options, ode, time_units='s'):
options['units'] = f'{time_units**2}' if rate2_target_units is None \
else f'{rate2_target_units}*{time_units}**2'

if options['shape'] is _unspecified:
if options['shape'] in {None, _unspecified}:
shape = _get_common_metadata(rate2_targets, metadata_key='shape')
if len(shape) == 1:
options['shape'] = (1,)
Expand Down Expand Up @@ -410,36 +412,30 @@ def configure_parameters_introspection(parameter_options, ode):
options['units'] = _get_common_metadata(targets, metadata_key='units')

if options['shape'] in {_unspecified, None}:
static_shapes = {}
dynamic_shapes = {}
param_shape = None
# First find the shapes of the static targets
for tgt, meta in targets.items():
if tgt in options['static_targets']:
static_shapes[tgt] = meta['shape']
else:
if len(meta['shape']) == 1:
dynamic_shapes[tgt] = (1,)
if isinstance(options['val'], Number):
static_shapes = {}
dynamic_shapes = {}
# First find the shapes of the static targets
for tgt, meta in targets.items():
if tgt in options['static_targets']:
static_shapes[tgt] = meta['shape']
else:
dynamic_shapes[tgt] = meta['shape'][1:]
all_shapes = {**dynamic_shapes, **static_shapes}
# Check that they're unique
if len(set(all_shapes.values())) != 1:
raise RuntimeError(f'Unable to obtain shape of parameter {name} via introspection.\n'
f'Targets have multiple shapes.\n'
f'{all_shapes}')
if len(meta['shape']) == 1:
dynamic_shapes[tgt] = (1,)
else:
dynamic_shapes[tgt] = meta['shape'][1:]
all_shapes = {**dynamic_shapes, **static_shapes}
# Check that they're unique
if len(set(all_shapes.values())) != 1:
raise RuntimeError(f'Unable to obtain shape of parameter {name} via introspection.\n'
f'Targets have multiple shapes.\n'
f'{all_shapes}')
else:
options['shape'] = next(iter(set(all_shapes.values())))
else:
options['shape'] = next(iter(set(all_shapes.values())))

if isinstance(options['val'], Sequence):
options['val'] = np.asarray(options['val'])
options['shape'] = np.asarray(options['val']).shape

if np.ndim(options['val']) > 0 and options['val'].shape != options['shape']:
# If the introspected val is a long array (a value at each node), then only
# take the value from the first node.
options['val'] = np.asarray([val[0, ...]])
else:
options['val'] = options['val'] * np.ones(options['shape'])
options['val'], options['shape'] = ensure_compatible(name, options['val'], options['shape'])


def configure_time_introspection(time_options, ode):
Expand Down Expand Up @@ -1194,7 +1190,7 @@ def _get_common_metadata(targets, metadata_key):
raise ValueError(f'Unable to automatically assign {metadata_key} based on targets. \n'
f'No targets were found.')
else:
err_dict = {tgt: meta[metadata_key] for tgt in targets}
err_dict = {tgt: meta[metadata_key] for tgt, meta in targets.items()}
raise ValueError(f'Unable to automatically assign {metadata_key} based on targets. \n'
f'Targets have multiple {metadata_key} assigned: {err_dict}. \n'
f'Either promote targets and use set_input_defaults to assign common '
Expand Down

0 comments on commit 5817db0

Please sign in to comment.