Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce dtype inference and improve dtype in ops.numpy.* #938

Closed

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 21, 2023

This PR unifies the default dtype behavior in ops.numpy.* and ensures that they respect backend.floatx()

A subtle bug has been caught in dropout_rnn_cell_test.py:
We should perform a custom mixed precision check because we can't initialize cell with dtype="mixed_float16" in self.run_layer_test.

EDITED:

WIP:

  • Add backend.dtypes functionality
  • Add unit tests
  • Refactor ops.numpy

@codecov
Copy link

codecov bot commented Sep 21, 2023

Codecov Report

Patch coverage: 69.91% and project coverage change: +11.40% 🎉

Comparison is base (6383d8a) 72.28% compared to head (94daabd) 83.69%.
Report is 1 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main     #938       +/-   ##
===========================================
+ Coverage   72.28%   83.69%   +11.40%     
===========================================
  Files         319      320        +1     
  Lines       28879    29058      +179     
  Branches     5529     5579       +50     
===========================================
+ Hits        20876    24320     +3444     
+ Misses       6632     3195     -3437     
- Partials     1371     1543      +172     
Flag Coverage Δ
keras_core 83.58% <69.91%> (+11.35%) ⬆️
keras_core-jax 67.02% <54.23%> (+<0.01%) ⬆️
keras_core-numpy 60.51% <56.35%> (?)
keras_core-tensorflow 66.99% <51.69%> (+<0.01%) ⬆️
keras_core-torch 68.95% <55.08%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
keras_core/backend/common/dtypes.py 61.76% <61.76%> (ø)
keras_core/ops/numpy.py 93.85% <62.85%> (-0.57%) ⬇️
keras_core/backend/torch/numpy.py 94.56% <84.21%> (+94.56%) ⬆️
keras_core/backend/numpy/numpy.py 97.32% <92.85%> (+97.32%) ⬆️
keras_core/backend/jax/numpy.py 97.77% <93.33%> (-0.23%) ⬇️
keras_core/backend/tensorflow/numpy.py 93.83% <93.33%> (-0.11%) ⬇️
keras_core/backend/__init__.py 95.12% <100.00%> (+25.12%) ⬆️
keras_core/backend/common/__init__.py 100.00% <100.00%> (ø)

... and 45 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I think it's a good idea to use dtype=None instead of dtype="float32" in signatures.

We should add dtype checks in unit tests for all operations affected here, to check that we're in fact getting the same dtype across backends, including for array. I think there might be ops where some backends will return float64 instead of float32. This will help us avoid inconsistencies.

@@ -348,6 +353,7 @@ def less_equal(x1, x2):
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
dtype = dtype or config.floatx()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Things like this will deviate from the NumPy convention in the sense that NumPy tries to infer the dtype from argument dtypes. IMO defaulting to float32 is much better: simpler, more consistent. So I think we can go with it.

However if we're going to make this deviation, we should do it consistently, in all ops that infer output dtype from argument dtype, such as arange.

The alternative is to stick to the NumPy dtype inference convention (but with float32 instead of float64).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should stick to the JAX dtype inference convention instead of NumPy, as it should be better suited for DL. What do you think?

We can consider reimplementing jnp.result_dtype for all backends
https://github.com/google/jax/blob/2cba122bbe512f7927d165fdbb29108dcf0fe124/jax/_src/dtypes.py#L638

It may require some time if we decide to do so.

@fchollet
Copy link
Contributor

Also, we should start testing consistency between symbolic outputs and real op outputs. That's a of checks over all, so it would justify the introduce of a new TestCase for dtypes.

@james77777778
Copy link
Contributor Author

We should add dtype checks in unit tests for all operations affected here, to check that we're in fact getting the same dtype across backends, including for array. I think there might be ops where some backends will return float64 instead of float32. This will help us avoid inconsistencies.

Also, we should start testing consistency between symbolic outputs and real op outputs. That's a of checks over all, so it would justify the introduce of a new TestCase for dtypes.

I can add some new test cases in keras_core/ops/numpy_test.py:

class NumpySymbolicDtypeTest(testing.TestCase):
    ...

class NumpyTensorDtypeTest(testing.TestCase):
    ...

Is it good?

However, It may take some time to implement the result_dtype-like function for all backends.

@fchollet
Copy link
Contributor

class NumpyTensorDtypeTest(testing.TestCase):

Yes, that sounds good!

However, It may take some time to implement the result_dtype-like function for all backends.

We may be able to use a test parameterization to save time/code. We can parameterize the input dtype, for instance. But in some cases we may also be able to parameterize the op functions, for groups of ops that have similar arguments.

@james77777778 james77777778 marked this pull request as draft September 22, 2023 09:54
@james77777778 james77777778 changed the title Improve dtype in ops.numpy.* Introduce dtype inference and improve dtype in ops.numpy.* Sep 22, 2023
@james77777778
Copy link
Contributor Author

Hi @fchollet

I want to verify whether this PR is on the right track.

I am attempting to implement a Keras Core version of result_dtype in keras_core/backend/common/dtypes.py.
Currently, the result matchs jnp.result_dtype when the input is python scalar types
(as demonstrated in keras_core/backend/common/dtypes_test.py)

If it is good, I will refactor some of ops.numpy.* and add the previously mentioned tests.

@fchollet
Copy link
Contributor

Keras Core is becoming Keras 3, and we're switching development to the main repository! Please reopen this PR in the keras-team/keras repository. Unfortunately we aren't able to automatically transfer PRs (but we have transferred all issues).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants