Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic tensor addition is not commutative w.r.t. dtypes #18415

Closed
jackd opened this issue Aug 27, 2023 · 4 comments
Closed

Symbolic tensor addition is not commutative w.r.t. dtypes #18415

jackd opened this issue Aug 27, 2023 · 4 comments
Labels

Comments

@jackd
Copy link
Contributor

jackd commented Aug 27, 2023

Behaviour of adding tensors of different types is unexpected and inconsistent with backends. Backends appear to all be commutative, but symbolic addition is not.

import numpy as np
import keras_core as keras

x = keras.Input((), dtype="float32")
y = keras.Input((), dtype="float64")

z0 = x + y  
print(f"z0: {z0.dtype}") # float32
z1 = y + x
print(f"z1: {z1.dtype}") # float64

m0 = keras.Model((x, y), z0)
m1 = keras.Model((x, y), z1)

v0 = m0((np.zeros((1,), "float32"), np.zeros((1,), "float64")))
print(f"v0: {v0.dtype}") # float64 (float32 in jax)
v1 = m1((np.zeros((1,), "float64"), np.zeros((1,), "float32")))
print(f"v1: {v1.dtype}") # float64 (float32 in jax)
@fchollet
Copy link
Collaborator

Thanks for the report! It sounds like there are several issues here:

  • We need a consistent casting policy when mixing dtypes. I think we should go with the JAX standard (downcasting). We can then package it as a backend utility like dtype = resolve_dtype(*dtypes).
  • In every op that takes more than one input, we use the utility to compute what to cast all inputs to.
  • We need to make sure to also use the same utility in the same way in the compute_output_spec logic of multi-input symbolic operations.

Would you be able to open a PR?

@jackd
Copy link
Contributor Author

jackd commented Aug 27, 2023

Just did a quick check and realized jax didn't actually down-cast - I hadn't enabled float64, so it was just adding two float32s. Once enabled, implicit cast behaviour is to up-cast. That makes me think going with the consensus (up-casting) would be less surprising that introducing another.

Happy to look into a PR.

@fchollet
Copy link
Collaborator

That makes me think going with the consensus (up-casting) would be less surprising that introducing another.

Definitely, let's do that.

Happy to look into a PR.

Thank you so much!

@fchollet fchollet transferred this issue from keras-team/keras-core Sep 22, 2023
@jackd
Copy link
Contributor Author

jackd commented Oct 2, 2023

Looks like this was resolved in #18482

@jackd jackd closed this as completed Oct 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants