Skip to content

Commit

Permalink
🚑 Fixed param_grid as optional and checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
albertnieto committed Sep 27, 2024
1 parent bc9925b commit fc17d0f
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 36 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "qcml"
version = "0.1.2"
version = "0.1.2.1"
description = "A benchmarking library for quantum and classical machine learning, with specialized support for evaluating kernel methods."
readme = "README.md"
authors = [
Expand Down
16 changes: 15 additions & 1 deletion qcml/bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
# limitations under the License.

from .grid_search import GridSearch
from .model_evaluator import ModelEvaluator
from .parameter_grid import ParameterGrid
from .checkpoint import (
save_checkpoint,
load_checkpoint,
delete_checkpoints,
get_highest_batch_checkpoint,
)
from .grids.kernel_grid import (
classical_kernel_grid,
classical_kernel_param_map,
Expand All @@ -26,7 +34,13 @@
from .grids.transformation_grid import get_kernel_transform

__all__ = [
"grid_search",
"GridSearch",
"ModelEvaluator",
"ParameterGrid",
"save_checkpoint",
"load_checkpoint",
"delete_checkpoints",
"get_highest_batch_checkpoint",
"classical_kernel_grid",
"classical_kernel_param_map",
"quantum_kernel_grid",
Expand Down
94 changes: 61 additions & 33 deletions qcml/bench/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# File: qic/bench/checkpoint.py

import os
import json
import logging

logger = logging.getLogger(__name__)


def save_checkpoint(
results,
classifier_name,
Expand All @@ -29,26 +26,37 @@ def save_checkpoint(
dataset_name,
results_path="checkpoints/",
):
# Handle classifier_name being either a string or a list
if isinstance(classifier_name, list):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
else:
classifier_name_str = classifier_name

checkpoint_file_name = (
f"qcml-{experiment_name}-{classifier_name}-{dataset_name}-batch{batch_idx}.json"
f"qcml-{experiment_name}-{classifier_name_str}-{dataset_name}-batch{batch_idx}.json"
)
checkpoint_file_path = os.path.join(results_path, checkpoint_file_name)

if not os.path.exists(results_path):
os.makedirs(results_path)
# Ensure the results_path directory exists
os.makedirs(results_path, exist_ok=True)

checkpoint_data = {"results": results, "batch_idx": batch_idx}

with open(checkpoint_file_path, "w") as f:
json.dump(checkpoint_data, f)

logger.info(f"Checkpoint saved to {checkpoint_file_path}")

try:
with open(checkpoint_file_path, "w") as f:
json.dump(checkpoint_data, f)
logger.info(f"Checkpoint saved to {checkpoint_file_path}")
except IOError as e:
logger.error(f"Failed to save checkpoint: {e}")

def load_checkpoint(
classifier_name, experiment_name, dataset_name, results_path="checkpoints/"
):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
if isinstance(classifier_name, list):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
else:
classifier_name_str = classifier_name

checkpoint_file_name = (
f"qcml-{experiment_name}-{classifier_name_str}-{dataset_name}.json"
)
Expand All @@ -58,40 +66,56 @@ def load_checkpoint(
logger.info("No checkpoint found for this classifier, dataset, and experiment.")
return None

with open(checkpoint_file_path, "r") as f:
checkpoint_data = json.load(f)

logger.info(f"Checkpoint loaded from {checkpoint_file_path}")
return checkpoint_data

try:
with open(checkpoint_file_path, "r") as f:
checkpoint_data = json.load(f)
logger.info(f"Checkpoint loaded from {checkpoint_file_path}")
return checkpoint_data
except IOError as e:
logger.error(f"Failed to load checkpoint: {e}")
return None

def delete_checkpoints(
classifier_name, experiment_name, dataset_name, results_path="checkpoints/"
):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
if isinstance(classifier_name, list):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
else:
classifier_name_str = classifier_name

checkpoint_file_name = (
f"qcml-{experiment_name}-{classifier_name_str}-{dataset_name}.json"
)
checkpoint_file_path = os.path.join(results_path, checkpoint_file_name)
if os.path.exists(checkpoint_file_path):
os.remove(checkpoint_file_path)
logger.info(
f"Checkpoint {checkpoint_file_path} deleted after successful completion of the experiment."
)
try:
os.remove(checkpoint_file_path)
logger.info(
f"Checkpoint {checkpoint_file_path} deleted after successful completion of the experiment."
)
except OSError as e:
logger.error(f"Failed to delete checkpoint: {e}")
else:
logger.info(
f"No checkpoint found to delete for {classifier_name_str} on {dataset_name}."
)


def get_highest_batch_checkpoint(
classifier_name, experiment_name, dataset_name, results_path="checkpoints/"
):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
if isinstance(classifier_name, list):
classifier_name_str = "+".join([clf.__name__ for clf in classifier_name])
else:
classifier_name_str = classifier_name

checkpoint_file_pattern = (
f"qcml-{experiment_name}-{classifier_name_str}-{dataset_name}-batch"
)

if not os.path.exists(results_path):
logger.debug(f"No checkpoints directory found at {results_path}.")
return None, None

# List all files in the checkpoints directory that match the pattern
checkpoint_files = [
f
Expand Down Expand Up @@ -120,11 +144,15 @@ def get_highest_batch_checkpoint(

if latest_checkpoint_file:
checkpoint_file_path = os.path.join(results_path, latest_checkpoint_file)
with open(checkpoint_file_path, "r") as f:
checkpoint_data = json.load(f)
logger.debug(
f"Checkpoint found: {latest_checkpoint_file} with batch {highest_batch}."
)
return checkpoint_data, highest_batch

return None, None
try:
with open(checkpoint_file_path, "r") as f:
checkpoint_data = json.load(f)
logger.debug(
f"Checkpoint found: {latest_checkpoint_file} with batch {highest_batch}."
)
return checkpoint_data, highest_batch
except IOError as e:
logger.error(f"Failed to load checkpoint: {e}")
return None, None

return None, None
7 changes: 6 additions & 1 deletion qcml/bench/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class GridSearch:
def __init__(
self,
classifiers: List[Callable],
param_grid: Dict[str, List[Any]],
param_grid: Optional[Dict[str, List[Any]]] = None,
combinations: Optional[
List[Tuple[Dict[str, Any], Callable, Dict[str, Any]]]
] = None,
Expand Down Expand Up @@ -66,13 +66,18 @@ def __init__(
# Initialize ModelEvaluator
self.evaluator = ModelEvaluator(use_jax=self.use_jax)

# Validate input parameters
if self.combinations is None and self.param_grid is None:
raise ValueError("Either 'param_grid' or 'combinations' must be provided.")

# Initialize ParameterGrid if combinations are not provided
if self.combinations is None:
self.param_grid_obj = ParameterGrid(
self.param_grid, self.transformations, self.transformation_params
)
self.combinations = self.param_grid_obj.combinations


def run(
self,
datasets: Optional[List[Dict[str, Any]]] = None,
Expand Down

0 comments on commit fc17d0f

Please sign in to comment.