From e603725c4949c3a9bc9c59debbf22c514138a210 Mon Sep 17 00:00:00 2001 From: Hatem Helal Date: Mon, 18 Sep 2023 08:56:58 +0000 Subject: [PATCH] redo without black fmt --- pyscf_ipu/nanoDFT/nanoDFT.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index e9d1273..6464a19 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -6,7 +6,7 @@ import pyscf import h5py import chex -from jaxtyping import Float, Array +from jaxtyping import Float, Array, Int from jsonargparse import CLI, Namespace from functools import partial from collections import namedtuple @@ -114,6 +114,8 @@ def _nanoDFT(state, opts, mol): FloatNxN = Float[Array, "N N"] FloatNxNxNxN = Float[Array, "N N N N"] Grid = Float[Array, "4 grid_size N"] +FloatArray = Float[Array, "..."] +IntArray = Int[Array, "..."] @chex.dataclass class IterationState: @@ -134,13 +136,13 @@ class IterationState: O (FloatNxN): Overlap integrals in the LCAO basis set. grid_AO (Grid): Numerical grid used to evaluate the exchange-correlation energy integral. ERI (FloatNxNxNxN): Two-electron repulsion integrals in the LCAO basis set. - grid_weights (Array): Weights associated with the grid_AO + grid_weights (FloatArray): Weights associated with the grid_AO mask (FloatN): Orbital occupation mask. - input_floats (Array): Supplementary vector of floats for ERI evaluation with libcint - input_ints (Array): Supplementary vector of ints for ERI evaluation with libcint + input_floats (FloatArray): Supplementary vector of floats for ERI evaluation with libcint + input_ints (IntArray): Supplementary vector of ints for ERI evaluation with libcint L_inv (FloatNxN): Defined as the inverse of the Cholesky decomposition of the overlap matrix. Used to change generalised eig problem into an eigh one. - diis_history (Array): Direct Inversion of Iterative Subspace (DIIS) is an optional method that + diis_history (FloatArray): Direct Inversion of Iterative Subspace (DIIS) is an optional method that can accelerate convergence of the SCF iterations. Maintains a history of how the Hamiltonian is evolving across the SCF iterations. @@ -153,12 +155,12 @@ class IterationState: O: FloatNxN grid_AO: Grid ERI: FloatNxNxNxN - grid_weights: Array + grid_weights: FloatArray mask: FloatN - input_floats: Array - input_ints: Array + input_floats: FloatArray + input_ints: IntArray L_inv: FloatNxN - diis_history: Array + diis_history: FloatArray def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):