-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #208 from dstrain115/quokka_sampler
Add support for sampling from Quokka devices
- Loading branch information
Showing
2 changed files
with
261 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright 2024 The Unitary Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Simulation using a "Quokka" device.""" | ||
|
||
from typing import Any, Callable, Dict, Optional, Sequence | ||
import warnings | ||
|
||
import cirq | ||
import numpy as np | ||
import json | ||
|
||
_REQUEST_ENDPOINT = "http://{}.quokkacomputing.com/qsim/qasm" | ||
_DEFAULT_QUOKKA_NAME = "quokka1" | ||
|
||
JSON_TYPE = Dict[str, Any] | ||
_RESULT_KEY = "result" | ||
_ERROR_CODE_KEY = "error_code" | ||
_RESULT_KEY = "result" | ||
_SCRIPT_KEY = "script" | ||
_REPETITION_KEY = "count" | ||
|
||
|
||
class QuokkaPostEndpoint: | ||
def __init__(self, name=_DEFAULT_QUOKKA_NAME): | ||
self._endpoint = _REQUEST_ENDPOINT.format(name) | ||
|
||
def __call__(self, json_request: JSON_TYPE) -> JSON_TYPE: | ||
try: | ||
import requests | ||
except ImportError as e: | ||
raise ImportError( | ||
"Please install requests library to use Quokka" | ||
"(e.g. pip install requests)" | ||
) from e | ||
result = requests.post(self._endpoint, json=json_request) | ||
return json.loads(result.content) | ||
|
||
|
||
class QuokkaSampler(cirq.Sampler): | ||
"""Sampler for querying a Quokka quantum simulation device. | ||
See https://www.quokkacomputing.com/ for more information.a | ||
Args: | ||
name: name of your quokka device | ||
post: used only for testing to override default | ||
behavior to connect to internet URLs. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
name: str = _DEFAULT_QUOKKA_NAME, | ||
post: Optional[Callable[[JSON_TYPE], JSON_TYPE]] = None, | ||
): | ||
self._post = post or QuokkaPostEndpoint(name) | ||
|
||
def run_sweep( | ||
self, | ||
program: cirq.AbstractCircuit, | ||
params: cirq.Sweepable, | ||
repetitions: int = 1, | ||
) -> Sequence[cirq.Result]: | ||
"""Samples from the given Circuit. | ||
This allows for sweeping over different parameter values, | ||
unlike the `run` method. The `params` argument will provide a | ||
mapping from `sympy.Symbol`s used within the circuit to a set of | ||
values. Unlike the `run` method, which specifies a single | ||
mapping from symbol to value, this method allows a "sweep" of | ||
values. This allows a user to specify execution of a family of | ||
related circuits efficiently. | ||
Args: | ||
program: The circuit to sample from. | ||
params: Parameters to run with the program. | ||
repetitions: The number of times to sample. | ||
Returns: | ||
Result list for this run; one for each possible parameter resolver. | ||
""" | ||
rtn_results = [] | ||
qubits = sorted(program.all_qubits()) | ||
measure_keys = {} | ||
register_names = {} | ||
meas_i = 0 | ||
|
||
# Find all measurements in the circuit and record keys | ||
# so that we can later translate between circuit and QASM registers. | ||
for op in program.all_operations(): | ||
if isinstance(op.gate, cirq.MeasurementGate): | ||
key = cirq.measurement_key_name(op) | ||
if key in measure_keys: | ||
warnings.warn( | ||
"Warning! Keys can only be measured once in Quokka simulator" | ||
f"Key {key} will only contain the last measured value" | ||
) | ||
measure_keys[key] = op.qubits | ||
if cirq.QasmOutput.valid_id_re.match(key): | ||
register_names[key] = f"m_{key}" | ||
else: | ||
register_names[key] = f"m{meas_i}" | ||
meas_i += 1 | ||
|
||
# QASM 2.0 does not support parameter sweeps, | ||
# so resolve any symbolic functions to a concrete circuit. | ||
for param_resolver in cirq.to_resolvers(params): | ||
circuit = cirq.resolve_parameters(program, param_resolver) | ||
qasm = cirq.qasm(circuit) | ||
|
||
# Hack to change sqrt-X gates into rx 0.5 gates: | ||
# Since quokka does not support sx or sxdg gates | ||
qasm = qasm.replace("\nsx ", "\nrx(pi*0.5) ").replace( | ||
"\nsxdg ", "\nrx(pi*-0.5) " | ||
) | ||
|
||
# Send data to quokka endpoint | ||
data = {_SCRIPT_KEY: qasm, _REPETITION_KEY: repetitions} | ||
json_results = self._post(data) | ||
|
||
if _ERROR_CODE_KEY in json_results and json_results[_ERROR_CODE_KEY] != 0: | ||
raise RuntimeError(f"Quokka returned an error: {json_results}") | ||
|
||
if _RESULT_KEY not in json_results: | ||
raise RuntimeError(f"Quokka did not return any results: {json_results}") | ||
|
||
# Associate results from json response to measurement keys. | ||
result_measurements = {} | ||
for key in measure_keys: | ||
register_name = register_names[key] | ||
if register_name not in json_results[_RESULT_KEY]: | ||
raise KeyError(f"Quokka did not measure key {key}: {json_results}") | ||
result_measurements[key] = np.asarray( | ||
json_results[_RESULT_KEY][register_name], dtype=np.dtype("int8") | ||
) | ||
|
||
# Append measurements to eventual result. | ||
rtn_results.append( | ||
cirq.ResultDict( | ||
params=param_resolver, | ||
measurements=result_measurements, | ||
) | ||
) | ||
return rtn_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright 2024 The Unitary Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import pytest | ||
import cirq | ||
import sympy | ||
|
||
import unitary.alpha.quokka_sampler as quokka_sampler | ||
|
||
# Qubits for testing | ||
_Q = cirq.LineQubit.range(10) | ||
|
||
|
||
class FakeQuokkaEndpoint: | ||
def __init__(self, *responses: quokka_sampler.JSON_TYPE): | ||
self.responses = list(responses) | ||
self.requests = [] | ||
|
||
def __call__( | ||
self, json_request: quokka_sampler.JSON_TYPE | ||
) -> quokka_sampler.JSON_TYPE: | ||
self.requests.append(json_request) | ||
return self.responses.pop() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"circuit,json_result", | ||
[ | ||
( | ||
cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0], key="mmm")), | ||
{"m_mmm": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit(cirq.X(_Q[0]), cirq.measure(_Q[0])), | ||
{"m0": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), cirq.X(_Q[1]), cirq.measure(_Q[0]), cirq.measure(_Q[1]) | ||
), | ||
{"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.CNOT(_Q[0], _Q[1]), | ||
cirq.measure(_Q[0]), | ||
cirq.measure(_Q[1]), | ||
), | ||
{"m0": [[1], [1], [1], [1], [1]], "m1": [[1], [1], [1], [1], [1]]}, | ||
), | ||
( | ||
cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.CNOT(_Q[0], _Q[1]), | ||
cirq.measure(_Q[0], _Q[1], key="m2"), | ||
), | ||
{"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, | ||
), | ||
], | ||
) | ||
def test_quokka_deterministic_examples(circuit, json_result): | ||
sim = cirq.Simulator() | ||
expected_results = sim.run(circuit, repetitions=5) | ||
json_response = {"error": "no error", "error_code": 0, "result": json_result} | ||
quokka = quokka_sampler.QuokkaSampler( | ||
name="test_mctesterface", post=FakeQuokkaEndpoint(json_response) | ||
) | ||
quokka_results = quokka.run(circuit, repetitions=5) | ||
assert quokka_results == expected_results | ||
|
||
|
||
def test_quokka_run_sweep(): | ||
sim = cirq.Simulator() | ||
circuit = cirq.Circuit( | ||
cirq.X(_Q[0]), | ||
cirq.X(_Q[1]) ** sympy.Symbol("X_1"), | ||
cirq.measure(_Q[0], _Q[1], key="m2"), | ||
) | ||
sweep = cirq.Points("X_1", [0, 1]) | ||
expected_results = sim.run_sweep(circuit, sweep, repetitions=5) | ||
json_response = { | ||
"error": "no error", | ||
"error_code": 0, | ||
"result": {"m_m2": [[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]}, | ||
} | ||
json_response2 = { | ||
"error": "no error", | ||
"error_code": 0, | ||
"result": {"m_m2": [[1, 0], [1, 0], [1, 0], [1, 0], [1, 0]]}, | ||
} | ||
quokka = quokka_sampler.QuokkaSampler( | ||
name="test_mctesterface", post=FakeQuokkaEndpoint(json_response, json_response2) | ||
) | ||
quokka_results = quokka.run_sweep(circuit, sweep, repetitions=5) | ||
assert quokka_results[0] == expected_results[0] |