Skip to content

Commit

Permalink
adj list first oracle wip
Browse files Browse the repository at this point in the history
  • Loading branch information
anurudhp committed Aug 30, 2024
1 parent e10662d commit 490afa9
Showing 1 changed file with 63 additions and 7 deletions.
70 changes: 63 additions & 7 deletions qualtran/bloqs/max_k_xor_sat/kikuchi_adjacency_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter

import sympy
from attrs import frozen

from qualtran import Bloq, bloq_example, BloqBuilder, BloqDocSpec, QAny, Signature, Soquet, SoquetT
from qualtran import (
Bloq,
bloq_example,
BloqBuilder,
BloqDocSpec,
QAny,
Signature,
Soquet,
SoquetT,
QBit,
)
from qualtran.resource_counting import BloqCountT, SympySymbolAllocator
from qualtran.symbolics import SymbolicInt
from .arithmetic import SymmetricDifference
from .arithmetic.equals import Equals

from .kxor_instance import KXorInstance
from .shims import ArbitraryGate, soft_O
from ..arithmetic import AddK
from ..mcmt import And


@frozen
class ColumnOfKthNonZeroEntry(Bloq):
r"""Given $(S, k)$, compute the column of the $k$-th non-zero entry in row $S$.
If the output is denoted as $f(S, k)$, then this bloq maps
$(S, k, z)$ to $(S, k, z \oplus f(S, k)$.
$(S, k, z, b)$ to $(S, k, z \oplus f'(S, k), b \oplus (k \ge s))$.
where $s$ is the sparsity, and $f'(S, k)$ is by extending $f$
such that for all $k \ge s$, $f'(S, k) = k$.
Using $f'$ ensures the computation is reversible.
Note: we must use the same extension $f'$ for both oracles.
This algorithm is described by the following pseudo-code:
```
def forward(S, k) -> f_S_k:
nnz := 0 # counter
for j in range(\bar{m}):
T := S \Delta U_j
entry := KikuchiMatrixEntry(S, T)
if entry != 0:
if |T| == l:
nnz := nnz + 1
if nnz == k:
f_S_k ^= T
Expand All @@ -70,14 +89,44 @@ def forward(S, k) -> f_S_k:
@property
def signature(self) -> 'Signature':
return Signature.build_from_dtypes(
S=QAny(self.index_bitsize), k=QAny(self.index_bitsize), T=QAny(self.index_bitsize)
S=QAny(self.index_bitsize),
k=QAny(self.index_bitsize),
T=QAny(self.index_bitsize),
flag=QBit(),
)

@property
def index_bitsize(self) -> SymbolicInt:
return self.ell * self.inst.index_bitsize

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> set['BloqCountT']:
m = self.inst.num_unique_constraints
ell, k = self.ell, self.inst.k
logn = self.inst.index_bitsize

counts_forward = Counter[Bloq]()

# compute symmetric differences for each constraint
counts_forward[SymmetricDifference(ell, k, ell, logn)] += m

# counter
counts_forward[AddK(logn, 1).controlled()] += m

# compare counter each time
counts_forward[Equals(logn)] += m

# when counter is equal (and updated in this iteration), we can copy the result
counts_forward[And()] += m

# all counts
counts = Counter[Bloq]()

for bloq, nb in counts_forward.items():
counts[bloq] += nb
counts[bloq.adjoint()] += nb

return bb.add(counts.items())

return {(ArbitraryGate(), soft_O(self.inst.m * self.ell * self.index_bitsize))}


Expand All @@ -86,7 +135,11 @@ class IndexOfNonZeroColumn(Bloq):
r"""Given $(S, T)$, compute $k$ such that $T$ is the $k$-th non-zero entry in row $S$.
If $f(S, k)$ denotes the $k$-th non-zero entry in row $S$,
then this bloq maps $(S, f(S, k), z)$ to $(S, f(S, k), z \oplus k)$.
then this bloq maps $(S, f'(S, k), z, b)$ to $(S, f'(S, k), z \oplus k, b \oplus )$.
where $s$ is the sparsity, and $f'(S, k)$ is by extending $f$
such that for all $k \ge s$, $f'(S, k) = k$.
Using $f'$ ensures the computation is reversible.
Note: we must use the same extension $f'$ for both oracles.
This algorithm is described by the following pseudo-code:
```
Expand Down Expand Up @@ -115,7 +168,10 @@ def reverse(S, f_S_k) -> k:
@property
def signature(self) -> 'Signature':
return Signature.build_from_dtypes(
S=QAny(self.index_bitsize), T=QAny(self.index_bitsize), k=QAny(self.index_bitsize)
S=QAny(self.index_bitsize),
T=QAny(self.index_bitsize),
k=QAny(self.index_bitsize),
flag=QBit(),
)

@property
Expand Down

0 comments on commit 490afa9

Please sign in to comment.