Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding general checkpointing infra to checkpoint inside a pass #297

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions bqskit/compiler/basepass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import abc
import warnings
from typing import TYPE_CHECKING
import pickle

if TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -158,7 +159,90 @@ def execute(*args: Any, **kwargs: Any) -> Any:
'https://bqskit.readthedocs.io/en/latest/source/runtime.html'
'In a future version, this error will become an AttributeError.',
)

def checkpoint_finished(self, data: PassData, checkpoint_key: str) -> bool:
"""
Check if we are checkpointing this pass. If so, check if the
checkpoint has finished.

Args:
data (PassData): The data dictionary.
checkpoint_key (str): The key to check for in the data dictionary.

Returns:
bool: True if the pass should checkpoint and has finished.
"""
if "checkpoint_dir" in data:
if data.get(checkpoint_key, False):
return True

return False

def finish_checkpoint(self,
circuit: Circuit,
data: PassData,
checkpoint_key: str,
remove_key: str | None = None) -> None:
"""
Set the checkpoint key to True and save the data and circuit.

Args:
circuit (Circuit): The circuit to save.
data (PassData): The data dictionary.
checkpoint_key (str): The key to set to True.
remove_key (str | None): If not None, remove this key from the data
dictionary before saving.
"""
if "checkpoint_dir" in data:
data[checkpoint_key] = True
if remove_key is not None:
data.pop(remove_key)
save_data_file = data["checkpoint_data_file"]
save_circuit_file = data["checkpoint_circ_file"]
pickle.dump(data, open(save_data_file, "wb"))
pickle.dump(circuit, open(save_circuit_file, "wb"))

def restart_checkpoint(self,
circuit: Circuit,
data: PassData,
checkpoint_key: str) -> Any | None:
"""
Load the saved data and circuit from the checkpoint.

This will modify (in-place) the passed in circuit and data dictionary.

Args:
data (PassData): The data dictionary.
checkpoint_key (str): The key to check for in the data dictionary.

Returns:
Any | None: If the checkpoint exists, it returns whatever data is
stored at the checkpoint key. Otherwise, it returns None.
"""
if "checkpoint_dir" in data:
load_data_file = data["checkpoint_data_file"]
load_circuit_file = data["checkpoint_circ_file"]
new_data = pickle.load(open(load_data_file, "rb"))
new_circuit = pickle.load(open(load_circuit_file, "rb"))
data.update(new_data)
circuit.become(new_circuit)
return data.get(checkpoint_key, None)

return None

def checkpoint_save(self, circuit: Circuit, data: PassData) -> None:
"""
Save the circuit and data to the checkpoint.

Args:
circuit (Circuit): The circuit to save.
data (PassData): The data dictionary.
"""
if "checkpoint_dir" in data:
save_data_file = data["checkpoint_data_file"]
save_circuit_file = data["checkpoint_circ_file"]
pickle.dump(data, open(save_data_file, "wb"))
pickle.dump(circuit, open(save_circuit_file, "wb"))

async def _sub_do_work(
workflow: Workflow,
Expand Down
140 changes: 102 additions & 38 deletions bqskit/passes/control/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,27 @@

import functools
import logging
from typing import Callable
from typing import Callable, List
import pickle
from os.path import join, exists
from pathlib import Path

from bqskit.compiler.basepass import _sub_do_work
from collections import Counter
from bqskit.compiler.basepass import BasePass
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
from bqskit.ir.gates.circuitgate import CircuitGate
from bqskit.ir.gate import Gate
from bqskit.ir.gates import CNOTGate
from bqskit.ir.gates.constant.unitary import ConstantUnitaryGate
from bqskit.ir.gates.parameterized.pauli import PauliGate
from bqskit.ir.gates.parameterized.unitary import VariableUnitaryGate
from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator
from bqskit.ir.opt.cost.generator import CostFunctionGenerator
from bqskit.ir.location import CircuitLocation
from bqskit.ir.operation import Operation
from bqskit.ir.point import CircuitPoint
Expand Down Expand Up @@ -60,9 +68,11 @@ def __init__(
self,
loop_body: WorkflowLike,
calculate_error_bound: bool = False,
error_cost_gen: CostFunctionGenerator = HilbertSchmidtResidualsGenerator(),
collection_filter: Callable[[Operation], bool] | None = None,
replace_filter: ReplaceFilterFn | str = 'always',
batch_size: int | None = None,
blocks_to_run: List[int] = []
) -> None:
"""
Construct a ForEachBlockPass.
Expand Down Expand Up @@ -127,6 +137,11 @@ def __init__(
Defaults to 'always'. #TODO: address importability

batch_size (int): (Deprecated).

blocks_to_run (List[int]):
A list of blocks to run the ForEachBlockPass body on. By default
you run on all blocks. This is mainly used with checkpointing,
where some blocks have already finished while others have not.
"""
if batch_size is not None:
import warnings
Expand All @@ -140,7 +155,8 @@ def __init__(
self.collection_filter = collection_filter or default_collection_filter
self.replace_filter = replace_filter or default_replace_filter
self.workflow = Workflow(loop_body)

self.blocks_to_run = sorted(blocks_to_run)
self.error_cost_gen = error_cost_gen
if not callable(self.collection_filter):
raise TypeError(
'Expected callable method that maps Operations to booleans for'
Expand All @@ -165,15 +181,32 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
else:
replace_filter = self.replace_filter

# If checkpoint_dir is defined in data, then we should checkpoint
# Note that checkpoint_dir is set in CheckpointRestartPass
should_checkpoint = 'checkpoint_dir' in data
checkpoint_dir = data.get('checkpoint_dir', "")
if should_checkpoint:
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

# Make room in data for block data
if self.key not in data:
data[self.key] = []

# Collect blocks
# Collect blocks to run with
blocks: list[tuple[int, Operation]] = []
for cycle, op in circuit.operations_with_cycles():
if self.collection_filter(op):
if (len(self.blocks_to_run) == 0):
self.blocks_to_run = list(range(circuit.num_operations))

block_ids = self.blocks_to_run.copy()
next_id = block_ids.pop(0)
for i, (cycle, op) in enumerate(circuit.operations_with_cycles()):
if self.collection_filter(op) and i == next_id:
blocks.append((cycle, op))
try:
next_id = block_ids.pop(0)
except IndexError:
# No more blocks to run on
break

# No blocks, no work
if len(blocks) == 0:
Expand All @@ -188,38 +221,70 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
subcircuits: list[Circuit] = []
block_datas: list[PassData] = []
for i, (cycle, op) in enumerate(blocks):
# Set up checkpoint data and circuit files
# Need to zero pad block ids for consistency
num_digits = len(str(circuit.num_operations))
block_num = str(self.blocks_to_run[i]).zfill(num_digits)
save_data_file = join(checkpoint_dir, f'block_{block_num}.data')
save_circuit_file = join(checkpoint_dir, f'block_{block_num}.pickle')
checkpoint_found = False
if should_checkpoint and exists(save_data_file):
# If checkpointing, reload block data and circuit from
# checkpoint if it exists
_logger.debug(f'Loading block {i} from checkpoint.')
try:
subcircuit = pickle.load(open(save_circuit_file, 'rb'))
block_data = pickle.load(open(save_data_file, 'rb'))
checkpoint_found = True
except Exception as e:
# Problem reading the checkpointed files
_logger.error(f"Exception for file: {save_data_file}", e)
checkpoint_found = False

# If no checkpoint found, form subcircuit and submodel
if not checkpoint_found:
# Form Subcircuit
if isinstance(op.gate, CircuitGate):
subcircuit = op.gate._circuit.copy()
subcircuit.set_params(op.params)
else:
subcircuit = Circuit.from_operation(op)

# Form Submodel
subradixes = [circuit.radixes[q] for q in op.location]
subnumbering = {op.location[i]: i for i in range(len(op.location))}
submodel = MachineModel(
len(op.location),
coupling_graph.get_subgraph(op.location, subnumbering),
model.gate_set,
subradixes,
)

# Form Subcircuit
if isinstance(op.gate, CircuitGate):
subcircuit = op.gate._circuit.copy()
subcircuit.set_params(op.params)
else:
subcircuit = Circuit.from_operation(op)

# Form Submodel
subradixes = [circuit.radixes[q] for q in op.location]
subnumbering = {op.location[i]: i for i in range(len(op.location))}
submodel = MachineModel(
len(op.location),
coupling_graph.get_subgraph(op.location, subnumbering),
model.gate_set,
subradixes,
)

# Form Subdata
block_data: PassData = PassData(subcircuit)
block_data['subnumbering'] = subnumbering
block_data['model'] = submodel
block_data['point'] = CircuitPoint(cycle, op.location[0])
block_data['calculate_error_bound'] = self.calculate_error_bound
for key in data:
if key.startswith(self.pass_down_key_prefix):
block_data[key] = data[key]
elif key.startswith(
self.pass_down_block_specific_key_prefix,
) and i in data[key]:
block_data[key] = data[key][i]
block_data.seed = data.seed
# Form Subdata
block_data: PassData = PassData(subcircuit)
block_data['subnumbering'] = subnumbering
block_data['model'] = submodel
block_data['point'] = CircuitPoint(cycle, op.location[0])
block_data['calculate_error_bound'] = self.calculate_error_bound
for key in data:
if key.startswith(self.pass_down_key_prefix):
block_data[key] = data[key]
elif key.startswith(
self.pass_down_block_specific_key_prefix,
) and i in data[key]:
block_data[key] = data[key][i]
block_data.seed = data.seed

# Change next subdirectory
if should_checkpoint:
# Blocks can have sub-blocks, so we need to change the ckpt dir
# for each block as a sub-folder of the main checkpoint dir
block_data["checkpoint_dir"] = join(checkpoint_dir, f'block_{block_num}')
block_data["checkpoint_circ_file"] = save_circuit_file
block_data["checkpoint_data_file"] = save_data_file
block_data['block_num'] = block_num
pickle.dump(block_data, open(save_data_file, 'wb'))
pickle.dump(subcircuit, open(save_circuit_file, 'wb'))

subcircuits.append(subcircuit)
block_datas.append(block_data)
Expand All @@ -229,7 +294,7 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
_sub_do_work,
[self.workflow] * len(subcircuits),
subcircuits,
block_datas,
block_datas
)

# Unpack results
Expand Down Expand Up @@ -272,7 +337,6 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
if self.calculate_error_bound:
_logger.debug(f'New circuit error is {data.error}.')


def default_collection_filter(op: Operation) -> bool:
return isinstance(
op.gate, (
Expand Down
2 changes: 2 additions & 0 deletions bqskit/passes/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""This package implements various IO related passes."""
from __future__ import annotations

from bqskit.passes.io.intermediate import CheckpointRestartPass
from bqskit.passes.io.checkpoint import LoadCheckpointPass
from bqskit.passes.io.checkpoint import SaveCheckpointPass
from bqskit.passes.io.intermediate import RestoreIntermediatePass
from bqskit.passes.io.intermediate import SaveIntermediatePass

__all__ = [
'CheckpointRestartPass',
'LoadCheckpointPass',
'SaveCheckpointPass',
'SaveIntermediatePass',
Expand Down
Loading
Loading