-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX support in DocArray v2 #21
Comments
Hello @Nick17t @samsja! I am Pranjal. While surfing GSoC projects, I came across this today. Having multi-modal data structures compatible with JAX modules sounds really cool to me. I had a small go through the I would love to know more and contribute to the project. |
@DevPranjal I added more info in the description of the issue. Be aware that this project is on DocArray v2 |
@Nick17t @samsja Project Description:DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports PyTorch, Numpy, and TensorFlow as computational backends. We want to extend the backend support to include JAX. Here are the specific tasks involved:
Unit tests for each function in the computational backend, using predefined tensors and DocumentArrayStack. Expected outcomes:Upon successful completion of this project, DocArray v2 will support JAX as a computational backend alongside PyTorch, Numpy, and TensorFlow. The implementation will be thoroughly tested and documented. I would like to work on this project. |
Hey @samsja , I would like to contribute to the project.Please guide me how to get started with the stuff. I am proficient in python as well as machine learning using tensorflow |
Hi @DevPranjal @Arnav131003 @tehami02 I am delighted to hear that you are interested in contributing to the Jina AI community! 🎉 To get started, please take a moment to fill out our survey so that we can learn more about you and your skills. Also, don't forget to mark your calendars for the GSoC x Jina AI webinar on March 23rd at 2 pm (CET). This is an excellent opportunity to learn more about the projects and ask any questions you have about the requirements and expectations. Our mentors will provide an in-depth overview of the projects and answer any questions you may have. So please don't hesitate to ask any questions or seek clarification on any aspect of the project. Is there anything specific you would like to learn from the webinar? Do you have any questions about the JAX support in DocArray v2 project that you would like to see clarified during the Q&A session? Let me know, and I'll be happy to help! Looking forward to seeing you at the webinar, and thank you for your interest in the Jina AI community! 😊 |
Hi @Nick17t this is very interesting project and I have worked on similar kind of project where we have to create the new backend module for JAX. And to make DocumentArrayStack compatible with JAX we need to ensure that DocumentArrayStack works seamlessly with the Jax backend. This will involve testing the existing DocumentArrayStack code with the new Jax backend and resolving any compatibility issues that arise. And I love to work on this project 😁. |
Project idea 6: JAX support in DocArray v2
Project Description
DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports several deep learning frameworks, including PyTorch and TensorFlow. Jax is becoming increasingly popular for deep learning, so we want to integrate it into DocArray.
The project we propose is to add Jax as a backend for DocArray, alongside PyTorch and TensorFlow. The first part would involve rewriting and translating all of the computational backend functions of DocArray with the Jax framework. Then, we would battle-test the implementation against a real Jax use case, such as integrating DocArray with Jax support for model training and serving.
Expected outcomes
Desired skills
More detailed :
This Project target DocArray, especially the current rewrite: DocArray v2 which is a new codebase.
We currently support three computational frameworks in DocArray v2 : Pytorch, Numpy, and TensorFlow, we would like to add JAX support.
More info about JAX can be found here but in short, it is a deep learning framework supported by Google that is getting a lot of traction, especially among researchers.
Concretely what is expected in this project:
Add a new backend to our Computational Backend while relying as much as possible on
jnp
(Jax Numpy) which is a numpy life interface for JAX. A similar approach can be found for the TensorFlow backend: https://github.com/docarray/docarray/blob/feat-rewrite-v2/docarray/computation/tensorflow_backend.pyCreate a new Tensor object with the JAX backend: Example : ImageTensor will need a JAX variant (all of the other one as well)
Make
DocumentArrayStack
compatible with JAX. Hopefully, this should be straightforward with the computational backend agnostic but since we notice some problems with the TensorFlow backend we can expect some friction hereBattle test the whole computational backend:
The text was updated successfully, but these errors were encountered: