From 084031b48e4bfc976f565539fd3595154de9e9e4 Mon Sep 17 00:00:00 2001 From: Nicholas Reinicke Date: Fri, 17 Nov 2023 11:55:27 -0700 Subject: [PATCH] formatting --- nrel/routee/powertrain/validation/errors.py | 8 ++------ scripts/developers/train_model_catalog.py | 8 ++++---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/nrel/routee/powertrain/validation/errors.py b/nrel/routee/powertrain/validation/errors.py index 9739a15..9920fc9 100644 --- a/nrel/routee/powertrain/validation/errors.py +++ b/nrel/routee/powertrain/validation/errors.py @@ -28,9 +28,7 @@ } -def mean_squared_error( - A, B, axis: Optional[int] = None -) -> float: +def mean_squared_error(A, B, axis: Optional[int] = None) -> float: return np.square(A - B).mean(axis=axis) @@ -41,9 +39,7 @@ def net_energy_error(target, target_pred) -> float: return net_error -def weighted_relative_percent_difference( - target, target_pred -) -> float: +def weighted_relative_percent_difference(target, target_pred) -> float: epsilon = np.finfo(np.float64).eps w = np.array(np.abs(target) / np.sum(np.abs(target))) diff --git a/scripts/developers/train_model_catalog.py b/scripts/developers/train_model_catalog.py index 20991ed..391a5e2 100644 --- a/scripts/developers/train_model_catalog.py +++ b/scripts/developers/train_model_catalog.py @@ -312,17 +312,17 @@ def load_all_files(files, file_limit=FILE_LIMIT): pt.DataColumn(name="grade_dec", units="decimal"), pt.DataColumn(name="entry_angle", units="degrees"), ] -features = [ +features = [ [pt.DataColumn(name="speed_mph", units="mph")], [ pt.DataColumn(name="speed_mph", units="mph"), - pt.DataColumn(name="grade_dec", units="decimal") + pt.DataColumn(name="grade_dec", units="decimal"), ], [ pt.DataColumn(name="speed_mph", units="mph"), pt.DataColumn(name="previous_speed_mph", units="mph"), pt.DataColumn(name="grade_dec", units="decimal"), - pt.DataColumn(name="previous_grade_dec", units="decimal") + pt.DataColumn(name="previous_grade_dec", units="decimal"), ], [ pt.DataColumn(name="previous_speed_mph", units="mph"), @@ -330,7 +330,7 @@ def load_all_files(files, file_limit=FILE_LIMIT): pt.DataColumn(name="previous_grade_dec", units="decimal"), pt.DataColumn(name="grade_dec", units="decimal"), pt.DataColumn(name="entry_angle", units="degrees"), - pt.DataColumn(name="exit_angle", units="degrees") + pt.DataColumn(name="exit_angle", units="degrees"), ], ]