Skip to content

Commit

Permalink
Merge pull request #66 from nasa/feature/tuple_units
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert authored Jul 13, 2023
2 parents aaa4a6b + d650984 commit 1485c8a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/progpy/data_models/lstm_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright © 2021 United States Government as represented by the Administrator of the
# National Aeronautics and Space Administration. All Rights Reserved.

from collections import abc
from itertools import chain
import matplotlib.pyplot as plt
from numbers import Number
Expand Down Expand Up @@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
raise ValueError(f"layers must be greater than 0, got {params['layers']}")
if np.isscalar(params['units']):
params['units'] = [params['units'] for _ in range(params['layers'])]
if not isinstance(params['units'], (list, np.ndarray)):
raise TypeError(f"units must be a list of integers, not {type(params['units'])}")
if not isinstance(params['units'], (abc.Sequence, np.ndarray)):
raise TypeError(f"units must be a Sequence (e.g., list or tuple) of integers, not {type(params['units'])}")
if len(params['units']) != params['layers']:
raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}")
for i in range(params['layers']):
Expand All @@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}")
if params['dropout'] < 0:
raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}")
if not isinstance(params['activation'], (list, np.ndarray)):
if not isinstance(params['activation'], (list, tuple, np.ndarray)):
params['activation'] = [params['activation'] for _ in range(params['layers'])]
if not np.isscalar(params['validation_split']):
raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}")
Expand Down
1 change: 1 addition & 0 deletions tests/test_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def future_loading(t, x=None):
[future_loading for _ in range(5)],
dt=[TIMESTEP, TIMESTEP/2, TIMESTEP/4, TIMESTEP*2, TIMESTEP*4],
window=2,
units=(16, ), # Units as tuple
epochs=20)

# Should get keys from original model
Expand Down

0 comments on commit 1485c8a

Please sign in to comment.