Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chalk-Numpy Version 3 #140

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open

Chalk-Numpy Version 3 #140

wants to merge 55 commits into from

Conversation

srush
Copy link
Collaborator

@srush srush commented Aug 5, 2024

This is a third attempt at getting a version of Chalk-Numpy running. I am not sure it is fully compatible yet, but it seems to be getting to the point that I can run the examples without any errors.

I made an attempt to add some comments and examples, but would you mayb be free to have call at some point to discuss?

Goal:

  • Relatively full rewrite a Chalk to support the PyTree / Numpy paradigm. The goal is to maintain the current Chalk API, but have it be more scalable and simpler. It supports Numpy / Jax currently.

Why?

  • Mostly its just a summer project. But it was an interesting learning experience.

What changed?

  • I removed a lot of code: Functional Envelopes / traces. All the rendering code. Shapes. Planar etc.

  • Added core transformations to the main code base.

  • Added support for jaxtyping which runtime checks shapes, and vectorization which compiles extra dimensions, throughout the code base.

  • Rendering code only implements matplotlib path api. (in Patch.py) this let's us remove specific renderers for svg etc. This also makes everything the same size. Text is now implemented by just producing Patch's for all backends.

  • Removed support for everything except arc trails. Segment now are just arc trails. No need for other shapes we previously had.

  • Segments / Text / Style are now all Numpy. This allows fast calculation by batching. Also allows jax compilation (which is useful for things like animations).

  • Layout / Envelope / Traces are implemented by collapsing diagram trees to a batch of segments. This allows efficient calculation without lots of slow python function calls. The one downside of this is that it breaks pad and some of the other functional transforms.

  • Diagam is now a PyTree which support the jax.tree.map functionality. This works both for numpy and jax. Eventually we can switch to optree which is a fast functional tree object that doesn't require jax.

  • Support for batch diagrams. This just means that you can create diagrams with arbitrary sizes and then collapse them by calling circle(np.arange(1, 5)).hcat(). This was a bit tricky to implement and requires keeping an explicit order when collapsing compound diagrams. The main new construct is the ComposeAxis() node type which is added with calls to concat.

What's missing?

  • - Arrow heads that stay the same size
  • - Dashed lines
  • - Better tests for tracing which is complex.
  • - Easy ways to switch between JAX / numpy (right now it is a os env var).

Mostly the example diagrams work and are pretty fast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant