Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ECC] Cost enhancements #1405

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions qualtran/bloqs/arithmetic/_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from qualtran import Bloq, QBit, QUInt, Register, Signature
from qualtran.bloqs.basic_gates import Toffoli
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator
from qualtran.resource_counting import BloqCountDictT, CostKey, QubitCount, SympySymbolAllocator


@frozen
Expand All @@ -36,7 +36,13 @@ def signature(self) -> 'Signature':
return Signature([Register('ctrl', QBit(), shape=(self.n,)), Register('target', QBit())])

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
return {Toffoli(): self.n - 2}
return {Toffoli(): self.n - 1}

def my_static_costs(self, cost_key: 'CostKey'):
# TODO https://github.com/quantumlib/Qualtran/issues/1261
if cost_key == QubitCount():
return self.n + 1
return NotImplemented


@frozen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need for this shim. you can use the LinearDepthGreaterThan bloq

Expand All @@ -51,7 +57,7 @@ def signature(self) -> 'Signature':
)

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
# litinski
# Litinski 2023. Figure/Table 3.
return {Toffoli(): self.n}


Expand All @@ -62,3 +68,9 @@ class CHalf(Bloq):
@cached_property
def signature(self) -> 'Signature':
return Signature([Register('ctrl', QBit()), Register('x', QUInt(self.n))])

def build_call_graph(self, ssa: SympySymbolAllocator) -> BloqCountDictT:
# It's unclear what this operation is (as part of the ModInv circuit).
# If we assume it's a modular halving, then we can just run `ModDbl`
# backwards, and the cost is the same.
return {(Toffoli(), 2 * self.n)}
5 changes: 3 additions & 2 deletions qualtran/bloqs/arithmetic/addition.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"## `Add`\n",
"An n-bit addition gate.\n",
"\n",
"Implements $U|a\\rangle|b\\rangle \\rightarrow |a\\rangle|a+b\\rangle$ using $4n - 4 T$ gates.\n",
"Implements $U|a\\rangle|b\\rangle \\rightarrow |a\\rangle|a+b\\rangle$ using $n-1$ AND gates.\n",
"\n",
"#### Parameters\n",
" - `a_dtype`: Quantum datatype used to represent the integer a.\n",
Expand All @@ -49,7 +49,8 @@
" - `b`: A b_dtype.bitsize-sized input/output register (register b above). \n",
"\n",
"#### References\n",
" - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). \n"
" - [Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648). Gidney 2018. The construction used in this bloq, evolved from [2].\n",
" - [A new quantum ripple-carry addition circuit](https://arxiv.org/abs/quant-ph/0410184). Cuccaro et. al. 2004.\n"
]
},
{
Expand Down
12 changes: 6 additions & 6 deletions qualtran/bloqs/arithmetic/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
class Add(Bloq):
r"""An n-bit addition gate.

Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|a+b\rangle$ using $4n - 4 T$ gates.
Implements $U|a\rangle|b\rangle \rightarrow |a\rangle|a+b\rangle$ using $n-1$ AND gates.

Args:
a_dtype: Quantum datatype used to represent the integer a.
Expand All @@ -89,7 +89,11 @@ class Add(Bloq):
b: A b_dtype.bitsize-sized input/output register (register b above).

References:
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648)
[Halving the cost of quantum addition](https://arxiv.org/abs/1709.06648).
Gidney 2018. The construction used in this bloq, evolved from [2].

[A new quantum ripple-carry addition circuit](https://arxiv.org/abs/quant-ph/0410184).
Cuccaro et. al. 2004.
"""

a_dtype: Union[QInt, QUInt, QMontgomeryUInt] = field()
Expand Down Expand Up @@ -479,10 +483,6 @@ def build_composite_bloq(

# Rejoin the qubits representing k for in-place addition.
k = bb.join(k_split, dtype=x.reg.dtype)
if not isinstance(x.reg.dtype, (QInt, QUInt, QMontgomeryUInt)):
raise ValueError(
"Only QInt, QUInt and QMontgomerUInt types are supported for composite addition."
)
k, x = bb.add(Add(x.reg.dtype, x.reg.dtype), a=k, b=x)

# Resplit the k qubits in order to undo the original bit flips to go from the binary
Expand Down
6 changes: 4 additions & 2 deletions qualtran/bloqs/arithmetic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,8 @@ def on_classical_vals(
def build_composite_bloq(
self, bb: 'BloqBuilder', a: Soquet, b: Soquet, target: SoquetT
) -> Dict[str, 'SoquetT']:
if isinstance(self.bitsize, sympy.Expr):
raise DecomposeTypeError(f"Cannot decompose symbolic {self}.")
if is_symbolic(self.bitsize):
raise DecomposeTypeError(f'Symbolic decomposition is not supported for {self}')

# Base Case: Comparing two qubits.
# Signed doesn't matter because we can't represent signed integers with 1 qubit.
Expand Down Expand Up @@ -1080,6 +1080,8 @@ def build_composite_bloq(
a = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=a)
b = bb.add(SignExtend(self.dtype, QInt(self.dtype.bitsize + 1)), x=b)
else:
if self.dtype.is_symbolic():
raise DecomposeTypeError(f"Cannot decompose symoblic {self}.")
a = bb.join(np.concatenate([[bb.allocate(1)], bb.split(a)]))
b = bb.join(np.concatenate([[bb.allocate(1)], bb.split(b)]))

Expand Down
3 changes: 3 additions & 0 deletions qualtran/bloqs/arithmetic/controlled_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
DecomposeTypeError,
QBit,
QInt,
QUInt,
Expand Down Expand Up @@ -134,6 +135,8 @@ def wire_symbol(self, soq: 'Soquet') -> 'WireSymbol':
def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'Soquet', a: 'Soquet', b: 'Soquet'
) -> Dict[str, 'SoquetT']:
if self.a_dtype.is_symbolic():
raise DecomposeTypeError(f"Cannot decompose symbolic {self}.")
a_arr = bb.split(a)
ctrl_q = bb.split(ctrl)[0]
ancilla_arr = []
Expand Down
29 changes: 25 additions & 4 deletions qualtran/bloqs/factoring/ecc/_ecc_shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

from functools import cached_property
from typing import Optional, Tuple
from typing import Optional, Sequence, Tuple

import attrs
from attrs import frozen

from qualtran import Bloq, CompositeBloq, DecomposeTypeError, QBit, Register, Side, Signature
from qualtran import Bloq, DecomposeTypeError, QBit, QUInt, Register, Side, Signature
from qualtran.drawing import RarrowTextBox, Text, WireSymbol
from qualtran.resource_counting import CostKey, QubitCount


@frozen
Expand All @@ -41,5 +43,24 @@ def wire_symbol(
return RarrowTextBox('MeasQFT')
raise ValueError(f'Unrecognized register name {reg.name}')

def cost_attrs(self):
return [('n', self.n)]
def my_static_costs(self, cost_key: 'CostKey'):
# TODO https://github.com/quantumlib/Qualtran/issues/1261
if cost_key == QubitCount():
return self.n
return NotImplemented


@frozen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why a shim? why not use the QROM bloq itself (with Shaped as data)?

class SimpleQROM(Bloq):
selection_bitsize: int
targets: Sequence[Tuple[str, int]] = attrs.field(converter=tuple)

@cached_property
def signature(self) -> 'Signature':
return Signature(
[Register('selection', QUInt(self.selection_bitsize))]
+ [Register(tname, QUInt(tsize)) for tname, tsize in self.targets]
)

def __str__(self):
return 'QROM'
20 changes: 17 additions & 3 deletions qualtran/bloqs/factoring/ecc/ec_add.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@
"ec_add = ECAdd(n, mod=p)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cdbe32a7",
"metadata": {
"cq.autogen": "ECAdd.ec_add_256"
},
"outputs": [],
"source": [
"ec_add_256 = ECAdd(\n",
" n=256, mod=0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF\n",
")"
]
},
{
"cell_type": "markdown",
"id": "39210af4",
Expand All @@ -111,8 +125,8 @@
"outputs": [],
"source": [
"from qualtran.drawing import show_bloqs\n",
"show_bloqs([ec_add],\n",
" ['`ec_add`'])"
"show_bloqs([ec_add, ec_add_256],\n",
" ['`ec_add`', '`ec_add_256`'])"
]
},
{
Expand Down Expand Up @@ -157,7 +171,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.11.8"
}
},
"nbformat": 4,
Expand Down
13 changes: 11 additions & 2 deletions qualtran/bloqs/factoring/ecc/ec_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def signature(self) -> 'Signature':
)

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
# litinksi
# Litinski 2023.
# These counts are transcribed from the table in Figure 5.
return {
MultiCToffoli(n=self.n): 18,
ModAdd(bitsize=self.n, mod=self.mod): 3,
Expand All @@ -86,4 +87,12 @@ def _ec_add() -> ECAdd:
return ec_add


_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add])
@bloq_example
def _ec_add_256() -> ECAdd:
ec_add_256 = ECAdd(
n=256, mod=0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF
)
return ec_add_256


_EC_ADD_DOC = BloqDocSpec(bloq_cls=ECAdd, examples=[_ec_add, _ec_add_256])
43 changes: 39 additions & 4 deletions qualtran/bloqs/factoring/ecc/ec_add_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,23 @@
import sympy
from attrs import frozen

from qualtran import Bloq, bloq_example, BloqDocSpec, QBit, QUInt, Register, Signature
from qualtran import (
Bloq,
bloq_example,
BloqBuilder,
BloqDocSpec,
QBit,
QUInt,
Register,
Signature,
Soquet,
SoquetT,
)
from qualtran.drawing import Circle, Text, TextBox, WireSymbol
from qualtran.simulation.classical_sim import ClassicalValT

from ._ecc_shims import SimpleQROM
from .ec_add import ECAdd
from .ec_point import ECPoint


Expand Down Expand Up @@ -140,6 +153,31 @@ def signature(self) -> 'Signature':
]
)

@cached_property
def lookup_bloq(self) -> SimpleQROM:
return SimpleQROM(
selection_bitsize=self.window_size,
targets=[('a', self.n), ('b', self.n), ('lam', self.n)],
)

def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl: 'SoquetT', x: Soquet, y: Soquet
) -> Dict[str, 'SoquetT']:
ctrl = bb.join(ctrl)
a = bb.allocate(dtype=QUInt(self.n))
b = bb.allocate(dtype=QUInt(self.n))
lam = bb.allocate(dtype=QUInt(self.n))

mod = self.R.mod
ctrl, a, b, lam = bb.add(self.lookup_bloq, selection=ctrl, a=a, b=b, lam=lam)
a, b, x, y, lam = bb.add(ECAdd(n=self.n, mod=mod), a=a, b=b, x=x, y=y, lam=lam)
ctrl, a, b, lam = bb.add(self.lookup_bloq.adjoint(), selection=ctrl, a=a, b=b, lam=lam)
bb.free(a)
bb.free(b)
bb.free(lam)

return {'ctrl': bb.split(ctrl), 'x': x, 'y': y}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
) -> 'WireSymbol':
Expand All @@ -153,9 +191,6 @@ def wire_symbol(
return TextBox(f'$+{self.R.y}$')
raise ValueError(f'Unrecognized register name {reg.name}')

def __str__(self):
return f'ECWindowAddR({self.n=})'


@bloq_example
def _ec_window_add() -> ECWindowAddR:
Expand Down
25 changes: 20 additions & 5 deletions qualtran/bloqs/factoring/ecc/ec_phase_estimate_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from functools import cached_property
from typing import Dict
from typing import Dict, Set, Union

import sympy
from attrs import frozen
Expand All @@ -23,6 +23,7 @@
bloq_example,
BloqBuilder,
BloqDocSpec,
DecomposeNotImplementedError,
DecomposeTypeError,
QUInt,
Register,
Expand All @@ -34,7 +35,7 @@
from qualtran.resource_counting import BloqCountDictT, SympySymbolAllocator

from ._ecc_shims import MeasureQFT
from .ec_add_r import ECAddR
from .ec_add_r import ECAddR, ECWindowAddR
from .ec_point import ECPoint


Expand All @@ -48,27 +49,41 @@ class ECPhaseEstimateR(Bloq):
Args:
n: The bitsize of the elliptic curve points' x and y registers.
point: The elliptic curve point to phase estimate against.
window_size: If non-zero, use windowed elliptic curve point addition.
"""

n: int
point: ECPoint
window_size: int = 0

@cached_property
def signature(self) -> 'Signature':
return Signature([Register('x', QUInt(self.n)), Register('y', QUInt(self.n))])

@property
def ec_add(self) -> Union[ECAddR, ECWindowAddR]:
if self.window_size == 0:
return functools.partial(ECAddR, n=self.n)
return functools.partial(ECWindowAddR, n=self.n, window_size=self.window_size)

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if isinstance(self.n, sympy.Expr):
raise DecomposeTypeError("Cannot decompose symbolic `n`.")
if self.window_size != 0:
raise DecomposeNotImplementedError("We don't support a windowed addition circuit yet.")

ctrl = [bb.add(PlusState()) for _ in range(self.n)]
for i in range(self.n):
ctrl[i], x, y = bb.add(ECAddR(n=self.n, R=2**i * self.point), ctrl=ctrl[i], x=x, y=y)

bb.add(MeasureQFT(n=self.n), x=ctrl)
return {'x': x, 'y': y}

def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
return {ECAddR(n=self.n, R=self.point): self.n, MeasureQFT(n=self.n): 1}
def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:
return {
(self.ec_add(R=self.point), self.n / (2**self.window_size)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float division?

(MeasureQFT(n=self.n), 1),
}

def __str__(self) -> str:
return f'PE${self.point}$'
Expand Down
Loading
Loading