Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wasm support to the conversion #66

Merged
merged 4 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _metadata.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__extension_version__ = "0.2.0rc15"
__extension_version__ = "0.2.0rc16"
__extension_name__ = "pytket-qir"
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
~~~~~~~~~

0.2.0rc16 (August 2023)
-----------------------
* add support for wasm in the conversion

0.2.0rc15 (August 2023)
-----------------------
* update the classical register handling to use i1* pointer
Expand Down
50 changes: 15 additions & 35 deletions pytket/qir/conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
"""

from enum import Enum
from typing import Union
from typing import Optional, Union

import pyqir

from pytket import wasm
from pytket.circuit import Circuit

from .conversion import QirGenerator
Expand All @@ -40,7 +41,8 @@ def pytket_to_qir(
circ: Circuit,
name: str = "Generated from input pytket circuit",
qir_format: QIRFormat = QIRFormat.BINARY,
pyqir_0_6_compatibility: bool = False,
wfh: Optional[wasm.WasmFileHandler] = None,
int_type: int = 64,
) -> Union[str, bytes, None]:
"""converts given pytket circuit to qir

Expand All @@ -53,6 +55,8 @@ def pytket_to_qir(
:param pyqir_0_6_compatibility: converts the output to be compatible with
pyqir 0.6, default value false
:type pyqir_0_6_compatibility: bool
:param int_type: size of each integer, allowed value 32 and 64
:type int_type: int
"""

if len(circ.q_registers) > 1 or (
Expand All @@ -64,6 +68,9 @@ def pytket_to_qir(
compiler pass `FlattenRelabelRegistersPass`."""
)

if int_type != 32 and int_type != 64:
raise ValueError("the integer size must be 32 or 64")

for creg in circ.c_registers:
if creg.size > 64:
raise ValueError("classical registers must not have more than 64 bits")
Expand All @@ -75,49 +82,22 @@ def pytket_to_qir(
)

qir_generator = QirGenerator(
circuit=circ,
module=m,
wasm_int_type=32,
qir_int_type=64,
circuit=circ, module=m, wasm_int_type=int_type, qir_int_type=int_type, wfh=wfh
)

populated_module = qir_generator.circuit_to_module(
qir_generator.circuit, qir_generator.module, True
)

if pyqir_0_6_compatibility:
if len(circ.c_registers) > 1:
raise ValueError(
"""The qir optimised for pyqir 0.6 can only contain
one classical register"""
)
if wfh is not None:
wasm_dict: dict[str, str] = qir_generator.get_wasm_sar()

initial_result = str(populated_module.module.ir()) # type: ignore

initial_result = (
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is probably easier to understand this when looking at the individual commits

initial_result.replace("entry_point", "EntryPoint")
.replace("num_required_qubits", "requiredQubits")
.replace("num_required_results", "requiredResults")
)

def keep_line(line: str) -> bool:
return (
("@__quantum__qis__read_result__body" not in line)
and ("@set_creg_bit" not in line)
and ("@get_creg_bit" not in line)
and ("@set_creg_to_int" not in line)
and ("@get_int_from_creg" not in line)
and ("@create_creg" not in line)
)

result = "\n".join(filter(keep_line, initial_result.split("\n")))

# replace the use of the removed register variable with i64 0
result = result.replace("i64 %0", "i64 0")
result = result.replace("i64 %3", "i64 0")
for wf in wasm_dict:
initial_result = initial_result.replace(wf, wasm_dict[wf])

for _ in range(10):
result = result.replace("\n\n\n\n", "\n\n")
result = initial_result

bitcode = pyqir.Module.from_ir(pyqir.Context(), result).bitcode # type: ignore

Expand Down
63 changes: 61 additions & 2 deletions pytket/qir/conversion/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pyqir
from pyqir import IntPredicate, Value

from pytket import Bit, Circuit, Qubit, predicates # type: ignore
from pytket import Bit, Circuit, Qubit, predicates, wasm # type: ignore
from pytket.circuit import ( # type: ignore
BitRegister,
ClassicalExpBox,
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
module: tketqirModule,
wasm_int_type: int,
qir_int_type: int,
wfh: Optional[wasm.WasmFileHandler] = None,
) -> None:
self.circuit = circuit
self.module = module
Expand All @@ -120,6 +121,12 @@ def __init__(
self.cregs = _retrieve_registers(self.circuit.bits, BitRegister)
self.target_gateset = self.module.gateset.base_gateset

self.wasm_dict: dict[str, str] = {}
self.wasm_dict[
"!llvm.module.flags"
] = 'attributes #1 = { "wasm" }\n\n!llvm.module.flags'
self.int_type_str = f"i{qir_int_type}"

self.target_gateset.add(OpType.PhasedX)
self.target_gateset.add(OpType.ZZPhase)
self.target_gateset.add(OpType.ZZMax)
Expand Down Expand Up @@ -244,6 +251,42 @@ def __init__(
self.circuit.n_qubits + 1
)

# void functionname()
if wfh is not None:
self.wasm: dict[str, pyqir.Function] = {}
for fn in wfh._functions:
wasm_func_interface = "declare "
parametertype = [self.qir_int_type] * wfh._functions[fn][0]
if wfh._functions[fn][1] == 0:
returntype = pyqir.Type.void(self.module.module.context)
wasm_func_interface += "void "
elif wfh._functions[fn][1] == 1:
returntype = self.qir_int_type
wasm_func_interface += f"i{self.int_type_str} "
else:
raise ValueError(
"wasm function which return more than"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should make an issue for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will do that

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done at CQCL/pytket-quantinuum#201
I have added that to pytket-quantinuum because this is an issues that needs to be solved for qasm, qir and on the hardware side

+ " one value are not supported yet"
)

self.wasm[fn] = self.module.module.add_external_function(
f"{fn}",
pyqir.FunctionType(
returntype,
parametertype,
),
)

wasm_func_interface += f"@{fn}("
if wfh._functions[fn][0] > 0:
param_str = f"{self.int_type_str}, " * (wfh._functions[fn][0] - 1)
wasm_func_interface += param_str
wasm_func_interface += f"{self.int_type_str})"
else:
wasm_func_interface += ")"

self.wasm_dict[wasm_func_interface] = f"{wasm_func_interface} #1"

self.additional_quantum_gates: dict[OpType, pyqir.Function] = {}

for creg in self.circuit.c_registers:
Expand Down Expand Up @@ -495,6 +538,9 @@ def _get_ssa_from_cl_bit_op(
else:
raise ValueError("unsupported bisewise operation")

def get_wasm_sar(self) -> dict[str, str]:
return self.wasm_dict

def circuit_to_module(
self, circuit: Circuit, module: tketqirModule, record_output: bool = False
) -> tketqirModule:
Expand Down Expand Up @@ -635,7 +681,20 @@ def condition_block() -> None:
)

elif isinstance(op, WASMOp):
raise ValueError("WASM not supported yet")
paramreg, resultreg = self._get_c_regs_from_com(command)

paramssa = [self._get_i64_ssa_reg(p) for p in paramreg]

result = self.module.builder.call( # type: ignore
self.wasm[command.op.func_name],
[*paramssa],
)

if len(resultreg) == 1:
self.module.builder.call(
self.set_creg_to_int,
[self.ssa_vars[resultreg[0]], result],
)

elif op.type == OpType.ZZPhase:
assert len(command.bits) == 0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
packages=find_namespace_packages(include=["pytket.*"]),
include_package_data=True,
install_requires=[
"pytket ~= 1.18",
"pytket == 1.19.0rc0",
"pyqir == 0.8.2",
"pyqir-generator == 0.7.0",
"pyqir-evaluator == 0.7.0",
Expand Down
29 changes: 0 additions & 29 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,6 @@ def test_pytket_qir() -> None:
check_qir_result(result, "test_pytket_qir")


def test_pytket_qir_optimised() -> None:
circ = Circuit(
3,
)
circ.H(0)

result = pytket_to_qir(
circ,
name="test_pytket_qir",
qir_format=QIRFormat.STRING,
pyqir_0_6_compatibility=True,
)

check_qir_result(result, "test_pytket_qir_optimised")


def test_pytket_qir_optimised_ii() -> None:
circ = Circuit(2).H(0).CX(0, 1).measure_all()

result = pytket_to_qir(
circ,
name="test_pytket_qir",
qir_format=QIRFormat.STRING,
pyqir_0_6_compatibility=True,
)

check_qir_result(result, "test_pytket_qir_optimised_ii")


def test_pytket_api_qreg() -> None:
circ = Circuit(3)
circ.H(0)
Expand Down
30 changes: 0 additions & 30 deletions tests/qir/test_pytket_qir_optimised.ll

This file was deleted.

41 changes: 0 additions & 41 deletions tests/qir/test_pytket_qir_optimised_ii.ll

This file was deleted.

60 changes: 60 additions & 0 deletions tests/qir/test_pytket_qir_wasm.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
; ModuleID = 'test_pytket_qir_wasm'
source_filename = "test_pytket_qir_wasm"

%Qubit = type opaque
%Result = type opaque

define void @main() #0 {
entry:
call void @__quantum__qis__h__body(%Qubit* null)
call void @__quantum__rt__tuple_start_record_output()
call void @__quantum__rt__tuple_end_record_output()
ret void
}

declare i1 @get_creg_bit(i1*, i32)

declare void @set_creg_bit(i1*, i32, i1)

declare void @set_creg_to_int(i1*, i32)

declare i1 @__quantum__qis__read_result__body(%Result*)

declare i1* @create_creg(i32)

declare i32 @get_int_from_creg(i1*)

declare void @__quantum__rt__int_record_output(i32, i8*)

declare void @__quantum__rt__tuple_start_record_output()

declare void @__quantum__rt__tuple_end_record_output()

declare void @init() #1

declare i32 @add_one(i32)

declare i32 @multi(i32, i32)

declare i32 @add_two(i32)

declare i32 @add_eleven(i32)

declare void @no_return(i32) #1

declare i32 @no_parameters()

declare i32 @new_function()

declare void @__quantum__qis__h__body(%Qubit*)

attributes #0 = { "entry_point" "num_required_qubits"="1" "num_required_results"="1" "output_labeling_schema" "qir_profiles"="custom" }

attributes #1 = { "wasm" }

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
Loading