Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX #4

Open
fmaussion opened this issue Dec 9, 2018 · 4 comments
Open

JAX #4

fmaussion opened this issue Dec 9, 2018 · 4 comments

Comments

@fmaussion
Copy link
Member

Out since one month: https://github.com/google/jax

Twitter is divided about the actual differences with pytorch, but the clear scope of JAX makes it actually worth a try (maybe faster than pytorch?)

@fmaussion
Copy link
Member Author

In particular, I wonder how these limitations compare to pytorch:

TLDR Do use

  • Functional programming
  • Many of NumPy’s
    functions (help us add more!)
  • Some SciPy functions
  • Indexing and slicing of arrays like x = A[[5, 1, 7], :, 2:4]
  • Explicit array creation from lists like A = np.array([x, y])

Don’t use

  • Assignment into arrays like A[0, 0] = x
  • Implicit casting to arrays like np.sum([x, y]) (use np.sum(np.array([x, y]) instead)
  • A.dot(B) method syntax for functions of more than one argument (use
    np.dot(A, B) instead)
  • Side-effects like mutation of arguments or mutation of global variables
  • The out argument of NumPy functions

For jit functions, also don’t use

  • Control flow based on dynamic values if x > 0: .... Control flow based
    on shapes is fine: if x.shape[0] > 2: ... and for subarr in array.
  • Slicing A[i:i+5] for dynamic index i (use lax.dynamic_slice instead)
    or boolean indexing A[bool_ind] for traced values bool_ind.

@fmaussion
Copy link
Member Author

I'm most worried about the forbidden assignment OPs

@phigre
Copy link
Collaborator

phigre commented Dec 9, 2018

Hmm, sounds generally really interesting.
Currently, some operations violate rules demanded by JAX:

* Assignment into arrays like `A[0, 0] = x`

This is at the moment only relevant at the final ice thickness computation for each step and probably could be altered by reverting the indexing scheme to the original model and calculating ice thickness at the domain border as well. But then, fancy indexing is necessary again ...

Further on, the stability criterion is currently implemented by an if-statement and therefore probably preventing JIT:

* Control flow based on dynamic values `if x > 0: ...`. Control flow based
  on shapes is fine: `if x.shape[0] > 2: ...` and `for subarr in array`.

Will have to have a closer look for exact evaluation.

@phigre
Copy link
Collaborator

phigre commented Dec 9, 2018

At first glance, it also seems like convolution is not yet supported by JAX. However, this is currently used to determine the boundary of the ice cap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants