Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' #1089

Open
Aki1991 opened this issue Jul 24, 2024 · 9 comments
Open

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' #1089

Aki1991 opened this issue Jul 24, 2024 · 9 comments

Comments

@Aki1991
Copy link

Aki1991 commented Jul 24, 2024

Hi all, I am trying to fine tune our model using owl_vit model.

But when I try to run it, I get this error, AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'. Jax version I am using is 0.4.30. If I use jax 0.4.23, it works but then it is not using GPU while training which slows down the training a lot. Is there a way I can use 0.4.30 version of jax and solve this error?

If I change the PRNGKeyArray with key, at later stage I get an error,

Traceback (most recent call last):
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 196, in _run_module_as_main
   return _run_code(code, main_globals, None,
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/runpy.py", line 86, in _run_code
   exec(code, run_globals)
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 61, in <module>
   app.run(main=main)
 File "/home/user/Akash/Owl/scenic/scenic/app.py", line 68, in run
   app.run(functools.partial(_run_main, main=main))
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 308, in run
   _run_main(main, args)
 File "/home/user/anaconda3/envs/owl_gpu/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
   sys.exit(main(argv))
 File "/home/user/Akash/Owl/scenic/scenic/app.py", line 109, in _run_main
   main(rng=rng, config=config, workdir=workdir, writer=writer)
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/main.py", line 51, in main
   trainer.train(
 File "/home/user/Akash/Owl/scenic/scenic/projects/owl_vit/trainer.py", line 218, in train
   gflops) = train_utils.initialize_model(
 File "/home/user/Akash/Owl/scenic/scenic/train_lib/train_utils.py", line 187, in initialize_model
   flops = debug_utils.compute_flops(
 File "/home/user/Akash/Owl/scenic/scenic/common_lib/debug_utils.py", line 139, in compute_flops
   flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

Can anyone suggest what can I do here? Thank you.

UPDATE: I installed all libraries with proper versions and made it work with GPU with jax==0.4.23 but I am still getting the error mentioned above,

    flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable
@dr4thmos
Copy link

Same issue there.

@Aki1991
Copy link
Author

Aki1991 commented Jul 31, 2024

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray' can be solved by changing jax.random.PRNGKeyArray with jax.Array.

But it is not solving

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

@LihanWa
Copy link

LihanWa commented Aug 8, 2024

This should be fixed in ott-jax==0.3.1

@Aki1991
Copy link
Author

Aki1991 commented Aug 8, 2024

I am using same version of ott-jax==0.3.1, still same error.

@LihanWa
Copy link

LihanWa commented Aug 8, 2024

sorry, it should be running "pip install ott-jax==0.4.5" firstly, if you have an error about "transport" then run "pip install ott-jax==0.3.1"

@Aki1991
Copy link
Author

Aki1991 commented Aug 9, 2024

Yes I am getting the "transport" error, that's why I am using ott-jax==0.3.1. And that leads to the error:

flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

@thecho7
Copy link

thecho7 commented Aug 13, 2024

pip install ott-jax==0.2.0 works

@TTy32
Copy link

TTy32 commented Dec 6, 2024

None of the ott-jax versions work for me:

pip install ott-jax==0.4.5

    from ott.tools import transport
ImportError: cannot import name 'transport' from 'ott.tools' (/home/u/miniforge3/envs/scenic/lib/python3.10/site-packages/ott/tools/__init__.py)

ott-jax==0.3.1

jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

pip install ott-jax==0.2.0

  File "/home/u/scenic/scenic/common_lib/debug_utils.py", line 139, in compute_flops
    flops = analysis['flops']
TypeError: 'NoneType' object is not subscriptable

Any ideas? I have no idea how to continue

@TTy32
Copy link

TTy32 commented Dec 12, 2024

Update: started with a clean conda create -n myenv python=3.10. After pip install ott-jax==0.2.0 the error was gone.

  • CUDA 12.1
  • Pytorch 2.5.1+cu121
  • Make sure to LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants