-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
converting sympy.NumberSymbol to torch.tensor in export_torch.py #726
base: master
Are you sure you want to change the base?
Conversation
attempting to address MilesCranmer#656
Nice! Do you want to add a unit test for the MWE you described in the issue? |
Pull Request Test Coverage Report for Build 11200885649Details
💛 - Coveralls |
Hi I've added a unit test to here, let me know if that is sufficient or if i need to do anything else |
for more information, see https://pre-commit.ci
Seems like the test is failing |
Apologies, I won't be able to address it until next week |
No worries! |
I figured out what's going on.
But So i have added a line to make perhaps you would rather fix it at a different level of abstraction? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the underlying issue is the use of issubclass
instead of isinstance
... Maybe try replacing those conditions with the following code?
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)
self._sympy_func = expr.func
if isinstance(expr, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Rational) and not isinstance(expr, sympy.Integer):
# This is some fraction fixed in the operator.
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not isinstance(expr.args[0], sympy.Float):
raise ValueError(
"UnevaluatedExpr should only be used to wrap floats."
)
self.register_buffer("_value", torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Integer):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.NumberSymbol):
# Handles mathematical constants like pi, E
self._value = float(expr)
self._torch_func = lambda: self._value
self._args = ()
elif isinstance(expr, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
try:
self._torch_func = _func_lookup[expr.func]
except KeyError:
raise KeyError(
f"Function {expr.func} was not found in Torch function mappings. "
"Please add it to extra_torch_mappings in the format, e.g., "
"{sympy.sqrt: torch.sqrt}."
)
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(
expr=arg,
_memodict=_memodict,
_func_lookup=_func_lookup,
**kwargs,
)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)
Thanks! That code gives the following error
note that it is now saying that's due to this code: elif isinstance(expr, sympy.Integer):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = () is there a reason that for the would it be ok to just add the following code before the integer case? elif isinstance(expr, sympy.core.numbers.One):
# Handles Integer special cases like NegativeOne, One, Zero
self._value = torch.tensor(int(expr))
self._torch_func = lambda: self._value
self._args = () or should it be either way, it then passes the test |
attempting to address #656