Skip to content

Commit

Permalink
feat: mem_swap function for swapping two inout values (#653)
Browse files Browse the repository at this point in the history
not sure if prelude is best place for the custom compiler

also not sure about name `mem_swap` but don't want to confuse with
quantum logic swap

...And not sure about the test location either.


Closes #652

// to coerce release-please: 

Release-As: 0.13.1
  • Loading branch information
ss2165 authored Nov 15, 2024
1 parent 8d8c8b1 commit 89e10a5
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
15 changes: 15 additions & 0 deletions guppylang/std/_internal/compiler/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from hugr import tys as ht
from hugr import val as hv

from guppylang.definition.custom import CustomCallCompiler
from guppylang.definition.value import CallReturnWires
from guppylang.error import InternalGuppyError

if TYPE_CHECKING:
from hugr.build.dfg import DfBase

Expand Down Expand Up @@ -123,3 +127,14 @@ def build_unwrap(
result is an error.
"""
return build_unwrap_right(builder, result, error_msg, error_signal)


class MemSwapCompiler(CustomCallCompiler):
"""Compiler for the `mem_swap` function."""

def compile_with_inouts(self, args: list[Wire]) -> CallReturnWires:
[x, y] = args
return CallReturnWires(regular_returns=[], inout_returns=[y, x])

def compile(self, args: list[Wire]) -> list[Wire]:
raise InternalGuppyError("Call compile_with_inouts instead")
6 changes: 6 additions & 0 deletions guppylang/std/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
ListPushCompiler,
ListSetitemCompiler,
)
from guppylang.std._internal.compiler.prelude import MemSwapCompiler
from guppylang.std._internal.util import (
float_op,
int_op,
Expand Down Expand Up @@ -880,3 +881,8 @@ def zip(x): ...

@guppy.custom(checker=UnsupportedChecker(), higher_order_value=False)
def __import__(x): ...


@guppy.custom(MemSwapCompiler())
def mem_swap(x: L, y: L) -> None:
"""Swaps the values of two variables."""
20 changes: 19 additions & 1 deletion tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from guppylang.decorator import guppy
from guppylang.module import GuppyModule
from guppylang.std.builtins import array, owned
from guppylang.std.builtins import array, owned, mem_swap
from tests.util import compile_guppy

from guppylang.std.quantum import qubit
Expand Down Expand Up @@ -258,3 +258,21 @@ def main() -> int:
package = module.compile()
validate(package)
run_int_fn(package, expected=6)


def test_mem_swap(validate):
module = GuppyModule("test")

module.load(qubit)
@guppy(module)
def foo(x: qubit, y: qubit) -> None:
mem_swap(x, y)

@guppy(module)
def main() -> array[qubit, 2]:
a = array(qubit(), qubit())
foo(a[0], a[1])
return a

package = module.compile()
validate(package)

0 comments on commit 89e10a5

Please sign in to comment.