From 208cee78ba3947ad908426abee2687cdcb03cfbf Mon Sep 17 00:00:00 2001 From: mgreminger Date: Sat, 7 Sep 2024 17:55:21 -0500 Subject: [PATCH] refactor: use custom function for interpolation function --- public/dimensional_analysis.py | 38 ++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/public/dimensional_analysis.py b/public/dimensional_analysis.py index acc413b2..b4cc9bcf 100644 --- a/public/dimensional_analysis.py +++ b/public/dimensional_analysis.py @@ -1333,25 +1333,33 @@ def get_interpolation_wrapper(interpolation_function: InterpolationFunction): if not NP.all(NP.diff(input_values) > 0): raise ValueError('The input values must be an increasing sequence for interpolation') - def interpolation_wrapper(input: Expr): - global NP - NP = cast(Any, NP) + class interpolation_wrapper(Function): + is_real = True - if input.is_number: - float_input = float(input) + @staticmethod + def _imp_(arg1): + return cast(Any, NP).interp(float(arg1), input_values, output_values) + + def _eval_evalf(self, prec): + if (len(self.args) != 1): + raise TypeError(f'The interpolation function {interpolation_function["name"]} requires 1 input value, ({len(self.args)} given)') + + if (self.args[0].is_number): + float_input = float(cast(Expr, self.args[0])) - if float_input < input_values[0] or float_input > input_values[-1]: - raise ValueError('Attempt to extrapolate with an interpolation function') + if float_input < input_values[0] or float_input > input_values[-1]: + raise ValueError('Attempt to extrapolate with an interpolation function') - return sympify(NP.interp(input, input_values, output_values)) - else: - if "symbolic_function" not in interpolation_function: - custom_func = cast(Callable[[Expr], Expr], Function(interpolation_function["name"], real=True)) - custom_func = implemented_function(custom_func, lambda arg1: cast(Any, NP).interp(float(arg1), input_values, output_values) ) - interpolation_function["symbolic_function"] = cast(UndefinedFunction, custom_func) + return sympify(cast(Any, NP).interp(float_input, input_values, output_values)) - return interpolation_function["symbolic_function"](input) - + def fdiff(self, argindex=1): + delta = sympify(1e-8) + upper_args = [arg if i != argindex-1 else arg + delta for i, arg in enumerate(self.args)] + + return (interpolation_wrapper(*upper_args) - interpolation_wrapper(*self.args)) / delta # type: ignore + + interpolation_wrapper.__name__ = interpolation_function["name"] + def interpolation_dims_wrapper(input): ensure_dims_all_compatible(get_dims(interpolation_function["inputDims"]), input)