Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in tile_linalg_jacobi for hydrogen #27

Closed
blazejba opened this issue Aug 18, 2023 · 3 comments · Fixed by #32
Closed

Error in tile_linalg_jacobi for hydrogen #27

blazejba opened this issue Aug 18, 2023 · 3 comments · Fixed by #32

Comments

@blazejba
Copy link
Contributor

Computing two hydrogen atoms using cpu backend works:
python density_functional_theory.py -H -backend cpu -float32

but with ipu backend
python density_functional_theory.py -H -backend ipu -float32

... an error is thrown:

Traceback (most recent call last):
  File "density_functional_theory.py", line 1440, in <module>
    elif args.H: recompute(args, None, 0, 0, our_fun=jax_dft, str=[["H", (0, 0, 0)],
  File "density_functional_theory.py", line 1095, in recompute
    energies, our_energy, our_hlgap, t_us, t_main_loop, us_hlgap = our_fun(str)
  File "density_functional_theory.py", line 1205, in jax_dft
    vals = density_functional_theory(atom_positions)
  File "density_functional_theory.py", line 665, in density_functional_theory
    vals = jax.jit(_do_compute, static_argnums=(10,11), device=device_1) ( density_matrix, kinetic, nuclear, overlap,
  File "density_functional_theory.py", line 148, in _do_compute
    vals = jax.lax.fori_loop(0, args.its, iter, vals)
  File "density_functional_theory.py", line 782, in iter
    eigvects = _eigh(generalized_hamiltonian )[1]
  File "density_functional_theory.py", line 1002, in _eigh
    eigvects, eigvals = ipu_eigh(x, sort_eigenvalues=True, num_iters=12)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 310, in ipu_eigh
    A, VT = ipu_jacobi_eigh(x, num_iters=num_iters)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 215, in ipu_jacobi_eigh
    Apcols, Aqcols, Vpcols, Vqcols = jax.lax.fori_loop(
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 213, in <lambda>
    eigh_iteration_fn = lambda _, x: ipu_jacobi_eigh_iteration(x, Atiles, Vtiles)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 180, in ipu_jacobi_eigh_iteration
    Apcols, Aqcols = tile_rotate_columns(Apcols, Aqcols, rotset)
  File "/nethome/blazejb/.venvs/3.2.0+1277/3.2.0+1277_poptorch/lib/python3.8/site-packages/tessellate_ipu/linalg/tile_linalg_jacobi.py", line 276, in tile_rotate_columns
    assert len(pcols_indices_new) == halfN
AssertionError
@AlexanderMath
Copy link
Contributor

ipu_eigh only works for d>=6 we fixed in nanoDFT here. Same fix should also work in density_functional_theory.py.

The fix is just a switch:

def eigh(x):
   if d <= 6: jnp.linalg.eigh(x)
   else: custom_ipu_eigh(x)

@AlexanderMath
Copy link
Contributor

This PR fixes it. Feel free to accept/close this issue.

@balancap
Copy link
Contributor

Thanks @blazejba for finding the issue! Just opened a bug ticket on TessellateIPU so we can solve it there: graphcore-research/tessellate-ipu#12
@AlexanderMath Thanks for the PR. Let's merge that for now until there is a fix in TessellateIPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants