diff --git a/qiskit_addon_sqd/fermion.py b/qiskit_addon_sqd/fermion.py index afb4ce4..52a373d 100644 --- a/qiskit_addon_sqd/fermion.py +++ b/qiskit_addon_sqd/fermion.py @@ -31,6 +31,7 @@ from __future__ import annotations import warnings +from collections.abc import Sequence import numpy as np from jax import Array, config, grad, jit, vmap @@ -43,6 +44,35 @@ config.update("jax_enable_x64", True) # To deal with large integers +class SCIState(fci.selected_ci.SCIVector): + """An immutable, lightweight wrapper for the ``pyscf.fci.selected_ci.SCIVector`` class.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Light check on `_strs` structure once during initialization + if not (isinstance(self._strs, Sequence) and len(self._strs) == 2) or not all(isinstance(strs, np.ndarray) for strs in self._strs): + raise ValueError( + "Cannot instantiate SCIState with input _strs field: {self._strs}." + ) + if self.shape != (len(self._strs[0]), len(self._strs[1])): + raise ValueError( + "Cannot instantiate SCIState with array shape ({self.shape}) and CI " + "string shape ({len(self._strs[0])}, {len(self._strs[1])})." + ) + + # Don't allow the array to be mutated + self.flags.writeable = False + + @property + def ci_strs_a(self): + """The alpha determinants.""" + return self._strs[0] + @property + def ci_strs_b(self): + """The beta determinants.""" + return self._strs[1] + + def solve_fermion( bitstring_matrix: tuple[np.ndarray, np.ndarray] | np.ndarray, /, @@ -53,7 +83,7 @@ def solve_fermion( spin_sq: int | None = None, max_davidson: int = 100, verbose: int | None = None, -) -> tuple[float, np.ndarray, list[np.ndarray], float]: +) -> tuple[float, SCIState, list[np.ndarray], float]: """ Approximate the ground state given molecular integrals and a set of electronic configurations. @@ -82,8 +112,8 @@ def solve_fermion( Returns: A tuple containing: - Minimum energy from SCI calculation - - SCI coefficients - - Average orbital occupancy + - The SCI ground state + - Average occupancy of the alpha and beta orbitals, respectively - Expectation value of spin-squared """ if isinstance(bitstring_matrix, tuple): @@ -107,7 +137,7 @@ def solve_fermion( myci = fci.selected_ci.SelectedCI() if spin_sq is not None: myci = fci.addons.fix_spin_(myci, ss=spin_sq) - e_sci, coeffs_sci = fci.selected_ci.kernel_fixed_space( + e_sci, sci_vec = fci.selected_ci.kernel_fixed_space( myci, hcore, eri, @@ -117,14 +147,16 @@ def solve_fermion( verbose=verbose, max_cycle=max_davidson, ) + # Convert the PySCF SCIVector to internal format + sci_state = SCIState(sci_vec) # Calculate the avg occupancy of each orbital - dm1 = myci.make_rdm1s(coeffs_sci, norb, (num_up, num_dn)) + dm1 = myci.make_rdm1s(sci_state, norb, (num_up, num_dn)) avg_occupancy = [np.diagonal(dm1[0]), np.diagonal(dm1[1])] # Compute total spin - spin_squared = myci.spin_square(coeffs_sci, norb, (num_up, num_dn))[0] + spin_squared = myci.spin_square(sci_state, norb, (num_up, num_dn))[0] - return e_sci, coeffs_sci, avg_occupancy, spin_squared + return e_sci, sci_state, avg_occupancy, spin_squared def optimize_orbitals( diff --git a/releasenotes/notes/wf-amps-d0fd5b930346adaf.yaml b/releasenotes/notes/wf-amps-d0fd5b930346adaf.yaml new file mode 100644 index 0000000..ab98fbf --- /dev/null +++ b/releasenotes/notes/wf-amps-d0fd5b930346adaf.yaml @@ -0,0 +1,4 @@ +upgrade: + - | + The ground state returned by :func:`qiskit_addon_sqd.fermion.solve_fermion` will now be an instance of :class:`qiskit_addon_sqd.fermion.SCIState`, rather than a PySCF ``SCIVector`` instance. +