Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement Windowed Modular Exponentiation #1468

Draft
wants to merge 29 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d70debe
Add rsa files - needs a lot of work just stashing it for now
fpapa250 Sep 25, 2024
74bab90
Merge branch 'main' into rsa-improvements
fpapa250 Sep 28, 2024
e5daf75
Made some structure changes RSA
fpapa250 Sep 28, 2024
dcf1284
Rework rsa mod exp bloqs to work in a rsa phase estimation circuit
fpapa250 Sep 29, 2024
dd9d251
Fix mypy issues
fpapa250 Sep 29, 2024
cb82d39
Add some classical simulation
fpapa250 Oct 1, 2024
14927c0
Implement primitives for ModExp
fpapa250 Oct 2, 2024
1ec788f
Merge branch 'main' into mod_exp_subroutines
fpapa250 Oct 2, 2024
fbdd7e1
Fix serialization test error
fpapa250 Oct 2, 2024
92c76e7
Merge branch 'mod_exp_subroutines' of github.com:fpapa250/Qualtran in…
fpapa250 Oct 2, 2024
5fdc417
Change Union -> SymbolicInt
fpapa250 Oct 2, 2024
6435600
Fix nits
fpapa250 Oct 6, 2024
f06e2f8
Better symbolic messages
fpapa250 Oct 7, 2024
dd45a8f
Better symbolic decomposition error messages
fpapa250 Oct 7, 2024
9fb62fd
Merge branch 'main' into mod_exp_subroutines
fpapa250 Oct 10, 2024
a9b9f56
Fix merge conflicts
fpapa250 Oct 10, 2024
3bf7e63
Fixed docstring to be more readable (hopefully)
fpapa250 Oct 11, 2024
284c552
Refactor RSA to have a phase estimation circuit and a classical simul…
fpapa250 Oct 12, 2024
885b33e
Merge branch 'main' into rsa-improvements
fpapa250 Oct 12, 2024
bd41420
Fix notebook specs merge conflict
fpapa250 Oct 12, 2024
f665ef0
Merge branch 'mod_exp_subroutines' into rsa-window
fpapa250 Oct 13, 2024
7d2bcba
Super WIP windowed mod exp
fpapa250 Oct 13, 2024
ba18d5f
Bloq decomposition complete - needs testing
fpapa250 Oct 14, 2024
45b58b4
More work on decomposition
fpapa250 Oct 15, 2024
d7460a4
stash current changes
fpapa250 Oct 16, 2024
db4828c
Merge branch 'main' into rsa-window
fpapa250 Oct 22, 2024
36e42b3
Partial bugfix windowed arithmetic
fpapa250 Oct 23, 2024
3138cae
Fix merge conflicts
fpapa250 Oct 23, 2024
edbcc51
stash changes for now
fpapa250 Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 148 additions & 22 deletions qualtran/bloqs/factoring/rsa/rsa_mod_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@
SoquetT,
)
from qualtran._infra.registers import Side
from qualtran.bloqs.arithmetic import Add
from qualtran.bloqs.arithmetic.subtraction import SubtractFrom
from qualtran.bloqs.basic_gates.swap import Swap
from qualtran.bloqs.basic_gates.z_basis import IntState
from qualtran.bloqs.data_loading.qroam_clean import QROAMClean
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.bloqs.mod_arithmetic.mod_addition import ModAdd
from qualtran.bloqs.mod_arithmetic.mod_subtraction import ModSub
from qualtran.drawing import Text, WireSymbol
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting.generalizers import ignore_split_join
from qualtran.simulation.classical_sim import ClassicalValT
from qualtran.symbolics import is_symbolic
from qualtran.symbolics.types import SymbolicInt
from qualtran.symbolics import is_symbolic, Shaped, SymbolicInt


@frozen
Expand All @@ -54,10 +59,13 @@ class ModExp(Bloq):
This bloq decomposes into controlled modular exponentiation for each exponent bit.

Args:
base: The integer base of the exponentiation
mod: The integer modulus
exp_bitsize: The size of the `exponent` thru-register
x_bitsize: The size of the `x` right-register
base: The integer base of the exponentiation.
mod: The integer modulus.
exp_bitsize: The size of the `exponent` thru-register.
x_bitsize: The size of the `x` right-register.
exp_window_size: The window size of windowed arithmetic on the controlled modular
multiplications.
mult_window_size: The window size of windowed arithmetic on the modular product additions.

Registers:
exponent: The exponent
Expand All @@ -66,12 +74,20 @@ class ModExp(Bloq):
References:
[How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749).
Gidney and Ekerå. 2019.

[Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095).
Stephane Beauregard. 2003.

[Windowed quantum arithmetic](https://arxiv.org/abs/1905.07682).
Craig Gidney. 2019.
"""

base: 'SymbolicInt'
mod: 'SymbolicInt'
exp_bitsize: 'SymbolicInt'
x_bitsize: 'SymbolicInt'
exp_window_size: Optional['SymbolicInt'] = None
mult_window_size: Optional['SymbolicInt'] = None

def __attrs_post_init__(self):
if not is_symbolic(self.base, self.mod):
Expand All @@ -87,12 +103,7 @@ def signature(self) -> 'Signature':
)

@classmethod
def make_for_shor(
cls,
big_n: 'SymbolicInt',
g: Optional['SymbolicInt'] = None,
rs: Optional[np.random.RandomState] = None,
):
def make_for_shor(cls, big_n: 'SymbolicInt', g: Optional['SymbolicInt'] = None, exp_window_size: Optional['SymbolicInt'] = None, mult_window_size: Optional['SymbolicInt'] = None, rs: Optional[np.random.RandomState] = None):
"""Factory method that sets up the modular exponentiation for a factoring run.

Args:
Expand All @@ -115,29 +126,132 @@ def make_for_shor(
g = rs.randint(2, int(big_n))
if math.gcd(g, int(big_n)) == 1:
break
return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n)
return cls(base=g, mod=big_n, exp_bitsize=2 * little_n, x_bitsize=little_n, exp_window_size=exp_window_size, mult_window_size=mult_window_size)

def qrom(self, data):
if is_symbolic(self.exp_bitsize) or is_symbolic(self.exp_window_size):
log_block_sizes = None
if is_symbolic(self.exp_bitsize) and not is_symbolic(self.exp_window_size):
# We assume that bitsize is much larger than window_size
log_block_sizes = (0,)
return QROAMClean(
[
data,
],
selection_bitsizes=(self.exp_window_size, self.mult_window_size),
target_bitsizes=(self.x_bitsize,),
log_block_sizes=log_block_sizes,
)

return QROAMClean(
[data],
selection_bitsizes=(self.exp_window_size, self.mult_window_size),
target_bitsizes=(self.x_bitsize,),
)


def _CtrlModMul(self, k: 'SymbolicInt'):
"""Helper method to return a `CModMulK` with attributes forwarded."""
return CModMulK(QUInt(self.x_bitsize), k=k, mod=self.mod)

def build_composite_bloq(self, bb: 'BloqBuilder', exponent: 'Soquet') -> Dict[str, 'SoquetT']:
if is_symbolic(self.exp_bitsize):
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.")
# https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
raise DecomposeTypeError(f"Cannot decompose {self} with symbolic `exp_bitsize`.")
x = bb.add(IntState(val=1, bitsize=self.x_bitsize))
exponent = bb.split(exponent)

base = self.base % self.mod
for j in range(self.exp_bitsize - 1, 0 - 1, -1):
exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x)
base = (base * base) % self.mod
if self.exp_window_size is not None and self.mult_window_size is not None:
k = self.base

a = bb.split(x)
b = bb.add(IntState(val=0, bitsize=self.x_bitsize))

ei = np.split(np.array(exponent), self.exp_bitsize // self.exp_window_size)
for i in range(self.exp_bitsize // self.exp_window_size):
kes = [pow(k, 2**i * x_e, self.mod) for x_e in range(2**self.exp_window_size)]
kes_inv = [pow(x_e, -1, self.mod) for x_e in kes]

mi = np.split(np.array(a), self.x_bitsize // self.mult_window_size)
for j in range(self.x_bitsize // self.mult_window_size):
data = list([(ke * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke in kes)
ei_i = bb.join(ei[(self.exp_bitsize // self.exp_window_size) - i - 1], QUInt((self.exp_window_size)))
mi_i = bb.join(mi[(self.x_bitsize // self.mult_window_size) - j - 1], QUInt((self.mult_window_size)))
ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, b = bb.add(ModAdd(self.x_bitsize, self.mod), x=t, y=b)
junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))}
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping)
ei[(self.exp_bitsize // self.exp_window_size) - i - 1] = bb.split(ei_i)
mi[(self.x_bitsize // self.mult_window_size) - j - 1] = bb.split(mi_i)

a = np.concatenate(mi, axis=None)
a = bb.join(a, QUInt(self.x_bitsize))

b = bb.split(b)
mi = np.split(np.array(b), self.x_bitsize // self.mult_window_size)
for j in range(self.x_bitsize // self.mult_window_size):
data = list([(ke_inv * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke_inv in kes_inv)
ei_i = bb.join(ei[(self.exp_bitsize // self.exp_window_size) - i - 1], QUInt((self.exp_window_size)))
mi_i = bb.join(mi[(self.x_bitsize // self.mult_window_size) - j - 1], QUInt((self.mult_window_size)))
ei_i, mi_i, t, *junk = bb.add(self.qrom(data), selection0=ei_i, selection1=mi_i)
t, a = bb.add(ModSub(QUInt(self.x_bitsize), self.mod), x=t, y=a)
junk_mapping = {f'junk_target{i}_': junk[i] for i in range(len(junk))}
ei_i, mi_i = bb.add(self.qrom(data).adjoint(), selection0=ei_i, selection1=mi_i, target0_=t, **junk_mapping)
ei[(self.exp_bitsize // self.exp_window_size) - i - 1] = bb.split(ei_i)
mi[(self.x_bitsize // self.mult_window_size) - j - 1] = bb.split(mi_i)

b = np.concatenate(mi, axis=None)

b = bb.join(b, QUInt(self.x_bitsize))

a, b = bb.add(Swap(self.x_bitsize), x=a, y=b)

a = bb.split(a)

x = bb.join(a, QUInt(self.x_bitsize))
exponent = np.concatenate(ei, axis=None)
bb.free(b, dirty=True)
else:
# https://en.wikipedia.org/wiki/Modular_exponentiation#Right-to-left_binary_method
base = self.base % self.mod
for j in range(self.exp_bitsize - 1, 0 - 1, -1):
exponent[j], x = bb.add(self._CtrlModMul(k=base), ctrl=exponent[j], x=x)
base = (base * base) % self.mod

return {'exponent': bb.join(exponent, dtype=QUInt(self.exp_bitsize)), 'x': x}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
k = ssa.new_symbol('k')
return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1}
if self.exp_window_size is not None and self.mult_window_size is not None:
if is_symbolic(self.exp_window_size, self.mult_window_size):
num_iterations = self.exp_bitsize // self.exp_window_size
return {self.qrom(Shaped((2**(self.exp_window_size+self.mult_window_size),))): 1,
self.qrom(Shaped((2**(self.exp_window_size+self.mult_window_size),))).adjoint(): 1,
ModAdd(self.x_bitsize, self.mod): 1,
ModSub(QUInt(self.x_bitsize), self.mod): 1,
IntState(val=1, bitsize=self.x_bitsize): 1, Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size}
else:
cg = {IntState(val=1, bitsize=self.x_bitsize): 1, Swap(self.x_bitsize): self.exp_bitsize // self.exp_window_size}

k = self.base
for i in range(self.exp_bitsize // self.exp_window_size):
kes = [pow(k, 2**i * x_e, self.mod) for x_e in range(2**self.exp_window_size)]
kes_inv = [pow(x_e, -1, self.mod) for x_e in kes]

for j in range(self.x_bitsize // self.mult_window_size):
data = list([(ke * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke in kes)
cg[self.qrom(data)] = cg.get(self.qrom(data), 0) + 1
cg[ModAdd(self.x_bitsize, self.mod)] = cg.get(ModAdd(self.x_bitsize, self.mod), 0) + 1
cg[self.qrom(data).adjoint()] = cg.get(self.qrom(data).adjoint(), 0) + 1

for j in range(self.x_bitsize // self.mult_window_size):
data = list([(ke_inv * f * 2**j) % self.mod for f in range(2**self.mult_window_size)] for ke_inv in kes_inv)
cg[self.qrom(data)] = cg.get(self.qrom(data), 0) + 1
cg[ModSub(QUInt(self.x_bitsize), self.mod)] = cg.get(ModSub(QUInt(self.x_bitsize), self.mod), 0) + 1
cg[self.qrom(data).adjoint()] = cg.get(self.qrom(data).adjoint(), 0) + 1

return cg
else:
k = ssa.new_symbol('k')
return {self._CtrlModMul(k=k): self.exp_bitsize, IntState(val=1, bitsize=self.x_bitsize): 1}

def on_classical_vals(self, exponent) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
return {'exponent': exponent, 'x': (self.base**exponent) % self.mod}
Expand Down Expand Up @@ -172,11 +286,23 @@ def _modexp() -> ModExp:
return modexp


@bloq_example(generalizer=(ignore_split_join, _generalize_k))
def _modexp_window() -> ModExp:
modexp_window = ModExp.make_for_shor(big_n=13 * 17, g=9, exp_window_size=8, mult_window_size=4)
return modexp_window


@bloq_example
def _modexp_symb() -> ModExp:
g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')
modexp_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x)
return modexp_symb

@bloq_example
def _modexp_window_symb() -> ModExp:
g, N, n_e, n_x, w_e, w_m = sympy.symbols('g N n_e, n_x w_e w_m')
modexp_window_symb = ModExp(base=g, mod=N, exp_bitsize=n_e, x_bitsize=n_x, exp_window_size=w_e, mult_window_size=w_m)
return modexp_window_symb


_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb))
_RSA_MODEXP_DOC = BloqDocSpec(bloq_cls=ModExp, examples=(_modexp_small, _modexp, _modexp_symb, _modexp_window, _modexp_window_symb))
38 changes: 36 additions & 2 deletions qualtran/bloqs/factoring/rsa/rsa_mod_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from qualtran import Bloq
from qualtran.bloqs.bookkeeping import Join, Split
from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, ModExp
from qualtran.bloqs.factoring.rsa.rsa_mod_exp import _modexp, _modexp_small, _modexp_symb, _modexp_window, _modexp_window_symb, ModExp
from qualtran.bloqs.mod_arithmetic import CModMulK
from qualtran.drawing import Text
from qualtran.resource_counting import SympySymbolAllocator
Expand All @@ -48,6 +48,40 @@ def test_mod_exp_consistent_classical():
for i in range(len(ret1)):
np.testing.assert_array_equal(ret1[i], ret2[i])

@pytest.mark.parametrize('p', [11, 13])
def test_mod_exp_window_consistent_classical_fast(p):
bloq = ModExp.make_for_shor(big_n=p, exp_window_size=2, mult_window_size=2)

rs = np.random.RandomState(52)
n_x = int(np.ceil(np.log2(p)))

for _ in range(10):
exponent = rs.randint(1, 2**n_x)

ret1 = bloq.call_classically(exponent=exponent)
ret2 = bloq.decompose_bloq().call_classically(exponent=exponent)
assert len(ret1) == len(ret2)
for i in range(len(ret1)):
np.testing.assert_array_equal(ret1[i], ret2[i])

'''
@pytest.mark.slow
@pytest.mark.parametrize('p, w_e, w_m', [(p, w_e, w_m) for p in (7, 11, 13) for w_e in range(1, (2 * int(np.ceil(np.log2(p)))) + 1) if (2 * int(np.ceil(np.log2(p)))) % w_e == 0 for w_m in range(1, int(np.ceil(np.log2(p))) + 1) if int(np.ceil(np.log2(p))) % w_m == 0])
def test_mod_exp_window_consistent_classical(p, w_e, w_m):
bloq = ModExp.make_for_shor(big_n=p, exp_window_size=w_e, mult_window_size=w_m)

rs = np.random.RandomState(52)
n_x = int(np.ceil(np.log2(p)))

for _ in range(10):
exponent = rs.randint(1, 2**n_x)

ret1 = bloq.call_classically(exponent=exponent)
ret2 = bloq.decompose_bloq().call_classically(exponent=exponent)
assert len(ret1) == len(ret2)
for i in range(len(ret1)):
np.testing.assert_array_equal(ret1[i], ret2[i])
'''

def test_modexp_symb_manual():
g, N, n_e, n_x = sympy.symbols('g N n_e, n_x')
Expand Down Expand Up @@ -89,7 +123,7 @@ def test_mod_exp_t_complexity():
assert tcomp.t > 0


@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small])
@pytest.mark.parametrize('bloq', [_modexp, _modexp_symb, _modexp_small, _modexp_window, _modexp_window_symb])
def test_modexp(bloq_autotester, bloq):
bloq_autotester(bloq)

Expand Down
13 changes: 12 additions & 1 deletion qualtran/bloqs/factoring/rsa/rsa_phase_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,26 @@ class RSAPhaseEstimate(Bloq):
n: The bitsize of the modulus N.
mod: The modulus N; a part of the public key for RSA.
base: A base for modular exponentiation.
exp_window_size: The window size of windowed arithmetic on the controlled modular
multiplications.
mult_window_size: The window size of windowed arithmetic on the modular product additions.

References:
[How to factor 2048 bit RSA integers in 8 hours using 20 million noisy qubits](https://arxiv.org/abs/1905.09749).
Gidney and Ekerå. 2019.

[Circuit for Shor's algorithm using 2n+3 qubits](https://arxiv.org/abs/quant-ph/0205095).
Beauregard. 2003. Fig 1.
Stephane Beauregard. 2003.

[Windowed quantum arithmetic](https://arxiv.org/abs/1905.07682).
Craig Gidney. 2019.
"""

n: 'SymbolicInt'
mod: 'SymbolicInt'
base: 'SymbolicInt'
exp_window_size: 'SymbolicInt' = 1
mult_window_size: 'SymbolicInt' = 1

@cached_property
def signature(self) -> 'Signature':
Expand Down
3 changes: 2 additions & 1 deletion qualtran/serialization/resolver_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@
"qualtran.bloqs.data_loading.qrom.QROM": qualtran.bloqs.data_loading.qrom.QROM,
"qualtran.bloqs.data_loading.qroam_clean.QROAMClean": qualtran.bloqs.data_loading.qroam_clean.QROAMClean,
"qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjoint,
"qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjointWrapper": qualtran.bloqs.data_loading.qroam_clean.QROAMCleanAdjointWrapper,
"qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM": qualtran.bloqs.data_loading.select_swap_qrom.SelectSwapQROM,
"qualtran.bloqs.mod_arithmetic.CModAddK": qualtran.bloqs.mod_arithmetic.CModAddK,
"qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd": qualtran.bloqs.mod_arithmetic.mod_addition.ModAdd,
Expand All @@ -340,7 +341,7 @@
"qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK": qualtran.bloqs.mod_arithmetic.mod_addition.CModAddK,
"qualtran.bloqs.mod_arithmetic.mod_addition.CtrlScaleModAdd": qualtran.bloqs.mod_arithmetic.CtrlScaleModAdd,
"qualtran.bloqs.mod_arithmetic.ModAdd": qualtran.bloqs.mod_arithmetic.ModAdd,
"qualtran.bloqs.mod_arithmetic.ModSub": qualtran.bloqs.mod_arithmetic.ModSub,
"qualtran.bloqs.mod_arithmetic.mod_subtraction.ModSub": qualtran.bloqs.mod_arithmetic.mod_subtraction.ModSub,
"qualtran.bloqs.mod_arithmetic.CModSub": qualtran.bloqs.mod_arithmetic.CModSub,
"qualtran.bloqs.mod_arithmetic.mod_subtraction.ModNeg": qualtran.bloqs.mod_arithmetic.mod_subtraction.ModNeg,
"qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg": qualtran.bloqs.mod_arithmetic.mod_subtraction.CModNeg,
Expand Down
Loading