diff --git a/bqskit/compiler/basepass.py b/bqskit/compiler/basepass.py index 413625f4e..51b95064a 100644 --- a/bqskit/compiler/basepass.py +++ b/bqskit/compiler/basepass.py @@ -4,6 +4,7 @@ import abc import warnings from typing import TYPE_CHECKING +import pickle if TYPE_CHECKING: from typing import Any @@ -158,7 +159,90 @@ def execute(*args: Any, **kwargs: Any) -> Any: 'https://bqskit.readthedocs.io/en/latest/source/runtime.html' 'In a future version, this error will become an AttributeError.', ) + + def checkpoint_finished(self, data: PassData, checkpoint_key: str) -> bool: + """ + Check if we are checkpointing this pass. If so, check if the + checkpoint has finished. + + Args: + data (PassData): The data dictionary. + checkpoint_key (str): The key to check for in the data dictionary. + + Returns: + bool: True if the pass should checkpoint and has finished. + """ + if "checkpoint_dir" in data: + if data.get(checkpoint_key, False): + return True + + return False + + def finish_checkpoint(self, + circuit: Circuit, + data: PassData, + checkpoint_key: str, + remove_key: str | None = None) -> None: + """ + Set the checkpoint key to True and save the data and circuit. + + Args: + circuit (Circuit): The circuit to save. + data (PassData): The data dictionary. + checkpoint_key (str): The key to set to True. + remove_key (str | None): If not None, remove this key from the data + dictionary before saving. + """ + if "checkpoint_dir" in data: + data[checkpoint_key] = True + if remove_key is not None: + data.pop(remove_key) + save_data_file = data["checkpoint_data_file"] + save_circuit_file = data["checkpoint_circ_file"] + pickle.dump(data, open(save_data_file, "wb")) + pickle.dump(circuit, open(save_circuit_file, "wb")) + + def restart_checkpoint(self, + circuit: Circuit, + data: PassData, + checkpoint_key: str) -> Any | None: + """ + Load the saved data and circuit from the checkpoint. + + This will modify (in-place) the passed in circuit and data dictionary. + + Args: + data (PassData): The data dictionary. + checkpoint_key (str): The key to check for in the data dictionary. + + Returns: + Any | None: If the checkpoint exists, it returns whatever data is + stored at the checkpoint key. Otherwise, it returns None. + """ + if "checkpoint_dir" in data: + load_data_file = data["checkpoint_data_file"] + load_circuit_file = data["checkpoint_circ_file"] + new_data = pickle.load(open(load_data_file, "rb")) + new_circuit = pickle.load(open(load_circuit_file, "rb")) + data.update(new_data) + circuit.become(new_circuit) + return data.get(checkpoint_key, None) + return None + + def checkpoint_save(self, circuit: Circuit, data: PassData) -> None: + """ + Save the circuit and data to the checkpoint. + + Args: + circuit (Circuit): The circuit to save. + data (PassData): The data dictionary. + """ + if "checkpoint_dir" in data: + save_data_file = data["checkpoint_data_file"] + save_circuit_file = data["checkpoint_circ_file"] + pickle.dump(data, open(save_data_file, "wb")) + pickle.dump(circuit, open(save_circuit_file, "wb")) async def _sub_do_work( workflow: Workflow, diff --git a/bqskit/passes/control/foreach.py b/bqskit/passes/control/foreach.py index b7308874f..77f872c94 100644 --- a/bqskit/passes/control/foreach.py +++ b/bqskit/passes/control/foreach.py @@ -3,9 +3,13 @@ import functools import logging -from typing import Callable +from typing import Callable, List +import pickle +from os.path import join, exists +from pathlib import Path from bqskit.compiler.basepass import _sub_do_work +from collections import Counter from bqskit.compiler.basepass import BasePass from bqskit.compiler.machine import MachineModel from bqskit.compiler.passdata import PassData @@ -13,9 +17,13 @@ from bqskit.compiler.workflow import WorkflowLike from bqskit.ir.circuit import Circuit from bqskit.ir.gates.circuitgate import CircuitGate +from bqskit.ir.gate import Gate +from bqskit.ir.gates import CNOTGate from bqskit.ir.gates.constant.unitary import ConstantUnitaryGate from bqskit.ir.gates.parameterized.pauli import PauliGate from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate +from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator +from bqskit.ir.opt.cost.generator import CostFunctionGenerator from bqskit.ir.location import CircuitLocation from bqskit.ir.operation import Operation from bqskit.ir.point import CircuitPoint @@ -60,9 +68,11 @@ def __init__( self, loop_body: WorkflowLike, calculate_error_bound: bool = False, + error_cost_gen: CostFunctionGenerator = HilbertSchmidtResidualsGenerator(), collection_filter: Callable[[Operation], bool] | None = None, replace_filter: ReplaceFilterFn | str = 'always', batch_size: int | None = None, + blocks_to_run: List[int] = [] ) -> None: """ Construct a ForEachBlockPass. @@ -127,6 +137,11 @@ def __init__( Defaults to 'always'. #TODO: address importability batch_size (int): (Deprecated). + + blocks_to_run (List[int]): + A list of blocks to run the ForEachBlockPass body on. By default + you run on all blocks. This is mainly used with checkpointing, + where some blocks have already finished while others have not. """ if batch_size is not None: import warnings @@ -140,7 +155,8 @@ def __init__( self.collection_filter = collection_filter or default_collection_filter self.replace_filter = replace_filter or default_replace_filter self.workflow = Workflow(loop_body) - + self.blocks_to_run = sorted(blocks_to_run) + self.error_cost_gen = error_cost_gen if not callable(self.collection_filter): raise TypeError( 'Expected callable method that maps Operations to booleans for' @@ -165,15 +181,32 @@ async def run(self, circuit: Circuit, data: PassData) -> None: else: replace_filter = self.replace_filter + # If checkpoint_dir is defined in data, then we should checkpoint + # Note that checkpoint_dir is set in CheckpointRestartPass + should_checkpoint = 'checkpoint_dir' in data + checkpoint_dir = data.get('checkpoint_dir', "") + if should_checkpoint: + Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) + # Make room in data for block data if self.key not in data: data[self.key] = [] - # Collect blocks + # Collect blocks to run with blocks: list[tuple[int, Operation]] = [] - for cycle, op in circuit.operations_with_cycles(): - if self.collection_filter(op): + if (len(self.blocks_to_run) == 0): + self.blocks_to_run = list(range(circuit.num_operations)) + + block_ids = self.blocks_to_run.copy() + next_id = block_ids.pop(0) + for i, (cycle, op) in enumerate(circuit.operations_with_cycles()): + if self.collection_filter(op) and i == next_id: blocks.append((cycle, op)) + try: + next_id = block_ids.pop(0) + except IndexError: + # No more blocks to run on + break # No blocks, no work if len(blocks) == 0: @@ -188,38 +221,70 @@ async def run(self, circuit: Circuit, data: PassData) -> None: subcircuits: list[Circuit] = [] block_datas: list[PassData] = [] for i, (cycle, op) in enumerate(blocks): + # Set up checkpoint data and circuit files + # Need to zero pad block ids for consistency + num_digits = len(str(circuit.num_operations)) + block_num = str(self.blocks_to_run[i]).zfill(num_digits) + save_data_file = join(checkpoint_dir, f'block_{block_num}.data') + save_circuit_file = join(checkpoint_dir, f'block_{block_num}.pickle') + checkpoint_found = False + if should_checkpoint and exists(save_data_file): + # If checkpointing, reload block data and circuit from + # checkpoint if it exists + _logger.debug(f'Loading block {i} from checkpoint.') + try: + subcircuit = pickle.load(open(save_circuit_file, 'rb')) + block_data = pickle.load(open(save_data_file, 'rb')) + checkpoint_found = True + except Exception as e: + # Problem reading the checkpointed files + _logger.error(f"Exception for file: {save_data_file}", e) + checkpoint_found = False + + # If no checkpoint found, form subcircuit and submodel + if not checkpoint_found: + # Form Subcircuit + if isinstance(op.gate, CircuitGate): + subcircuit = op.gate._circuit.copy() + subcircuit.set_params(op.params) + else: + subcircuit = Circuit.from_operation(op) + + # Form Submodel + subradixes = [circuit.radixes[q] for q in op.location] + subnumbering = {op.location[i]: i for i in range(len(op.location))} + submodel = MachineModel( + len(op.location), + coupling_graph.get_subgraph(op.location, subnumbering), + model.gate_set, + subradixes, + ) - # Form Subcircuit - if isinstance(op.gate, CircuitGate): - subcircuit = op.gate._circuit.copy() - subcircuit.set_params(op.params) - else: - subcircuit = Circuit.from_operation(op) - - # Form Submodel - subradixes = [circuit.radixes[q] for q in op.location] - subnumbering = {op.location[i]: i for i in range(len(op.location))} - submodel = MachineModel( - len(op.location), - coupling_graph.get_subgraph(op.location, subnumbering), - model.gate_set, - subradixes, - ) - - # Form Subdata - block_data: PassData = PassData(subcircuit) - block_data['subnumbering'] = subnumbering - block_data['model'] = submodel - block_data['point'] = CircuitPoint(cycle, op.location[0]) - block_data['calculate_error_bound'] = self.calculate_error_bound - for key in data: - if key.startswith(self.pass_down_key_prefix): - block_data[key] = data[key] - elif key.startswith( - self.pass_down_block_specific_key_prefix, - ) and i in data[key]: - block_data[key] = data[key][i] - block_data.seed = data.seed + # Form Subdata + block_data: PassData = PassData(subcircuit) + block_data['subnumbering'] = subnumbering + block_data['model'] = submodel + block_data['point'] = CircuitPoint(cycle, op.location[0]) + block_data['calculate_error_bound'] = self.calculate_error_bound + for key in data: + if key.startswith(self.pass_down_key_prefix): + block_data[key] = data[key] + elif key.startswith( + self.pass_down_block_specific_key_prefix, + ) and i in data[key]: + block_data[key] = data[key][i] + block_data.seed = data.seed + + # Change next subdirectory + if should_checkpoint: + # Blocks can have sub-blocks, so we need to change the ckpt dir + # for each block as a sub-folder of the main checkpoint dir + block_data["checkpoint_dir"] = join(checkpoint_dir, f'block_{block_num}') + block_data["checkpoint_circ_file"] = save_circuit_file + block_data["checkpoint_data_file"] = save_data_file + block_data['block_num'] = block_num + pickle.dump(block_data, open(save_data_file, 'wb')) + pickle.dump(subcircuit, open(save_circuit_file, 'wb')) subcircuits.append(subcircuit) block_datas.append(block_data) @@ -229,7 +294,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None: _sub_do_work, [self.workflow] * len(subcircuits), subcircuits, - block_datas, + block_datas ) # Unpack results @@ -272,7 +337,6 @@ async def run(self, circuit: Circuit, data: PassData) -> None: if self.calculate_error_bound: _logger.debug(f'New circuit error is {data.error}.') - def default_collection_filter(op: Operation) -> bool: return isinstance( op.gate, ( diff --git a/bqskit/passes/io/__init__.py b/bqskit/passes/io/__init__.py index f06461538..03b225e29 100644 --- a/bqskit/passes/io/__init__.py +++ b/bqskit/passes/io/__init__.py @@ -1,12 +1,14 @@ """This package implements various IO related passes.""" from __future__ import annotations +from bqskit.passes.io.intermediate import CheckpointRestartPass from bqskit.passes.io.checkpoint import LoadCheckpointPass from bqskit.passes.io.checkpoint import SaveCheckpointPass from bqskit.passes.io.intermediate import RestoreIntermediatePass from bqskit.passes.io.intermediate import SaveIntermediatePass __all__ = [ + 'CheckpointRestartPass', 'LoadCheckpointPass', 'SaveCheckpointPass', 'SaveIntermediatePass', diff --git a/bqskit/passes/io/intermediate.py b/bqskit/passes/io/intermediate.py index 47c85b9f1..d648ec2ff 100644 --- a/bqskit/passes/io/intermediate.py +++ b/bqskit/passes/io/intermediate.py @@ -4,21 +4,27 @@ import logging import pickle from os import listdir -from os import mkdir -from os.path import exists +from os.path import exists, join, isdir +import shutil from re import findall +from typing import cast, Sequence +from pathlib import Path + +from copy import deepcopy from bqskit.compiler.basepass import BasePass +from bqskit.compiler.workflow import Workflow +from bqskit.passes.alias import PassAlias from bqskit.compiler.passdata import PassData from bqskit.ir.circuit import Circuit from bqskit.ir.gates.circuitgate import CircuitGate from bqskit.ir.lang.qasm2.qasm2 import OPENQASM2Language from bqskit.ir.operation import Operation from bqskit.passes.util.converttou3 import ToU3Pass +from bqskit.utils.typing import is_sequence _logger = logging.getLogger(__name__) - class SaveIntermediatePass(BasePass): """ The SaveIntermediate class. @@ -197,3 +203,50 @@ async def run(self, circuit: Circuit, data: PassData) -> None: circuit.append_circuit(block_circ, block_location) # Check if the circuit has been partitioned, if so, try to replace # blocks + +class CheckpointRestartPass(BasePass): + ''' + This pass is used to reload a checkpointed circuit. Checkpoints are useful + to restart a workflow from a certain point in the event of a crash or + timeout. + ''' + def __init__(self, checkpoint_dir: str, + default_passes: BasePass | Sequence[BasePass]) -> None: + """ + Args: + checkpoint_dir (str): + Path to the directory containing the checkpointed circuit. + default_passes (BasePass | Sequence[BasePass]): + The passes to run if the checkpoint does not exist. Typically, + these will be the partitioning passes to set up the block + structure. + """ + if not is_sequence(default_passes): + default_passes = [cast(BasePass, default_passes)] + + if not isinstance(default_passes, list): + default_passes = list(default_passes) + + self.checkpoint_dir = checkpoint_dir + self.default_passes = default_passes + + async def run(self, circuit: Circuit, data: PassData) -> None: + """ + Set's the `checkpoint_dir` attribute and restores the circuit from the + checkpoint if possible. If the checkpoint does not exist, the default + passes are run. + """ + data["checkpoint_dir"] = self.checkpoint_dir + if not exists(join(self.checkpoint_dir, "circuit.pickle")): + _logger.info("Checkpoint does not exist!") + await Workflow(self.default_passes).run(circuit, data) + Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True) + pickle.dump(circuit, open(join(self.checkpoint_dir, "circuit.pickle"), "wb")) + pickle.dump(data, open(join(self.checkpoint_dir, "data.pickle"), "wb")) + else: + # Already checkpointed, restore + _logger.info("Restoring from Checkpoint!") + new_circuit = pickle.load(open(join(self.checkpoint_dir, "circuit.pickle"), "rb")) + circuit.become(new_circuit) + new_data = pickle.load(open(join(self.checkpoint_dir, "data.pickle"), "rb")) + data.update(new_data) diff --git a/bqskit/passes/processing/scan.py b/bqskit/passes/processing/scan.py index 67e27e600..cfd11656d 100644 --- a/bqskit/passes/processing/scan.py +++ b/bqskit/passes/processing/scan.py @@ -22,6 +22,17 @@ class ScanningGateRemovalPass(BasePass): Starting from one side of the circuit, attempt to remove gates one-by-one. """ + ''' + This key can be set in the checkpoint to indicate that the pass has + finished. + ''' + checkpoint_finish_key = "ScanningGateRemovalPass_finished" + + ''' + This key holds the data to restart/continue the pass from a checkpoint. + ''' + checkpoint_data_key = "ScanningGateRemovalPass_data" + def __init__( self, start_from_left: bool = True, @@ -97,6 +108,9 @@ def __init__( async def run(self, circuit: Circuit, data: PassData) -> None: """Perform the pass's operation, see :class:`BasePass` for more.""" + if self.checkpoint_finished(data, + ScanningGateRemovalPass.checkpoint_finish_key): + return instantiate_options = self.instantiate_options.copy() if 'seed' not in instantiate_options: instantiate_options['seed'] = data.seed @@ -108,8 +122,19 @@ async def run(self, circuit: Circuit, data: PassData) -> None: circuit_copy = circuit.copy() reverse_iter = not self.start_from_left - for cycle, op in circuit.operations_with_cycles(reverse=reverse_iter): + # Restart the checkpoint from the start ind if we are checkpointing, + # otherwise start from the beginning. + start_ind = self.restart_checkpoint(data, + ScanningGateRemovalPass.checkpoint_data_key) + if start_ind is None: + start_ind = 0 + + operations = circuit.operations_with_cycles(reverse=reverse_iter) + for i, cycle_op in enumerate(operations): + if i < start_ind: + continue + cycle, op = cycle_op if not self.collection_filter(op): _logger.debug(f'Skipping operation {op} at cycle {cycle}.') continue @@ -131,9 +156,15 @@ async def run(self, circuit: Circuit, data: PassData) -> None: if self.cost(working_copy, target) < self.success_threshold: _logger.debug('Successfully removed operation.') circuit_copy = working_copy + # Checkpoint the circuit here if set + data[ScanningGateRemovalPass.checkpoint_data_key] = start_ind + self.checkpoint_save(circuit_copy, data) circuit.become(circuit_copy) + self.finish_checkpoint(circuit, data, + ScanningGateRemovalPass.checkpoint_finish_key, + remove_key=ScanningGateRemovalPass.checkpoint_data_key) def default_collection_filter(op: Operation) -> bool: return True