Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce backend.result_type #18482

Merged
merged 29 commits into from
Sep 28, 2023

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 23, 2023

[MOVED FROM KERAS CORE PR]
keras-team/keras-core#938

Related to #18400

This PR adds backend.result_type which is inspired by and modified from JAX:
https://github.com/google/jax/blob/main/jax/_src/dtypes.py

The major difference is as follows:

  • Keras' backend.result_type does not canonicalize the resulting dtype. Consequently, Keras allows the computation of high-precision types such as float64, int64 and uint64.
  • Keras' backend.result_type utilizes the precision of backend.floatx() for weak type. It defaults to "32", which means that float will be float32 and so on.

In this PR, result_type has been applied to:

  • ops.add
  • ops.arange
  • ops.sqrt

@codecov-commenter
Copy link

codecov-commenter commented Sep 23, 2023

Codecov Report

Attention: 25 lines in your changes are missing coverage. Please review.

Comparison is base (299419a) 77.34% compared to head (6bb815f) 77.54%.
Report is 2 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18482      +/-   ##
==========================================
+ Coverage   77.34%   77.54%   +0.19%     
==========================================
  Files         332      333       +1     
  Lines       32000    32163     +163     
  Branches     6248     6277      +29     
==========================================
+ Hits        24751    24940     +189     
+ Misses       5664     5637      -27     
- Partials     1585     1586       +1     
Flag Coverage Δ
keras 77.44% <89.13%> (+0.18%) ⬆️
keras-jax 63.32% <64.78%> (+1.13%) ⬆️
keras-numpy 57.27% <71.30%> (+1.14%) ⬆️
keras-tensorflow 63.15% <64.34%> (+1.14%) ⬆️
keras-torch 64.06% <63.91%> (+0.12%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/__init__.py 95.00% <100.00%> (+0.12%) ⬆️
keras/backend/common/__init__.py 100.00% <100.00%> (ø)
keras/backend/jax/numpy.py 98.87% <100.00%> (+1.16%) ⬆️
keras/backend/numpy/numpy.py 98.52% <100.00%> (+1.55%) ⬆️
keras/backend/tensorflow/numpy.py 95.01% <100.00%> (+1.07%) ⬆️
keras/backend/torch/numpy.py 94.87% <100.00%> (+0.78%) ⬆️
keras/ops/numpy.py 94.86% <97.14%> (+0.63%) ⬆️
keras/backend/common/dtypes.py 76.69% <76.69%> (ø)

... and 3 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

The logic here is quite complex overall, and it may be difficult to extend and maintain in the future. I assume it is adapted from JAX? Could it be significantly simplified?

I think we can just need a function that can resolve result_dtype(a_dtype, b_dtype). Once we have it we can do pairwise reductions to get the result dtype for any list of tensors / dtypes. The function itself can just be half of a symmetric matrix (or the entire matrix, if you don't want to order the inputs). That should be readable and would not take much code.

keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
@james77777778
Copy link
Contributor Author

The logic here is quite complex overall, and it may be difficult to extend and maintain in the future. I assume it is adapted from JAX? Could it be significantly simplified?

Yes, it is adapted from JAX. In the latest commit, I believe I have simplified it significantly.

I think we can just need a function that can resolve result_dtype(a_dtype, b_dtype).

I think most of the complexity arises from dealing the weak type (python's int and float). We can simplify the logic by removing them, but we might lose some value in result_dtype.
For example, there are some ops allow python scalar type for their arguments. It would be inconvenient if we couldn't handle weak type in result_dtype.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update -- this is a nice simplification!

keras/backend/common/dtypes.py Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
keras/backend/common/dtypes.py Outdated Show resolved Hide resolved
@james77777778 james77777778 changed the title Introduce backend.result_dtype and improve dtypes in ops.numpy.* Introduce backend.result_dtype Sep 26, 2023
@james77777778 james77777778 changed the title Introduce backend.result_dtype Introduce backend.result_type Sep 26, 2023
@james77777778
Copy link
Contributor Author

james77777778 commented Sep 26, 2023

Hi @fchollet

I think backend.result_type should now be ready. I have applied it to the following ops:

  • ops.ones
  • ops.zeros
  • ops.empty
  • ops.identity
  • ops.tri
  • ops.eye

and also added the corresponding tests.

It is worth noting:
There is an inevitable gap between tensorflow and other backends, where tf.Variable with int32 cannot be placed on the GPU, and as a workaround we need to use int64.
https://www.tensorflow.org/xla/known_issues#tfvariable_on_a_different_device
This issue will break self.add_variable because constant initializer extensively use ops.ones, ops.zeros and ops.eye in the __call__.

So I have skipped the canonicalization when the resulting dtype is int64 with tensorflow in backend.result_type.

EDITED:
On second thought:
It should be safe to cast the value for the initialization of tf.Variable.
This resolves the int32 and int64 issue in tensorflow without making the exception of the type inference rule.

@james77777778
Copy link
Contributor Author

While there are still some dtype inconsistencies in other ops, we can address them in a separate PR.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great contribution -- it's looking very good.

keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
keras/backend/numpy/numpy.py Show resolved Hide resolved
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
@james77777778
Copy link
Contributor Author

There are valid use cases for float64. Doing print(ops.ones((1,), dtype="float64").dtype) should definitely return float64 with all backends other than JAX. In JAX it is an unfortunate limitation inherited from historical reasons.

Requesting float64 should return a float64 output if the backend supports it (as all do, except JAX).

Got it.
Generally, we should adhere to JAX's type inference rules with enable_x64=True. Is that correct?

@fchollet
Copy link
Collaborator

Generally, we should adhere to JAX's type inference rules with enable_x64=True. Is that correct?

Yes, that is a good plan. My experience with JAX dtype promotion policy is that it is more user friendly than TF's (with the exception of 64 bit dtypes being disabled by default).

@james77777778
Copy link
Contributor Author

james77777778 commented Sep 28, 2023

Yes, that is a good plan. My experience with JAX dtype promotion policy is that it is more user friendly than TF's (with the exception of 64 bit dtypes being disabled by default).

Currently, backend.result_type and some of ops.numpy.* behavior matches JAX when using JAX_DEFAULT_DTYPE_BITS=32 and JAX_ENABLE_X64=true

EDITED:
This PR should be ready.
In ops.add, the type inference of jax and torch is the same, so we can omit result_type and casting.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

keras/backend/jax/numpy.py Outdated Show resolved Hide resolved
keras/backend/tensorflow/core.py Outdated Show resolved Hide resolved
@@ -3872,3 +3874,165 @@ def test_tri(self):
self.assertAllClose(knp.Tri()(3), np.tri(3))
self.assertAllClose(knp.Tri()(3, 4), np.tri(3, 4))
self.assertAllClose(knp.Tri()(3, 4, 1), np.tri(3, 4, 1))


class NumpyDtypeTest(testing.TestCase, parameterized.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the most useful class -- only through rigorous testing can we achieve consistency 👍

@james77777778
Copy link
Contributor Author

james77777778 commented Sep 28, 2023

Thank you for the thorough review.
Please let me know if any changes are necessary.

EDITED:
result_type has been applied to:

  • ops.add
  • ops.arange
  • ops.sqrt

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful. Thank you for the awesome contribution! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants