Skip to content

Commit

Permalink
First pass
Browse files Browse the repository at this point in the history
  • Loading branch information
caleb-johnson committed Oct 28, 2024
1 parent a90d78e commit 62282e4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
46 changes: 39 additions & 7 deletions qiskit_addon_sqd/fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence

import numpy as np
from jax import Array, config, grad, jit, vmap
Expand All @@ -43,6 +44,35 @@
config.update("jax_enable_x64", True) # To deal with large integers


class SCIState(fci.selected_ci.SCIVector):
"""An immutable, lightweight wrapper for the ``pyscf.fci.selected_ci.SCIVector`` class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Light check on `_strs` structure once during initialization
if not (isinstance(self._strs, Sequence) and len(self._strs) == 2) or not all(isinstance(strs, np.ndarray) for strs in self._strs):
raise ValueError(
"Cannot instantiate SCIState with input _strs field: {self._strs}."
)
if self.shape != (len(self._strs[0]), len(self._strs[1])):
raise ValueError(
"Cannot instantiate SCIState with array shape ({self.shape}) and CI "
"string shape ({len(self._strs[0])}, {len(self._strs[1])})."
)

# Don't allow the array to be mutated
self.flags.writeable = False

@property
def ci_strs_a(self):
"""The alpha determinants."""
return self._strs[0]
@property
def ci_strs_b(self):
"""The beta determinants."""
return self._strs[1]


def solve_fermion(
bitstring_matrix: tuple[np.ndarray, np.ndarray] | np.ndarray,
/,
Expand All @@ -53,7 +83,7 @@ def solve_fermion(
spin_sq: int | None = None,
max_davidson: int = 100,
verbose: int | None = None,
) -> tuple[float, np.ndarray, list[np.ndarray], float]:
) -> tuple[float, SCIState, list[np.ndarray], float]:
"""
Approximate the ground state given molecular integrals and a set of electronic configurations.
Expand Down Expand Up @@ -82,8 +112,8 @@ def solve_fermion(
Returns:
A tuple containing:
- Minimum energy from SCI calculation
- SCI coefficients
- Average orbital occupancy
- The SCI ground state
- Average occupancy of the alpha and beta orbitals, respectively
- Expectation value of spin-squared
"""
if isinstance(bitstring_matrix, tuple):
Expand All @@ -107,7 +137,7 @@ def solve_fermion(
myci = fci.selected_ci.SelectedCI()
if spin_sq is not None:
myci = fci.addons.fix_spin_(myci, ss=spin_sq)
e_sci, coeffs_sci = fci.selected_ci.kernel_fixed_space(
e_sci, sci_vec = fci.selected_ci.kernel_fixed_space(
myci,
hcore,
eri,
Expand All @@ -117,14 +147,16 @@ def solve_fermion(
verbose=verbose,
max_cycle=max_davidson,
)
# Convert the PySCF SCIVector to internal format
sci_state = SCIState(sci_vec)
# Calculate the avg occupancy of each orbital
dm1 = myci.make_rdm1s(coeffs_sci, norb, (num_up, num_dn))
dm1 = myci.make_rdm1s(sci_state, norb, (num_up, num_dn))
avg_occupancy = [np.diagonal(dm1[0]), np.diagonal(dm1[1])]

# Compute total spin
spin_squared = myci.spin_square(coeffs_sci, norb, (num_up, num_dn))[0]
spin_squared = myci.spin_square(sci_state, norb, (num_up, num_dn))[0]

return e_sci, coeffs_sci, avg_occupancy, spin_squared
return e_sci, sci_state, avg_occupancy, spin_squared


def optimize_orbitals(
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/wf-amps-d0fd5b930346adaf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
upgrade:
- |
The ground state returned by :func:`qiskit_addon_sqd.fermion.solve_fermion` will now be an instance of :class:`qiskit_addon_sqd.fermion.SCIState`, rather than a PySCF ``SCIVector`` instance.

0 comments on commit 62282e4

Please sign in to comment.