From c571c6d0b185af40c3d97c7a39168824d401fd81 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Sun, 8 Oct 2023 12:07:37 +0000 Subject: [PATCH] add num_segments in eri segment_sum --- pyscf_ipu/experimental/integrals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 74224cf..32ddac8 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -279,7 +279,7 @@ def eri_basis_sparse(b: Basis): tree_map(lambda x: jnp.take(x, idx, axis=0), primitives) for idx in indices ] eris = cijkl * vmap_eri_primitives(*pijkl) - return segment_sum(eris, batch) + return segment_sum(eris, batch, num_segments=count + 1) def eri_basis(b: Basis):