Skip to content

Commit

Permalink
Reorganise unit tests to better match source layout (#122)
Browse files Browse the repository at this point in the history
* rename test files

* default to showing 10 slowest tests

* move pyscf interop tests

* move ipu tests

* remove duplicated test case

* remove redundant parametrize on test_gto
  • Loading branch information
hatemhelal authored Oct 10, 2023
1 parent 9ca8e13 commit 8b0ee07
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 104 deletions.
2 changes: 1 addition & 1 deletion .pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[pytest]
addopts = -s -v
addopts = -s -v --durations=10
103 changes: 0 additions & 103 deletions test/test_experimental.py → test/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax import tree_map, vmap
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.basis import basisset
from pyscf_ipu.experimental.device import has_ipu, ipu_func
from pyscf_ipu.experimental.integrals import (
eri_basis,
eri_basis_sparse,
Expand All @@ -19,47 +17,10 @@
overlap_primitives,
)
from pyscf_ipu.experimental.interop import to_pyscf
from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh
from pyscf_ipu.experimental.primitive import Primitive
from pyscf_ipu.experimental.structure import molecule


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"])
def test_to_pyscf(basis_name):
mol = molecule("water")
basis = basisset(mol, basis_name)
pyscf_mol = to_pyscf(mol, basis_name)
assert basis.num_orbitals == pyscf_mol.nao


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31+g"])
def test_gto(basis_name):
from pyscf.dft.numint import eval_rho

# Atomic orbitals
structure = molecule("water")
basis = basisset(structure, basis_name)
mesh, _ = uniform_mesh()
actual = basis(mesh)

mol = to_pyscf(structure, basis_name)
expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh))
assert_allclose(actual, expect_ao, atol=1e-6)

# Molecular orbitals
mf = mol.KS()
mf.kernel()
C = jnp.array(mf.mo_coeff, dtype=jnp.float32)
actual = basis.occupancy * C @ C.T
expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32)
assert_allclose(actual, expect, atol=1e-6)

# Electron density
actual = electron_density(basis, mesh, C)
expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda")
assert_allclose(actual, expect, atol=1e-6)


def test_overlap():
# Exercise 3.21 of "Modern quantum chemistry: introduction to advanced
# electronic structure theory."" by Szabo and Ostlund
Expand Down Expand Up @@ -147,19 +108,6 @@ def test_water_nuclear():
assert_allclose(actual, expect, atol=1e-4)


def eri_orbitals(orbitals):
def take(orbital, index):
p = tree_map(lambda *xs: jnp.stack(xs), *orbital.primitives)
p = tree_map(lambda x: jnp.take(x, index, axis=0), p)
c = jnp.take(orbital.coefficients, index)
return p, c

indices = [jnp.arange(o.num_primitives) for o in orbitals]
indices = [i.reshape(-1) for i in jnp.meshgrid(*indices)]
prim, coef = zip(*[take(o, i) for o, i in zip(orbitals, indices)])
return jnp.sum(jnp.prod(jnp.stack(coef), axis=0) * vmap(eri_primitives)(*prim))


def test_eri():
# PyQuante test cases for ERI
a, b, c, d = [Primitive()] * 4
Expand All @@ -168,18 +116,6 @@ def test_eri():
c, d = [Primitive(lmn=jnp.array([1, 0, 0]))] * 2
assert_allclose(eri_primitives(a, b, c, d), 0.940316, atol=1e-5)

# H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
h2 = molecule("h2")
basis = basisset(h2, "sto-3g")
indices = [(0, 0, 0, 0), (0, 0, 1, 1), (1, 0, 0, 0), (1, 0, 1, 0)]
expected = [0.7746, 0.5697, 0.4441, 0.2970]

for ijkl, expect in zip(indices, expected):
actual = eri_orbitals([basis.orbitals[aoid] for aoid in ijkl])
assert_allclose(actual, expect, atol=1e-4)


def test_eri_basis():
# H2 molecule in sto-3g: See equation 3.235 of Szabo and Ostlund
h2 = molecule("h2")
basis = basisset(h2, "sto-3g")
Expand Down Expand Up @@ -215,42 +151,3 @@ def test_water_eri(sparse):
aosym = "s8" if sparse else "s1"
expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym)
assert_allclose(actual, expect, atol=1e-4)


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_overlap():
from pyscf_ipu.experimental.integrals import _overlap_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_overlap_primitives)(a, b)
assert_allclose(actual, overlap_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_kinetic():
from pyscf_ipu.experimental.integrals import _kinetic_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_kinetic_primitives)(a, b)
assert_allclose(actual, kinetic_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_nuclear():
from pyscf_ipu.experimental.integrals import _nuclear_primitives

# PyQuante test case for nuclear attraction integral
a, b = [Primitive()] * 2
c = jnp.zeros(3)
actual = ipu_func(_nuclear_primitives)(a, b, c)
assert_allclose(actual, -1.595769, atol=1e-5)


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_ipu_eri():
from pyscf_ipu.experimental.integrals import _eri_primitives

# PyQuante test cases for ERI
a, b, c, d = [Primitive()] * 4
actual = ipu_func(_eri_primitives)(a, b, c, d)
assert_allclose(actual, 1.128379, atol=1e-5)
47 changes: 47 additions & 0 deletions test/test_integrals_ipu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import pytest
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.device import has_ipu, ipu_func
from pyscf_ipu.experimental.integrals import kinetic_primitives, overlap_primitives
from pyscf_ipu.experimental.primitive import Primitive


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_overlap():
from pyscf_ipu.experimental.integrals import _overlap_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_overlap_primitives)(a, b)
assert_allclose(actual, overlap_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_kinetic():
from pyscf_ipu.experimental.integrals import _kinetic_primitives

a, b = [Primitive()] * 2
actual = ipu_func(_kinetic_primitives)(a, b)
assert_allclose(actual, kinetic_primitives(a, b))


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_nuclear():
from pyscf_ipu.experimental.integrals import _nuclear_primitives

# PyQuante test case for nuclear attraction integral
a, b = [Primitive()] * 2
c = jnp.zeros(3)
actual = ipu_func(_nuclear_primitives)(a, b, c)
assert_allclose(actual, -1.595769, atol=1e-5)


@pytest.mark.skipif(not has_ipu(), reason="Skipping ipu test!")
def test_eri():
from pyscf_ipu.experimental.integrals import _eri_primitives

# PyQuante test cases for ERI
a, b, c, d = [Primitive()] * 4
actual = ipu_func(_eri_primitives)(a, b, c, d)
assert_allclose(actual, 1.128379, atol=1e-5)
46 changes: 46 additions & 0 deletions test/test_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
import jax.numpy as jnp
import numpy as np
import pytest
from numpy.testing import assert_allclose

from pyscf_ipu.experimental.basis import basisset
from pyscf_ipu.experimental.interop import to_pyscf
from pyscf_ipu.experimental.mesh import electron_density, uniform_mesh
from pyscf_ipu.experimental.structure import molecule


@pytest.mark.parametrize("basis_name", ["sto-3g", "6-31g**"])
def test_to_pyscf(basis_name):
mol = molecule("water")
basis = basisset(mol, basis_name)
pyscf_mol = to_pyscf(mol, basis_name)
assert basis.num_orbitals == pyscf_mol.nao


def test_gto():
from pyscf.dft.numint import eval_rho

# Atomic orbitals
basis_name = "6-31+g"
structure = molecule("water")
basis = basisset(structure, basis_name)
mesh, _ = uniform_mesh()
actual = basis(mesh)

mol = to_pyscf(structure, basis_name)
expect_ao = mol.eval_gto("GTOval_cart", np.asarray(mesh))
assert_allclose(actual, expect_ao, atol=1e-6)

# Molecular orbitals
mf = mol.KS()
mf.kernel()
C = jnp.array(mf.mo_coeff, dtype=jnp.float32)
actual = basis.occupancy * C @ C.T
expect = jnp.array(mf.make_rdm1(), dtype=jnp.float32)
assert_allclose(actual, expect, atol=1e-6)

# Electron density
actual = electron_density(basis, mesh, C)
expect = eval_rho(mol, expect_ao, mf.make_rdm1(), "lda")
assert_allclose(actual, expect, atol=1e-6)
File renamed without changes.

0 comments on commit 8b0ee07

Please sign in to comment.