Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restrict classical action of certain arithmetic bloqs #1518

Merged
merged 8 commits into from
Jan 21, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix bug in KaliskiStep3 and add tests for all steps
NoureldinYosri committed Nov 13, 2024
commit cd00a9a4989b9a469bad68f8c2008c87ecb8eded
19 changes: 8 additions & 11 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
@@ -1462,20 +1462,17 @@ def on_classical_vals(
c: Optional['ClassicalValT'] = None,
target: Optional['ClassicalValT'] = None,
) -> Dict[str, 'ClassicalValT']:
if self._op_symbol in ('>', '<='):
c_val = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c_val = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
if self.uncompute:
assert c == add_ints(
int(a),
int(b),
num_bits=int(self.dtype.bitsize),
is_signed=isinstance(self.dtype, QInt),
)
assert c == c_val
assert target == self._classical_comparison(a, b)
return {'a': a, 'b': b}
if self._op_symbol in ('>', '<='):
c = add_ints(-int(a), int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
else:
c = add_ints(int(a), -int(b), num_bits=self.dtype.bitsize + 1, is_signed=False)
return {'a': a, 'b': b, 'c': c, 'target': int(self._classical_comparison(a, b))}
assert c is None
assert target is None
return {'a': a, 'b': b, 'c': c_val, 'target': int(self._classical_comparison(a, b))}

def _compute(self, bb: 'BloqBuilder', a: 'Soquet', b: 'Soquet') -> Dict[str, 'SoquetT']:
if self._op_symbol in ('>', '<='):
23 changes: 11 additions & 12 deletions qualtran/bloqs/mod_arithmetic/mod_division.py
Original file line number Diff line number Diff line change
@@ -72,8 +72,6 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, v: int, m: int, f: int, is_terminal: int
) -> Dict[str, 'ClassicalValT']:
print('here')
assert False
m ^= f & (v == 0)
assert is_terminal == 0
is_terminal ^= m
@@ -101,10 +99,10 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
if is_symbolic(self.bitsize):
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize)
cvs: Union[HasLength, List[int]] = HasLength(self.bitsize + 1)
else:
cvs = [0] * int(self.bitsize)
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 2}
cvs = [0] * int(self.bitsize) + [1]
return {MultiAnd(cvs=cvs): 1, MultiAnd(cvs=cvs).adjoint(): 1, CNOT(): 3}


@frozen
@@ -197,25 +195,25 @@ def on_classical_vals(
def build_composite_bloq(
self, bb: 'BloqBuilder', u: Soquet, v: Soquet, b: Soquet, a: Soquet, m: Soquet, f: Soquet
) -> Dict[str, 'SoquetT']:
u, v, junk, greater_than = bb.add(
u, v, junk_c, greater_than = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)), a=u, b=v
)

(greater_than, f, b), junk, ctrl = bb.add(
(greater_than, f, b), junk_m, ctrl = bb.add(
MultiAnd(cvs=(1, 1, 0)), ctrl=(greater_than, f, b)
)

ctrl, a = bb.add(CNOT(), ctrl=ctrl, target=a)
ctrl, m = bb.add(CNOT(), ctrl=ctrl, target=m)

greater_than, f, b = bb.add(
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk, target=ctrl
MultiAnd(cvs=(1, 1, 0)).adjoint(), ctrl=(greater_than, f, b), junk=junk_m, target=ctrl
)
u, v = bb.add(
LinearDepthHalfGreaterThan(QMontgomeryUInt(self.bitsize)).adjoint(),
a=u,
b=v,
c=junk,
c=junk_c,
target=greater_than,
)
return {'u': u, 'v': v, 'b': b, 'a': a, 'm': m, 'f': f}
@@ -391,7 +389,7 @@ def build_composite_bloq(

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {
CNOT(): 4,
CNOT(): 3,
XGate(): 2,
ModDbl(QMontgomeryUInt(self.bitsize), self.mod): 1,
CSwapApprox(self.bitsize): 2,
@@ -475,7 +473,7 @@ def on_classical_vals(
of `f` and `m`.
"""
assert m == 0
is_terminal = f == 1 and v == 0
is_terminal = int(f == 1 and v == 0)
if f == 0:
# When `f = 0` this means that the algorithm is nearly over and that we just need to
# double the value of `r`.
@@ -489,7 +487,8 @@ def on_classical_vals(
f = 0
r = (r << 1) % self.mod
else:
m = (u % 2 == 1) & (v % 2 == 0)
m = ((u % 2 == 1) & (v % 2 == 0)) or (u % 2 == 1 and v % 2 == 1 and u > v)
m = int(m)
# Kaliski iteration as described in Fig7 of https://arxiv.org/pdf/2001.09580.
swap = (u % 2 == 0 and v % 2 == 1) or (u % 2 == 1 and v % 2 == 1 and u > v)
if swap:
78 changes: 77 additions & 1 deletion qualtran/bloqs/mod_arithmetic/mod_division_test.py
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@

import qualtran.testing as qlt_testing
from qualtran import QMontgomeryUInt
from qualtran.bloqs.mod_arithmetic import mod_division
from qualtran.bloqs.mod_arithmetic.mod_division import _kaliskimodinverse_example, KaliskiModInverse
from qualtran.resource_counting import get_cost_value, QECGatesCost
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
@@ -36,7 +37,7 @@ def test_kaliski_mod_inverse_classical_action(bitsize, mod):
continue
x_montgomery = dtype.uint_to_montgomery(x, mod)
res = blq.call_classically(x=x_montgomery)
print(x, x_montgomery)

assert res == cblq.call_classically(x=x_montgomery)
assert len(res) == 2
assert res[0] == dtype.montgomery_inverse(x_montgomery, mod)
@@ -99,3 +100,78 @@ def test_kaliskimodinverse_example(bloq_autotester):
@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('mod_division')


def test_kaliski_iteration_decomposition():
mod = 7
bitsize = 5
b = mod_division._KaliskiIteration(bitsize, mod)
cb = b.decompose_bloq()
for x in range(mod):
u = mod
v = x
r = 0
s = 1
f = 1

for _ in range(2 * bitsize):
inputs = {'u': u, 'v': v, 'r': r, 's': s, 'm': 0, 'f': f, 'is_terminal': 0}
res = b.call_classically(**inputs)
assert res == cb.call_classically(**inputs), f'{inputs=}'
u, v, r, s, _, f, _ = res # type: ignore

qlt_testing.assert_valid_bloq_decomposition(b)
qlt_testing.assert_equivalent_bloq_counts(b, generalizer=(ignore_alloc_free, ignore_split_join))


def test_kaliski_steps():
bitsize = 5
mod = 7
steps = [
mod_division._KaliskiIterationStep1(bitsize),
mod_division._KaliskiIterationStep2(bitsize),
mod_division._KaliskiIterationStep3(bitsize),
mod_division._KaliskiIterationStep4(bitsize),
mod_division._KaliskiIterationStep5(bitsize),
mod_division._KaliskiIterationStep6(bitsize, mod),
]
csteps = [b.decompose_bloq() for b in steps]

# check decomposition is valid.
for step in steps:
qlt_testing.assert_valid_bloq_decomposition(step)
qlt_testing.assert_equivalent_bloq_counts(
step, generalizer=(ignore_alloc_free, ignore_split_join)
)

# check that for all inputs all 2n iteration work when excuted directly on the 6 steps
# and their decompositions.
for x in range(mod):
u, v, r, s, f = mod, x, 0, 1, 1

for _ in range(2 * bitsize):
a = b = m = is_terminal = 0

res = steps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
assert res == csteps[0].call_classically(v=v, m=m, f=f, is_terminal=is_terminal)
v, m, f, is_terminal = res # type: ignore

res = steps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[1].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
assert res == csteps[2].call_classically(u=u, v=v, b=b, a=a, m=m, f=f)
u, v, b, a, m, f = res # type: ignore

res = steps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
assert res == csteps[3].call_classically(u=u, v=v, r=r, s=s, a=a)
u, v, r, s, a = res # type: ignore

res = steps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
assert res == csteps[4].call_classically(u=u, v=v, r=r, s=s, b=b, f=f)
u, v, r, s, b, f = res # type: ignore

res = steps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
assert res == csteps[5].call_classically(u=u, v=v, r=r, s=s, b=b, a=a, m=m, f=f)
u, v, r, s, b, a, m, f = res # type: ignore