diff --git a/simulai/residuals/_pytorch_residuals.py b/simulai/residuals/_pytorch_residuals.py index 5038f23..b6a38f3 100644 --- a/simulai/residuals/_pytorch_residuals.py +++ b/simulai/residuals/_pytorch_residuals.py @@ -72,7 +72,7 @@ def __init__( self.processing = processing self.periodic_bc_protected_key = "periodic" - self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh"] + self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh", "cosh", "sech", "sinh"] self.protected_operators = ["L", "Div", "Identity", "Kronecker"] self.protected_funcs_subs = self._construct_protected_functions() @@ -186,9 +186,18 @@ def _construct_protected_functions(self): Returns: dict: A dictionary of function names and their corresponding function objects. """ - protected_funcs = { - func: getattr(self.engine, func) for func in self.protected_funcs - } + protected_funcs = dict() + + for func in self.protected_funcs: + try: + func_op = getattr(self.engine, func) + except: + try: + func_op = getattr(self, func) + except: + raise Exception(f"Operator {func} is not defined.") + + protected_funcs[func] = func_op return protected_funcs @@ -618,6 +627,25 @@ def inner(inputs): return jacobian(inner, inputs) + def sech(self, x): + + cosh = getattr(self.engine, "cosh") + + return 1/cosh(x) + + def csch(self, x): + + sinh = getattr(self.engine, "sinh") + + return 1/sinh(x) + + def coth(self, x): + + cosh = getattr(self.engine, "cosh") + sinh = getattr(self.engine, "sinh") + + return cosh(x)/sinh(x) + def diff(feature: torch.Tensor, param: torch.Tensor) -> torch.Tensor: """Calculates the gradient of the given feature with respect to the given parameter. diff --git a/tests/residuals/test_symbolicoperator.py b/tests/residuals/test_symbolicoperator.py index 4023e4f..4a410d7 100644 --- a/tests/residuals/test_symbolicoperator.py +++ b/tests/residuals/test_symbolicoperator.py @@ -98,7 +98,7 @@ def model_operator(): branch_network=branch_net, var_dim=n_outputs, rescale_factors=np.array([1]), - devices="gpu", + devices="cpu", model_id="flame_net", ) @@ -171,7 +171,7 @@ def test_symbolic_buitin_functions(self): assert all([isinstance(item, torch.Tensor) for item in residual(input_data)]) def test_symbolic_operator_ode(self): - for token in ["sin", "cos", "sqrt", "tanh"]: + for token in ["sin", "cos", "sqrt", "tanh", "cosh", "sech", "coth", "sinh"]: f = f"D(u, t) - alpha*{token}(u)" input_labels = ["t"]