Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(WIP?) vectorized log_likelihood function for NumPyro #2390

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aporsch1
Copy link

@aporsch1 aporsch1 commented Oct 3, 2024

Description

Checklist

  • Follows official PR format
  • New features are properly documented
  • Code style correct (follows pylint and black guidelines)

📚 Documentation preview 📚: https://arviz--2390.org.readthedocs.build/en/2390/

@aporsch1
Copy link
Author

aporsch1 commented Oct 3, 2024

Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though.

@OriolAbril
Copy link
Member

@virajpandya could you try it out and see how timing compares to the ~80 mins from the latest release and setting log_likelihood=False?

You can install the arviz version of this PR with:

pip install "arviz @ git+https://github.com/aporsch1/arviz"

@OriolAbril OriolAbril changed the title (WIP?) vectorized log_likelihood function for NumPyro (https://github.com/arviz-devs/arviz/issues/2373) (WIP?) vectorized log_likelihood function for NumPyro Oct 7, 2024
@OriolAbril OriolAbril linked an issue Oct 7, 2024 that may be closed by this pull request
@OriolAbril
Copy link
Member

Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though.

The pylint checks are failing. These are the specific errors:

************* Module arviz.data.io_numpyro
arviz/data/io_numpyro.py:195:64: C0303: Trailing whitespace (trailing-whitespace)
arviz/data/io_numpyro.py:196:0: C0301: Line too long (105/100) (line-too-long)
arviz/data/io_numpyro.py:195:34: E0602: Undefined variable 'jax' (undefined-variable)

For the jax import, note that it is not a dependency of ArviZ (nor it should be) so it needs to be imported at runtime from inside the method itself. This is already done in the __init__ method for example: https://github.com/arviz-devs/arviz/blob/main/arviz/data/io_numpyro.py#L67

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Log likelihood computation in numpyro can be extremely slow
2 participants