Skip to content

Commit

Permalink
Merge pull request #193 from IBM/extend/symbolic
Browse files Browse the repository at this point in the history
Extending support for all the hyperbolic functions
  • Loading branch information
Joao-L-S-Almeida authored Jul 26, 2024
2 parents 23a953a + a19e83e commit c8556af
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
36 changes: 32 additions & 4 deletions simulai/residuals/_pytorch_residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.processing = processing
self.periodic_bc_protected_key = "periodic"

self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh"]
self.protected_funcs = ["cos", "sin", "sqrt", "exp", "tanh", "cosh", "sech", "sinh"]
self.protected_operators = ["L", "Div", "Identity", "Kronecker"]

self.protected_funcs_subs = self._construct_protected_functions()
Expand Down Expand Up @@ -186,9 +186,18 @@ def _construct_protected_functions(self):
Returns:
dict: A dictionary of function names and their corresponding function objects.
"""
protected_funcs = {
func: getattr(self.engine, func) for func in self.protected_funcs
}
protected_funcs = dict()

for func in self.protected_funcs:
try:
func_op = getattr(self.engine, func)
except:
try:
func_op = getattr(self, func)
except:
raise Exception(f"Operator {func} is not defined.")

protected_funcs[func] = func_op

return protected_funcs

Expand Down Expand Up @@ -618,6 +627,25 @@ def inner(inputs):

return jacobian(inner, inputs)

def sech(self, x):

cosh = getattr(self.engine, "cosh")

return 1/cosh(x)

def csch(self, x):

sinh = getattr(self.engine, "sinh")

return 1/sinh(x)

def coth(self, x):

cosh = getattr(self.engine, "cosh")
sinh = getattr(self.engine, "sinh")

return cosh(x)/sinh(x)


def diff(feature: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
"""Calculates the gradient of the given feature with respect to the given parameter.
Expand Down
4 changes: 2 additions & 2 deletions tests/residuals/test_symbolicoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def model_operator():
branch_network=branch_net,
var_dim=n_outputs,
rescale_factors=np.array([1]),
devices="gpu",
devices="cpu",
model_id="flame_net",
)

Expand Down Expand Up @@ -171,7 +171,7 @@ def test_symbolic_buitin_functions(self):
assert all([isinstance(item, torch.Tensor) for item in residual(input_data)])

def test_symbolic_operator_ode(self):
for token in ["sin", "cos", "sqrt", "tanh"]:
for token in ["sin", "cos", "sqrt", "tanh", "cosh", "sech", "coth", "sinh"]:
f = f"D(u, t) - alpha*{token}(u)"

input_labels = ["t"]
Expand Down

0 comments on commit c8556af

Please sign in to comment.