Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restrict classical action of certain arithmetic bloqs #1518

Merged
merged 8 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions qualtran/bloqs/arithmetic/subtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Dict, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np
Expand Down Expand Up @@ -40,6 +39,7 @@
from qualtran.bloqs.bookkeeping import Allocate, Cast, Free
from qualtran.bloqs.mcmt.multi_target_cnot import MultiTargetCNOT
from qualtran.drawing import Text
from qualtran.simulation.classical_sim import add_ints

if TYPE_CHECKING:
from qualtran.drawing import WireSymbol
Expand Down Expand Up @@ -270,10 +270,15 @@ def signature(self):
def on_classical_vals(
self, a: 'ClassicalValT', b: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
unsigned = isinstance(self.dtype, (QUInt, QMontgomeryUInt))
bitsize = self.dtype.bitsize
N = 2**bitsize if unsigned else 2 ** (bitsize - 1)
return {'a': a, 'b': int(math.fmod(b - a, N))}
return {
'a': a,
'b': add_ints(
int(b),
-int(a),
num_bits=int(self.dtype.bitsize),
is_signed=isinstance(self.dtype, QInt),
),
}

def wire_symbol(
self, reg: Optional['Register'], idx: Tuple[int, ...] = tuple()
Expand Down
9 changes: 9 additions & 0 deletions qualtran/bloqs/arithmetic/subtraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,12 @@ def test_subtract_from_bloq_decomposition():
want[(a << 4) | c][a_b] = 1
got = gate.tensor_contract()
np.testing.assert_allclose(got, want)


@pytest.mark.parametrize('bitsize', range(2, 5))
def test_subtractfrom_classical_action(bitsize):
dtype = QInt(bitsize)
blq = SubtractFrom(dtype)
qlt_testing.assert_consistent_classical_action(
blq, a=tuple(dtype.get_classical_domain()), b=tuple(dtype.get_classical_domain())
)
30 changes: 28 additions & 2 deletions qualtran/bloqs/mod_arithmetic/mod_addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,17 @@ def signature(self) -> 'Signature':
def on_classical_vals(
self, x: 'ClassicalValT', y: 'ClassicalValT'
) -> Dict[str, 'ClassicalValT']:
return {'x': x, 'y': (x + y) % self.mod}
if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)
if not (0 <= y < self.mod):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
)

y = (x + y) % self.mod
return {'x': x, 'y': y}

def build_composite_bloq(self, bb: 'BloqBuilder', x: Soquet, y: Soquet) -> Dict[str, 'SoquetT']:
if is_symbolic(self.bitsize):
Expand Down Expand Up @@ -307,6 +317,12 @@ def on_classical_vals(
return {'ctrl': 0, 'x': x}

assert ctrl == 1, 'Bad ctrl value.'

if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)

x = (x + self.k) % self.mod
return {'ctrl': ctrl, 'x': x}

Expand Down Expand Up @@ -492,7 +508,17 @@ def on_classical_vals(
if ctrl != self.cv:
return {'ctrl': ctrl, 'x': x, 'y': y}

return {'ctrl': ctrl, 'x': x, 'y': (x + y) % self.mod}
if not (0 <= x < self.mod):
raise ValueError(
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
)
if not (0 <= y < self.mod):
raise ValueError(
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
)

y = (x + y) % self.mod
return {'ctrl': ctrl, 'x': x, 'y': y}

def build_composite_bloq(
self, bb: 'BloqBuilder', ctrl, x: Soquet, y: Soquet
Expand Down
12 changes: 12 additions & 0 deletions qualtran/bloqs/mod_arithmetic/mod_addition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,15 @@ def test_cmod_add_complexity_vs_ref():
def test_mod_add_classical_action(bitsize, prime):
b = ModAdd(bitsize, prime)
assert_consistent_classical_action(b, x=range(prime), y=range(prime))


def test_cmodadd_tensor():
blq = CModAddK(bitsize=4, mod=7, k=1)
want = np.zeros((7, 7))
for i in range(7):
j = (i + 1) % 7
want[j, i] = 1

tn = blq.tensor_contract()
np.testing.assert_allclose(tn[:7, :7], np.eye(7)) # ctrl = 0
np.testing.assert_allclose(tn[16 : 16 + 7, 16 : 16 + 7], want) # ctrl = 1
Loading