-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ipu_eigh
OOM for 500x500 matrix.
#21
Comments
Found issue. Consider for
Each Potential fix: Compute |
Found a fix, attached profile for N=512 below. Pushed code to this branch. Let me know if it makes sense to write a PR. Changes. Adding modifications to ipu_jacobi_eigh. Rotset was compiled to constant which added 1MB to certain tiles causing OOM (each tile had N=512 copies of rotset, so total of 1472*512~750k copies). Changing so rotset is computed on the fly using jax.numpy (=> each tile only has at most 2 rotsets at any time). Included changing static_gather to all_cols[all_indices, :] which (guessing) gets compiled to Potential tough questions before PR:
|
Closing as #42 PR is fixing the issue, using a proper |
Reproducer
Popvision profile
The text was updated successfully, but these errors were encountered: