-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce backend.result_type
#18482
Introduce backend.result_type
#18482
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18482 +/- ##
==========================================
+ Coverage 77.34% 77.54% +0.19%
==========================================
Files 332 333 +1
Lines 32000 32163 +163
Branches 6248 6277 +29
==========================================
+ Hits 24751 24940 +189
+ Misses 5664 5637 -27
- Partials 1585 1586 +1
Flags with carried forward coverage won't be shown. Click here to find out more.
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
The logic here is quite complex overall, and it may be difficult to extend and maintain in the future. I assume it is adapted from JAX? Could it be significantly simplified?
I think we can just need a function that can resolve result_dtype(a_dtype, b_dtype)
. Once we have it we can do pairwise reductions to get the result dtype for any list of tensors / dtypes. The function itself can just be half of a symmetric matrix (or the entire matrix, if you don't want to order the inputs). That should be readable and would not take much code.
Yes, it is adapted from JAX. In the latest commit, I believe I have simplified it significantly.
I think most of the complexity arises from dealing the weak type (python's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update -- this is a nice simplification!
backend.result_dtype
and improve dtypes in ops.numpy.*
backend.result_dtype
backend.result_dtype
backend.result_type
Hi @fchollet I think
and also added the corresponding tests.
EDITED: |
While there are still some dtype inconsistencies in other ops, we can address them in a separate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great contribution -- it's looking very good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update!
Got it. |
Yes, that is a good plan. My experience with JAX dtype promotion policy is that it is more user friendly than TF's (with the exception of 64 bit dtypes being disabled by default). |
…X_DEFAULT_DTYPE_BITS=32`
Currently, EDITED: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
@@ -3872,3 +3874,165 @@ def test_tri(self): | |||
self.assertAllClose(knp.Tri()(3), np.tri(3)) | |||
self.assertAllClose(knp.Tri()(3, 4), np.tri(3, 4)) | |||
self.assertAllClose(knp.Tri()(3, 4, 1), np.tri(3, 4, 1)) | |||
|
|||
|
|||
class NumpyDtypeTest(testing.TestCase, parameterized.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the most useful class -- only through rigorous testing can we achieve consistency 👍
Thank you for the thorough review. EDITED:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonderful. Thank you for the awesome contribution! 🚀
[MOVED FROM KERAS CORE PR]
keras-team/keras-core#938
Related to #18400
This PR adds
backend.result_type
which is inspired by and modified from JAX:https://github.com/google/jax/blob/main/jax/_src/dtypes.py
The major difference is as follows:
backend.result_type
does not canonicalize the resulting dtype. Consequently, Keras allows the computation of high-precision types such asfloat64
,int64
anduint64
.backend.result_type
utilizes the precision ofbackend.floatx()
for weak type. It defaults to"32"
, which means thatfloat
will befloat32
and so on.In this PR,
result_type
has been applied to:ops.add
ops.arange
ops.sqrt