Skip to content

Commit

Permalink
dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
caleb-johnson committed Oct 28, 2024
1 parent 144f1cf commit a8f8681
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions qiskit_addon_sqd/fermion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import warnings
from typing import NamedTuple
from dataclasses import dataclass

import numpy as np
from jax import Array, config, grad, jit, vmap
Expand All @@ -29,7 +29,8 @@
config.update("jax_enable_x64", True) # To deal with large integers


class SCIState(NamedTuple):
@dataclass(frozen=True)
class SCIState:
"""The amplitudes and determinants describing a quantum state.
``amplitudes`` is an ``MxN`` array such that ``M`` = ``len(ci_strs_a)``
Expand All @@ -41,6 +42,17 @@ class SCIState(NamedTuple):
ci_strs_a: np.ndarray
ci_strs_b: np.ndarray

def __post_init__(self):
"""Validate dimensions of inputs."""
object.__setattr__(
self, "amplitudes", np.asarray(self.amplitudes)
) # Convert to ndarray if not already
if self.amplitudes.shape != (len(self.ci_strs_a), len(self.ci_strs_b)):
raise ValueError(
f"'amplitudes' shape must be ({len(self.ci_strs_a)}, {len(self.ci_strs_b)}) "
f"but got {self.amplitudes.shape}"
)

def save(self, filename):
"""Save the SCIState object to an .npz file."""
np.savez(
Expand Down

0 comments on commit a8f8681

Please sign in to comment.