Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Hessenberg: Use fori_loop() instead of python for-loop #43

Merged
merged 7 commits into from
Oct 10, 2023

Conversation

paolot-gc
Copy link
Contributor

In the Hessenberg transformation, this PR replaces the python for-loop that scan the columns of the matrix with a jax.lax.fori_loop().
Avoids code bloating associated with the loop unrolling of the python for-loop.
In its first version, this PR does not use the "smart indexing" that avoids some matrix processing.
Currently, the code crashes for sizes larger than 128x128.

@paolot-gc paolot-gc marked this pull request as draft October 3, 2023 11:06
@paolot-gc paolot-gc changed the title Hessenberg: Use fori_loop() instead of python for-loop [WIP] Hessenberg: Use fori_loop() instead of python for-loop Oct 3, 2023
@paolot-gc
Copy link
Contributor Author

The crash was due to memory not being aligned properly. I fixed the issue by slapping a MIN_ALIGN on all vertex Vectors (as suggested by @balancap). I compiled and it ran correctly with size 736.
Unfortunately, it appears to take twice as many cycles as the unrolled version, but it was expected anyway, because I removed for the time being the smart indexing.
Execution_trace

On the plus side, the compile time is just 55s (vs. 15 minutes) and the memory usage is much smaller than for the unrolled-loop version.
image

@paolot-gc
Copy link
Contributor Author

With the use of the gather_p primitive (commit) suggested by @balancap, the cost of the slicing a column of the R matrix and sdiag_full matrix has reduced considerably.
For size 1472, the overall number of cycles for the computation is 37.6e6 (vs. 46.6e6 with direct indexing).
The two charts below illustrate the execution trace for the fori_loop() (implemented by Jax with a while() loop) before (30,676 cycles) and after (24,507).

image image

@paolot-gc
Copy link
Contributor Author

With the last commit, the mapping to tiles has been changed to allow more than one row per tile.
The memory chart below is for size 1472*3 = 4416, which is probably the largest we can handle.
510+kiB per tile
One iteration takes ~164k cycles, whereas the entire computation ~726e6 cycles
image

@paolot-gc paolot-gc marked this pull request as ready for review October 6, 2023 16:42
@paolot-gc paolot-gc requested a review from balancap October 6, 2023 16:47
@paolot-gc paolot-gc merged commit da44969 into main Oct 10, 2023
6 checks passed
@paolot-gc paolot-gc deleted the hessenberg_loop branch October 10, 2023 13:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants