Motivation: This repo was originally a series of lessons introducing Jax concepts, and it can still be used for that. However, it has also advanced beyond that to include example code solving various real-world problems. These case-studies themselves can be used as a follow-up lesson for those who have finished the first lesson.
Blurb: Jax is often thought of as Numpy for the GPU, but it is so much more (both in terms of features, and sharp edges). The tutorials presented here—one aimed at a general audience and the other at computational neuroscientists—were inspired by a roadblock I encountered in my research. Specifically, I was working on a LIF simulation problem that, despite using vectorized Numpy, took excessively long to run. By incorporating Jax into my workflow and iterating on it, I managed to reduce the runtime from ~10 seconds to ~0.2 seconds.
The exercises folder contains the code structured as a series of exercises for you to work through to reinforce the concepts.
- 
Using jit
- 
Understanding when to use jit a.k.a why not jit everything? 
- 
Timing jax
- 
reading haskell-like function signatures 
- 
fori_loop,while_loop,scan
- make your code look more like the math described in the papers
- in prior notebooks we had introduced methods to speed up code, and the JIT compilation. Let's investigate if and how much they speed up code!
- learn the design decisions behind Jax's RNG implementation
- learn the basics behind Jax's grad methods that will cover 90% of usecases
- should probably be called grad manipulations, where we stop gradients, skip applications, and more
- learn how Jax internally handles data structures and how to add your own custom data structure to the model registry
Einsum isn't specific to Jax, but it's still useful to know!
Case studies build on the exercises and rely on concepts covered in the lessons. In the case studies we see the concepts applied to real-world problems.
-  Grad Advanced, which will cover custom gradients, jacrevandjacfwd
-  3d Parallelism
- data
- pipeline
- tensor parallelism
 
If you use this software in your research, please cite it as follows:
@misc{numpy_to_jax,
  title = {Numpy To Jax},
  author = {Ian Quah, Bryan Quah},
  year = {2024},
  url = {https://github.com/IanQS/numpy_to_jax},
  version = {1.0.0},
  note = {Jax is often thought of as Numpy for the GPU, but it is so much more (both in terms of features, and sharp edges). The tutorials presented here—one aimed at a general audience and the other at computational neuroscientists—were inspired by a roadblock I encountered in my research}
}