From df4ff635e439d092781bbdb0c72154f14b23c100 Mon Sep 17 00:00:00 2001 From: Joao Lucas de Sousa Almeida Date: Thu, 25 Jul 2024 21:11:59 -0300 Subject: [PATCH] Testing tanh in symbolic expressions Signed-off-by: Joao Lucas de Sousa Almeida --- simulai/residuals/_pytorch_residuals.py | 2 +- tests/residuals/test_symbolicoperator.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/simulai/residuals/_pytorch_residuals.py b/simulai/residuals/_pytorch_residuals.py index 3f03309..5038f23 100644 --- a/simulai/residuals/_pytorch_residuals.py +++ b/simulai/residuals/_pytorch_residuals.py @@ -72,7 +72,7 @@ def __init__( self.processing = processing self.periodic_bc_protected_key = "periodic" - self.protected_funcs = ["cos", "sin", "sqrt", "exp"] + self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh"] self.protected_operators = ["L", "Div", "Identity", "Kronecker"] self.protected_funcs_subs = self._construct_protected_functions() diff --git a/tests/residuals/test_symbolicoperator.py b/tests/residuals/test_symbolicoperator.py index 2df1f40..4023e4f 100644 --- a/tests/residuals/test_symbolicoperator.py +++ b/tests/residuals/test_symbolicoperator.py @@ -171,7 +171,7 @@ def test_symbolic_buitin_functions(self): assert all([isinstance(item, torch.Tensor) for item in residual(input_data)]) def test_symbolic_operator_ode(self): - for token in ["sin", "cos", "sqrt"]: + for token in ["sin", "cos", "sqrt", "tanh"]: f = f"D(u, t) - alpha*{token}(u)" input_labels = ["t"]