Skip to content

Commit

Permalink
fix: simplify classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Edoardo-Pedicillo committed Feb 13, 2025
1 parent 8796224 commit 8003090
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 173 deletions.
1 change: 1 addition & 0 deletions src/qibocal/fitting/classifier/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def train_qubit(
)
hpars_list.append(hyperpars)
classifier.dump_hyper(hyperpars)
pdb.set_trace()
model = classifier.create_model(hyperpars)

results, _y_pred, model, _ = benchmarking(
Expand Down
180 changes: 45 additions & 135 deletions src/qibocal/protocols/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,19 @@

import numpy as np
import numpy.typing as npt
import pandas as pd
import plotly.graph_objects as go
from qibolab import AcquisitionType, ExecutionParameters
from qibolab.platform import Platform
from qibolab.pulses import PulseSequence
from qibolab.qubits import QubitId
from sklearn.metrics import roc_auc_score, roc_curve

from qibocal import update
from qibocal.auto.operation import RESULTSFILE, Data, Parameters, Results, Routine
from qibocal.auto.serialize import serialize
from qibocal.fitting.classifier import run
from qibocal.fitting.classifier.qubit_fit import QubitFit
from qibocal.protocols.utils import (
LEGEND_FONT_SIZE,
MESH_SIZE,
TITLE_SIZE,
evaluate_grid,
format_error_single_cell,
get_color_state0,
plot_results,
round_report,
table_dict,
Expand All @@ -45,7 +39,7 @@ class SingleShotClassificationParameters(Parameters):
classifiers_list: Optional[list[str]] = field(
default_factory=lambda: [DEFAULT_CLASSIFIER]
)
"""List of models to classify the qubit states"""
"""List of model to classify the qubit states"""
savedir: Optional[str] = " "
"""Dumping folder of the classification results"""

Expand All @@ -67,19 +61,13 @@ class SingleShotClassificationData(Data):
classifiers_list: Optional[list[str]] = field(
default_factory=lambda: [DEFAULT_CLASSIFIER]
)
"""List of models to classify the qubit states"""
"""List of model to classify the qubit states"""


@dataclass
class SingleShotClassificationResults(Results):
"""SingleShotClassification outputs."""

names: list
"""List of models name."""
savedir: str
"""Dumping folder of the classification results."""
y_preds: dict[QubitId, list]
"""Models' predictions of the test set."""
grid_preds: dict[QubitId, list]
"""Models' prediction of the contour grid."""
threshold: dict[QubitId, float] = field(default_factory=dict)
Expand All @@ -96,16 +84,6 @@ class SingleShotClassificationResults(Results):
"""Assignment fidelity evaluated only with the `qubit_fit` model."""
effective_temperature: dict[QubitId, float] = field(default_factory=dict)
"""Qubit effective temperature from Boltzmann distribution."""
models: dict[QubitId, list] = field(default_factory=list)
"""List of trained classification models."""
benchmark_table: Optional[dict[QubitId, pd.DataFrame]] = field(default_factory=dict)
"""Benchmark tables."""
classifiers_hpars: Optional[dict[QubitId, dict]] = field(default_factory=dict)
"""Classifiers hyperparameters."""
x_tests: dict[QubitId, list] = field(default_factory=dict)
"""Test set."""
y_tests: dict[QubitId, list] = field(default_factory=dict)
"""Test set."""

def __contains__(self, key: QubitId):
"""Checking if key is in Results.
Expand All @@ -122,7 +100,7 @@ def __contains__(self, key: QubitId):

def save(self, path):
classifiers = run.import_classifiers(self.names)
for qubit in self.models:
for qubit in self.model:
for i, mod in enumerate(classifiers):
if self.savedir == " ":
save_path = pathlib.Path(path)
Expand All @@ -132,10 +110,10 @@ def save(self, path):
classifier = run.Classifier(mod, save_path / f"qubit{qubit}")
classifier.savedir.mkdir(parents=True, exist_ok=True)
dump_dir = classifier.base_dir / classifier.name / classifier.name
classifier.dump()(self.models[qubit][i], dump_dir)
classifier.dump()(self.model[qubit][i], dump_dir)
classifier.dump_hyper(self.classifiers_hpars[qubit][classifier.name])
asdict_class = asdict(self)
asdict_class.pop("models")
asdict_class.pop("model")
asdict_class.pop("classifiers_hpars")
(path / f"{RESULTSFILE}.json").write_text(json.dumps(serialize(asdict_class)))

Expand Down Expand Up @@ -251,59 +229,32 @@ def _fit(data: SingleShotClassificationData) -> SingleShotClassificationResults:
effective_temperature = {}
for qubit in qubits:
qubit_data = data.data[qubit]
state0_data = qubit_data[qubit_data.state == 0]
iq_state0 = state0_data[["i", "q"]]
benchmark_table, y_test, x_test, models, names, hpars_list = run.train_qubit(
data, qubit
)
benchmark_tables[qubit] = benchmark_table.values.tolist()
models_dict[qubit] = models
y_tests[qubit] = y_test.tolist()
x_tests[qubit] = x_test.tolist()
hpars[qubit] = {}
y_preds = []
grid_preds = []
iq_values = data["i", "q"]
states = data["state"]
model = QubitFit()
model.fit(iq_values, states)
grid = evaluate_grid(qubit_data)
for i, model_name in enumerate(names):
hpars[qubit][model_name] = hpars_list[i]
try:
y_preds.append(models[i].predict_proba(x_test)[:, 1].tolist())
except AttributeError:
y_preds.append(models[i].predict(x_test).tolist())
grid_preds.append(
np.round(np.reshape(models[i].predict(grid), (MESH_SIZE, MESH_SIZE)))
.astype(np.int64)
.tolist()
)
if model_name == "qubit_fit":
threshold[qubit] = models[i].threshold
rotation_angle[qubit] = models[i].angle
mean_gnd_states[qubit] = models[i].iq_mean0.tolist()
mean_exc_states[qubit] = models[i].iq_mean1.tolist()
fidelity[qubit] = models[i].fidelity
assignment_fidelity[qubit] = models[i].assignment_fidelity
predictions_state0 = models[i].predict(iq_state0.tolist())
effective_temperature[qubit] = models[i].effective_temperature(
predictions_state0, data.qubit_frequencies[qubit]
)
grid_preds = model.predict(grid)
threshold[qubit] = model.threshold
rotation_angle[qubit] = model.angle
mean_gnd_states[qubit] = model.iq_mean0.tolist()
mean_exc_states[qubit] = model.iq_mean1.tolist()
fidelity[qubit] = model.fidelity
assignment_fidelity[qubit] = model.assignment_fidelity
predictions_state0 = model.predict(iq_state0.tolist())
effective_temperature[qubit] = model.effective_temperature(
predictions_state0, data.qubit_frequencies[qubit]
)
y_test_predict[qubit] = y_preds
grid_preds_dict[qubit] = grid_preds
return SingleShotClassificationResults(
benchmark_table=benchmark_tables,
y_tests=y_tests,
x_tests=x_tests,
names=names,
classifiers_hpars=hpars,
models=models_dict,
threshold=threshold,
rotation_angle=rotation_angle,
mean_gnd_states=mean_gnd_states,
mean_exc_states=mean_exc_states,
fidelity=fidelity,
assignment_fidelity=assignment_fidelity,
effective_temperature=effective_temperature,
savedir=data.savedir,
y_preds=y_test_predict,
grid_preds=grid_preds_dict,
)

Expand All @@ -314,74 +265,33 @@ def _plot(
fit: SingleShotClassificationResults,
):
fitting_report = ""
models_name = data.classifiers_list
figures = plot_results(data, target, 2, fit)
if fit is not None:
y_test = fit.y_tests[target]
y_pred = fit.y_preds[target]

if len(models_name) != 1:
# Evaluate the ROC curve
fig_roc = go.Figure()
fig_roc.add_shape(
type="line", line=dict(dash="dash"), x0=0.0, x1=1.0, y0=0.0, y1=1.0
)
for i, model in enumerate(models_name):
y_pred = fit.y_preds[target][i]
fpr, tpr, _ = roc_curve(y_test, y_pred)
auc_score = roc_auc_score(y_test, y_pred)
name = f"{model} (AUC={auc_score:.2f})"
fig_roc.add_trace(
go.Scatter(
x=fpr,
y=tpr,
name=name,
mode="lines",
marker=dict(size=3, color=get_color_state0(i)),
)
)
fig_roc.update_layout(
width=ROC_WIDTH,
height=ROC_LENGHT,
title=dict(text="ROC curves", font=dict(size=TITLE_SIZE)),
legend=dict(font=dict(size=LEGEND_FONT_SIZE)),
)
fig_roc.update_xaxes(
title_text=f"False Positive Rate",
range=[0, 1],
)
fig_roc.update_yaxes(
title_text="True Positive Rate",
range=[0, 1],
)
figures.append(fig_roc)

if "qubit_fit" in models_name:
fitting_report = table_html(
table_dict(
target,
[
"Average State 0",
"Average State 1",
"Rotational Angle",
"Threshold",
"Readout Fidelity",
"Assignment Fidelity",
"Effective Qubit Temperature [K]",
],
[
np.round(fit.mean_gnd_states[target], 3),
np.round(fit.mean_exc_states[target], 3),
np.round(fit.rotation_angle[target], 3),
np.round(fit.threshold[target], 6),
np.round(fit.fidelity[target], 3),
np.round(fit.assignment_fidelity[target], 3),
format_error_single_cell(
round_report([fit.effective_temperature[target]])
),
],
)
fitting_report = table_html(
table_dict(
target,
[
"Average State 0",
"Average State 1",
"Rotational Angle",
"Threshold",
"Readout Fidelity",
"Assignment Fidelity",
"Effective Qubit Temperature [K]",
],
[
np.round(fit.mean_gnd_states[target], 3),
np.round(fit.mean_exc_states[target], 3),
np.round(fit.rotation_angle[target], 3),
np.round(fit.threshold[target], 6),
np.round(fit.fidelity[target], 3),
np.round(fit.assignment_fidelity[target], 3),
format_error_single_cell(
round_report([fit.effective_temperature[target]])
),
],
)
)

return figures, fitting_report

Expand Down
38 changes: 0 additions & 38 deletions src/qibocal/protocols/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1147,44 +1147,6 @@ def plot_results(data: Data, qubit: QubitId, qubit_states: list, fit: Results):
),
)
figures.append(fig)

if fit is not None and len(models_name) != 1:
fig_benchmarks = make_subplots(
rows=1,
cols=3,
horizontal_spacing=SPACING,
vertical_spacing=SPACING,
subplot_titles=(
"accuracy",
"testing time [s]",
"training time [s]",
),
# pylint: disable=E1101
)
for i, model in enumerate(models_name):
for plot in range(3):
fig_benchmarks.add_trace(
go.Scatter(
x=[model],
y=[fit.benchmark_table[qubit][i][plot]],
mode="markers",
showlegend=False,
marker=dict(size=10, color=get_color_state1(i)),
),
row=1,
col=plot + 1,
)

fig_benchmarks.update_yaxes(type="log", row=1, col=2)
fig_benchmarks.update_yaxes(type="log", row=1, col=3)
fig_benchmarks.update_layout(
autosize=False,
height=COLUMNWIDTH,
width=COLUMNWIDTH * 3,
title=dict(text="Benchmarks", font=dict(size=TITLE_SIZE)),
)

figures.append(fig_benchmarks)
return figures


Expand Down

0 comments on commit 8003090

Please sign in to comment.