Skip to content

Commit

Permalink
fixed import onnx bug which was happening due to name conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
rimjhimittal committed Mar 21, 2024
1 parent 22aed97 commit cdbf8ab
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
16 changes: 16 additions & 0 deletions src/modeci_mdf/functions/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,24 @@

import numpy as np
import onnxruntime as ort
import sys
import os

onnx_lib_path = next(p for p in sys.path if 'site-packages' in p)

current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir in sys.path:
sys.path.remove(current_dir)

sys.path.insert(0, onnx_lib_path)

import onnx.defs

sys.path.remove(onnx_lib_path)
if current_dir not in sys.path:
sys.path.insert(0, current_dir)


try:
import torch

Expand Down
26 changes: 19 additions & 7 deletions src/modeci_mdf/functions/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,24 @@
import warnings
import types
from typing import List, Dict, Callable
import sys
import os

from docstring_parser import parse
onnx_lib_path = next(p for p in sys.path if 'site-packages' in p)

current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir in sys.path:
sys.path.remove(current_dir)

sys.path.insert(0, onnx_lib_path)
import onnx
import onnx.defs

sys.path.remove(onnx_lib_path)
if current_dir not in sys.path:
sys.path.insert(0, current_dir)

from docstring_parser import parse

# Make sure we import math and numpy for Python expression strings. These imports
# are important, do not remove even though they appear unused.
Expand Down Expand Up @@ -101,11 +116,7 @@ def substitute_args(expression_string: str = None, args: Dict[str, str] = None)
return expression_string


def create_python_function(
name: str = None,
expression_string: str = None,
arguments: List[str] = None,
) -> "types.FunctionType":
def create_python_function(name: str, expression_string: str, arguments: List[str]) -> types.FunctionType:
"""Create a Python function e.g. linear, exponential, sin, cos, ReLu
Args:
Expand All @@ -122,7 +133,7 @@ def create_python_function(
# assumes expression is one line
name = name.replace(":", "_")
expr = create_python_expression(expression_string)
func_str = f"def {name}({','.join(arguments)}):\n\treturn {expr}"
func_str = f"def {name}({','.join(arguments)}):\n\treturn {expression_string}"

res = {}
exec(func_str, globals(), res)
Expand Down Expand Up @@ -248,6 +259,7 @@ def add_public_functions_from_module(module, module_alias: str = None):
arguments=[STANDARD_ARG_0, "scale"],
expression_string="scale * tan(%s)" % (STANDARD_ARG_0),
)


add_mdf_function(
"sinh",
Expand Down

0 comments on commit cdbf8ab

Please sign in to comment.