Skip to content

Commit

Permalink
refactor: use custom function for interpolation function
Browse files Browse the repository at this point in the history
  • Loading branch information
mgreminger committed Sep 7, 2024
1 parent bc6e2de commit 208cee7
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions public/dimensional_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,25 +1333,33 @@ def get_interpolation_wrapper(interpolation_function: InterpolationFunction):
if not NP.all(NP.diff(input_values) > 0):
raise ValueError('The input values must be an increasing sequence for interpolation')

def interpolation_wrapper(input: Expr):
global NP
NP = cast(Any, NP)
class interpolation_wrapper(Function):
is_real = True

if input.is_number:
float_input = float(input)
@staticmethod
def _imp_(arg1):
return cast(Any, NP).interp(float(arg1), input_values, output_values)

def _eval_evalf(self, prec):
if (len(self.args) != 1):
raise TypeError(f'The interpolation function {interpolation_function["name"]} requires 1 input value, ({len(self.args)} given)')

if (self.args[0].is_number):
float_input = float(cast(Expr, self.args[0]))

if float_input < input_values[0] or float_input > input_values[-1]:
raise ValueError('Attempt to extrapolate with an interpolation function')
if float_input < input_values[0] or float_input > input_values[-1]:
raise ValueError('Attempt to extrapolate with an interpolation function')

return sympify(NP.interp(input, input_values, output_values))
else:
if "symbolic_function" not in interpolation_function:
custom_func = cast(Callable[[Expr], Expr], Function(interpolation_function["name"], real=True))
custom_func = implemented_function(custom_func, lambda arg1: cast(Any, NP).interp(float(arg1), input_values, output_values) )
interpolation_function["symbolic_function"] = cast(UndefinedFunction, custom_func)
return sympify(cast(Any, NP).interp(float_input, input_values, output_values))

return interpolation_function["symbolic_function"](input)

def fdiff(self, argindex=1):
delta = sympify(1e-8)
upper_args = [arg if i != argindex-1 else arg + delta for i, arg in enumerate(self.args)]

return (interpolation_wrapper(*upper_args) - interpolation_wrapper(*self.args)) / delta # type: ignore

interpolation_wrapper.__name__ = interpolation_function["name"]

def interpolation_dims_wrapper(input):
ensure_dims_all_compatible(get_dims(interpolation_function["inputDims"]), input)

Expand Down

0 comments on commit 208cee7

Please sign in to comment.