Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ipu_eigh OOM for 500x500 matrix. #21

Closed
AlexanderMath opened this issue Aug 31, 2023 · 4 comments · Fixed by #42
Closed

ipu_eigh OOM for 500x500 matrix. #21

AlexanderMath opened this issue Aug 31, 2023 · 4 comments · Fixed by #42

Comments

@AlexanderMath
Copy link

AlexanderMath commented Aug 31, 2023

Reproducer

import jax 
import jax.numpy as jnp
import numpy as np 

def linalg_eigh(x):
  from tessellate_ipu.linalg import ipu_eigh
  eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
  return eigvals, eigvects

a = np.random.normal(0, 1, (500, 500))

print(jax.jit(linalg_eigh, backend="ipu")(a))

Popvision profile
image

image

@AlexanderMath
Copy link
Author

AlexanderMath commented Sep 1, 2023

Found issue. Consider for N=500 the lines 142-144

for _ in range(1, N):
    rotset_sorted = jacobi_sort_rotation_set(rotset)
    print(rotset.shape, rotset.nbytes) # (250, 2) 2000
    rotset_replicated = tile_constant_replicated(rotset_sorted, tiles=Atiles)
    ...
    # Next rotation set.
    rotset = jacobi_next_rotation_set(rotset)

Each rotset takes up 2kb and we create N=500 different ones taking up 500*2kb=1MB. Since we replicate these as constants over tiles=Atiles we try to put 1MB on all the tiles => OOM.

Potential fix: Compute rotset on the fly, looks like it's ok for rotset to be computed on the fly.

@AlexanderMath
Copy link
Author

Popvision profile verification for N=100. On tile 64 we have 100 constants each of 400 bytes (profile numbers are zero index'ed so 99 is 100).
image

@AlexanderMath
Copy link
Author

AlexanderMath commented Sep 4, 2023

Found a fix, attached profile for N=512 below. Pushed code to this branch. Let me know if it makes sense to write a PR.

Changes. Adding modifications to ipu_jacobi_eigh. Rotset was compiled to constant which added 1MB to certain tiles causing OOM (each tile had N=512 copies of rotset, so total of 1472*512~750k copies). Changing so rotset is computed on the fly using jax.numpy (=> each tile only has at most 2 rotsets at any time). Included changing static_gather to all_cols[all_indices, :] which (guessing) gets compiled to all_cols.T[:, all_indices].T. Finally, compile time blew up for N=512 so changed for i in range(1, N) to jax.lax.fori_loop(1, N, iteration, ..). Numerical error for d=64 is np.max(vals)=1.14e-05 and np.max(vects)=1.479 but 1e-4 for N=512 even with num_iters=64 (can't check if this is also the case for prev implementation due to OOM when N=512).

Potential tough questions before PR:

  1. do we want to retain control over unroll vs jax.lax.fori_loop?
  2. profiling w/ popvision took ~10min for N=512.
  3. do we want flexibility to switch between np and jnp when computing rotset (i.e. do at trace time or on-the-fly); can be done by adding argument def ipu_eigh(..., np_backend=jax.numpy).

image

@balancap
Copy link
Contributor

Closing as #42 PR is fixing the issue, using a proper fori_loop.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants