Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support NumPy and PyTorch #26

Open
4 of 17 tasks
NeilGirdhar opened this issue Apr 15, 2024 · 0 comments
Open
4 of 17 tasks

Support NumPy and PyTorch #26

NeilGirdhar opened this issue Apr 15, 2024 · 0 comments

Comments

@NeilGirdhar
Copy link
Owner

NeilGirdhar commented Apr 15, 2024

This is now well within reach thanks to th e Array API.

  • For every method, find xp = get_namespace(*arrays) and then use xp instead of jax.numpy.
  • Use the special extension for special functions (xp.special). Depends on RFC: special function extension data-apis/array-api#725.
  • Support sampling methods for:
    • Jax
    • PyTorch
    • NumPy
  • Support native fixed point sampling methods (used in exp-to-nat) for:
    • Jax
    • PyTorch
    • NumPy
  • Generalize abstract_custom_jvp to PyTorch.
  • Port the Fisher information code (which depends on automatic differentiation) to PyTorch.
  • Move automatic JIT-application from methods to tests.
  • Make tests work for each namespace
    • Jax
    • PyTorch
    • NumPy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant