-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inconsistent type handling between backends: ops.sqrt #18400
Comments
I believe this is the limitation of tensorflow, as it cannot directly perform Workaround: from keras_core import ops
print(ops.arange(10).dtype)
# <dtype: 'int32'>
# method 1: cast before sqrt
print(ops.sqrt(ops.cast(ops.arange(10), "float32")))
# [0. 1. 1.4142135 1.7320508 2. 2.2360678 2.4494896
# 2.6457512 2.828427 3. ], shape=(10,), dtype=float32)
# method 2: create tensor with supported dtype
print(ops.sqrt(ops.arange(10, dtype="float32")))
# [0. 1. 1.4142135 1.7320508 2. 2.2360678 2.4494896
# 2.6457512 2.828427 3. ], shape=(10,), dtype=float32)
print(tf.sqrt(tf.range(10)))
# InvalidArgumentError |
That is the underlying issue - however at this point I'm a little unclear as to what the design goal of the keras_core.ops API is. If it's meant totally decouple the user from the underlying framework, this seems like a bug that needs addressing (potentially performing the cast you suggest within keras_core.ops). If it keras_core.ops is in fact expected to behave differently across backends, then it loses some value. In any case, in the interim, thanks for the workaround! |
Some additional information:
I think we should standardize the behavior of these two ops? |
Thanks for the report. I have fixed some of these issues (the title issue in particular), but generally it's likely there are still various dtype inconsistencies lurking. We should do extensive testing of dtype handling in backend ops. |
I noticed a lot of these while working on a PR to resolve #18415 . In particular, the ops will work when run symbolically - and often time return an incorrectly typed symbolic tensor. e.g. in the sqrt case I believe you'll get back a tensor with the same type, even if it's an integer. Same goes for most trig functions, and even a lot of logical operations that are intended to be used on bools. |
Good call, we also need to test consistency between symbolic outputs and real outputs. |
@lbortolotti , This is working fine in the Keras 3.0.2, please find the working Gist attached |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further. |
Sorry for the spam :-)
I'm migrating my codebase to exclusively use keras_core.ops methods, including all all "base" operations, based on feedback received in keras-team/keras-core#919 and others.
I'm still having some trouble, however.
The following code works correctly with the jax backend, but with tensorflow does the following:
Throws:
The text was updated successfully, but these errors were encountered: