Skip to content

Commit

Permalink
redo without black fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
hatemhelal committed Sep 18, 2023
1 parent 3720ee6 commit e603725
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions pyscf_ipu/nanoDFT/nanoDFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyscf
import h5py
import chex
from jaxtyping import Float, Array
from jaxtyping import Float, Array, Int
from jsonargparse import CLI, Namespace
from functools import partial
from collections import namedtuple
Expand Down Expand Up @@ -114,6 +114,8 @@ def _nanoDFT(state, opts, mol):
FloatNxN = Float[Array, "N N"]
FloatNxNxNxN = Float[Array, "N N N N"]
Grid = Float[Array, "4 grid_size N"]
FloatArray = Float[Array, "..."]
IntArray = Int[Array, "..."]

@chex.dataclass
class IterationState:
Expand All @@ -134,13 +136,13 @@ class IterationState:
O (FloatNxN): Overlap integrals in the LCAO basis set.
grid_AO (Grid): Numerical grid used to evaluate the exchange-correlation energy integral.
ERI (FloatNxNxNxN): Two-electron repulsion integrals in the LCAO basis set.
grid_weights (Array): Weights associated with the grid_AO
grid_weights (FloatArray): Weights associated with the grid_AO
mask (FloatN): Orbital occupation mask.
input_floats (Array): Supplementary vector of floats for ERI evaluation with libcint
input_ints (Array): Supplementary vector of ints for ERI evaluation with libcint
input_floats (FloatArray): Supplementary vector of floats for ERI evaluation with libcint
input_ints (IntArray): Supplementary vector of ints for ERI evaluation with libcint
L_inv (FloatNxN): Defined as the inverse of the Cholesky decomposition of the overlap matrix.
Used to change generalised eig problem into an eigh one.
diis_history (Array): Direct Inversion of Iterative Subspace (DIIS) is an optional method that
diis_history (FloatArray): Direct Inversion of Iterative Subspace (DIIS) is an optional method that
can accelerate convergence of the SCF iterations. Maintains a history of how the Hamiltonian
is evolving across the SCF iterations.
Expand All @@ -153,12 +155,12 @@ class IterationState:
O: FloatNxN
grid_AO: Grid
ERI: FloatNxNxNxN
grid_weights: Array
grid_weights: FloatArray
mask: FloatN
input_floats: Array
input_ints: Array
input_floats: FloatArray
input_ints: IntArray
L_inv: FloatNxN
diis_history: Array
diis_history: FloatArray


def init_dft_tensors_cpu(mol, opts, DIIS_iters=9):
Expand Down

0 comments on commit e603725

Please sign in to comment.