From 10f67140c4d3140142009acc0fdcbaa127339de1 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Wed, 3 Jul 2024 15:56:29 -0500 Subject: [PATCH 01/17] three-qubit code finished --- SQIR/NDSem.v | 3 - SQIR/UnitaryOps.v | 26 +--- examples/error-correction/ErrorCorrection.v | 154 ++++++++++++++++++++ examples/error-correction/dune | 4 + 4 files changed, 159 insertions(+), 28 deletions(-) create mode 100644 examples/error-correction/ErrorCorrection.v create mode 100644 examples/error-correction/dune diff --git a/SQIR/NDSem.v b/SQIR/NDSem.v index 8ec9e1b..7b2a4a3 100644 --- a/SQIR/NDSem.v +++ b/SQIR/NDSem.v @@ -114,15 +114,12 @@ Proof. Msimpl; simpl; apply sqrt_0 ). - try apply norm_zero_iff_zero. try apply WF_Zero. try easy. + contradict H0. rewrite <- Mmult_assoc. rewrite proj_twice_neq by easy. unfold norm. Msimpl; simpl. try apply sqrt_0. - try apply norm_zero_iff_zero. try apply WF_Zero. - try easy. + rewrite <- Mmult_assoc in H0, H1. rewrite proj_twice_eq in H0, H1. apply nd_meas_f; assumption. diff --git a/SQIR/UnitaryOps.v b/SQIR/UnitaryOps.v index 5f5e84b..e395e99 100644 --- a/SQIR/UnitaryOps.v +++ b/SQIR/UnitaryOps.v @@ -416,25 +416,11 @@ Proof. - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; try lca; field_simplify_eq; try nonzero; group_Cexp. + simpl. try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). - try ( - rewrite Cplus_comm; unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; easy - ). + try (simpl; rewrite Copp_mult_distr_l, Copp_mult_distr_r; repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; autorewrite with RtoC_db; rewrite Ropp_involutive; setoid_rewrite sin2_cos2; rewrite Cmult_1_r; apply f_equal; lra). - try ( - simpl; repeat rewrite <- Cmult_assoc; simpl; - rewrite <- Cmult_plus_distr_l; - unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; autorewrite with R_db; - unfold Cexp; apply f_equal2; [apply f_equal; lra|] - ). - apply f_equal; lra. - rewrite <- Mscale_kron_dist_l. repeat rewrite <- Mscale_kron_dist_r. repeat (apply f_equal2; try reflexivity). @@ -452,7 +438,7 @@ Proof. unfold Cminus, Cmult; simpl; autorewrite with R_db; apply c_proj_eq; simpl; autorewrite with R_db). rewrite <- Rminus_unfold, <- cos_plus. - apply f_equal. try apply f_equal. try lra. lra. + apply f_equal. try apply f_equal. try lra. + apply f_equal2; [apply f_equal; lra|]. apply c_proj_eq; simpl; try lra. R_field_simplify. @@ -473,19 +459,9 @@ Proof. try (autorewrite with RtoC_db; rewrite Rplus_comm; rewrite <- Rminus_unfold, <- cos_plus; apply f_equal; apply f_equal; lra). - try ( - rewrite Cplus_comm; apply c_proj_eq; simpl; try lra; - autorewrite with R_db; rewrite <- Rminus_unfold; - rewrite <- cos_plus; apply f_equal; lra - ). - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; try lca; field_simplify_eq; try nonzero; group_Cexp. + try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). - try ( - simpl; rewrite Cplus_comm; unfold Cplus, Cmult; - autorewrite with R_db; simpl; - setoid_rewrite sin2_cos2; easy - ). + try (rewrite Copp_mult_distr_l, Copp_mult_distr_r; repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; autorewrite with RtoC_db; rewrite Ropp_involutive; diff --git a/examples/error-correction/ErrorCorrection.v b/examples/error-correction/ErrorCorrection.v new file mode 100644 index 0000000..92ef2c4 --- /dev/null +++ b/examples/error-correction/ErrorCorrection.v @@ -0,0 +1,154 @@ +Require Export SQIR.UnitaryOps. +Require Import QuantumLib.Measurement. + +Module ThreeQubitCode. + +Open Scope ucom. + +(* q at 0; encoding/decoding ancillae at 1 and 2; and recovery ancillae at 3 and 4. *) +Definition dim : nat := 5. + +Definition encode : base_ucom dim := + CNOT 0 1; CNOT 0 2. + +Theorem encode_correct : forall (α β : C), + (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩ ) + = α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩. +Proof. + intros. + simpl. + autorewrite with eval_db; simpl. + Qsimpl. + replace (I 8) with (I 2 ⊗ I 2 ⊗ I 2). + replace (I 4) with (I 2 ⊗ I 2). + 2,3: repeat rewrite id_kron; easy. + repeat (distribute_plus; + repeat rewrite <- kron_assoc by auto with wf_db; + restore_dims). +repeat rewrite kron_mixed_product. + Qsimpl. + autorewrite with ket_db. + rewrite Mplus_comm; easy. +Qed. + +Inductive error : Set := + | NoError + | BitFlip0 + | BitFlip1 + | BitFlip2. + +Definition apply_error (e : error) : base_ucom dim := + match e with + | NoError => SKIP + | BitFlip0 => X 0 + | BitFlip1 => X 1 + | BitFlip2 => X 2 + end. + +Definition error_syndrome (e : error) : Vector 4 := + match e with + | NoError => ∣0,0⟩ + | BitFlip0 => ∣0,1⟩ + | BitFlip1 => ∣1,0⟩ + | BitFlip2 => ∣1,1⟩ + end. + +Definition Toffoli_false_fst {dim} (a b c : nat) : base_ucom dim := + X a; + CCX a b c; + X a. + +Definition recover : base_ucom dim := + CNOT 0 4; CNOT 1 4; + CNOT 1 3; CNOT 2 3; + CNOT 3 4; + Toffoli_false_fst 3 4 0; + Toffoli_false_fst 4 3 1; + CCX 3 4 2. + +Definition decode : base_ucom dim := + CNOT 0 1; + CNOT 0 2. + +Theorem decode_correct : forall (α β : C) (φ : Vector 4), + WF_Matrix φ -> + (@uc_eval dim decode) × ((α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ φ) + = ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ φ). +Proof. + intros. + simpl. + autorewrite with eval_db; simpl; Qsimpl. + rewrite Mmult_assoc. + replace (I 8) with (I 2 ⊗ I 4) by ( + repeat rewrite id_kron; + Qsimpl; easy + ). + repeat (distribute_plus; + repeat rewrite <- kron_assoc by auto with wf_db; + restore_dims). + autorewrite with ket_db. + apply Mplus_comm. +Qed. + +Definition error_recover_correct (e : error) : forall (α β : C), + (@uc_eval dim (apply_error e; recover)) × (α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩) = + (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). +Proof. + intros. + destruct e. + Local Opaque CCX. + all : unfold apply_error; unfold recover; simpl. + all : try rewrite denote_SKIP; Qsimpl. + 2 : unfold dim; lia. + all : repeat rewrite Mmult_assoc. + all : repeat rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_mult_dist_r. + all : replace (∣0, 0, 0, 0, 0⟩) with (f_to_vec dim (fun _ => false)). + all : replace (∣1, 1, 1, 0, 0⟩) with (f_to_vec dim (fun n => n Date: Fri, 5 Jul 2024 21:45:33 -0500 Subject: [PATCH 02/17] three-qubit code phase-flip finished --- examples/error-correction/ErrorCorrection.v | 154 ---------- examples/error-correction/ThreeQubitCode.v | 313 ++++++++++++++++++++ examples/error-correction/dune | 2 +- 3 files changed, 314 insertions(+), 155 deletions(-) delete mode 100644 examples/error-correction/ErrorCorrection.v create mode 100644 examples/error-correction/ThreeQubitCode.v diff --git a/examples/error-correction/ErrorCorrection.v b/examples/error-correction/ErrorCorrection.v deleted file mode 100644 index 92ef2c4..0000000 --- a/examples/error-correction/ErrorCorrection.v +++ /dev/null @@ -1,154 +0,0 @@ -Require Export SQIR.UnitaryOps. -Require Import QuantumLib.Measurement. - -Module ThreeQubitCode. - -Open Scope ucom. - -(* q at 0; encoding/decoding ancillae at 1 and 2; and recovery ancillae at 3 and 4. *) -Definition dim : nat := 5. - -Definition encode : base_ucom dim := - CNOT 0 1; CNOT 0 2. - -Theorem encode_correct : forall (α β : C), - (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩ ) - = α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩. -Proof. - intros. - simpl. - autorewrite with eval_db; simpl. - Qsimpl. - replace (I 8) with (I 2 ⊗ I 2 ⊗ I 2). - replace (I 4) with (I 2 ⊗ I 2). - 2,3: repeat rewrite id_kron; easy. - repeat (distribute_plus; - repeat rewrite <- kron_assoc by auto with wf_db; - restore_dims). -repeat rewrite kron_mixed_product. - Qsimpl. - autorewrite with ket_db. - rewrite Mplus_comm; easy. -Qed. - -Inductive error : Set := - | NoError - | BitFlip0 - | BitFlip1 - | BitFlip2. - -Definition apply_error (e : error) : base_ucom dim := - match e with - | NoError => SKIP - | BitFlip0 => X 0 - | BitFlip1 => X 1 - | BitFlip2 => X 2 - end. - -Definition error_syndrome (e : error) : Vector 4 := - match e with - | NoError => ∣0,0⟩ - | BitFlip0 => ∣0,1⟩ - | BitFlip1 => ∣1,0⟩ - | BitFlip2 => ∣1,1⟩ - end. - -Definition Toffoli_false_fst {dim} (a b c : nat) : base_ucom dim := - X a; - CCX a b c; - X a. - -Definition recover : base_ucom dim := - CNOT 0 4; CNOT 1 4; - CNOT 1 3; CNOT 2 3; - CNOT 3 4; - Toffoli_false_fst 3 4 0; - Toffoli_false_fst 4 3 1; - CCX 3 4 2. - -Definition decode : base_ucom dim := - CNOT 0 1; - CNOT 0 2. - -Theorem decode_correct : forall (α β : C) (φ : Vector 4), - WF_Matrix φ -> - (@uc_eval dim decode) × ((α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ φ) - = ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ φ). -Proof. - intros. - simpl. - autorewrite with eval_db; simpl; Qsimpl. - rewrite Mmult_assoc. - replace (I 8) with (I 2 ⊗ I 4) by ( - repeat rewrite id_kron; - Qsimpl; easy - ). - repeat (distribute_plus; - repeat rewrite <- kron_assoc by auto with wf_db; - restore_dims). - autorewrite with ket_db. - apply Mplus_comm. -Qed. - -Definition error_recover_correct (e : error) : forall (α β : C), - (@uc_eval dim (apply_error e; recover)) × (α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩) = - (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). -Proof. - intros. - destruct e. - Local Opaque CCX. - all : unfold apply_error; unfold recover; simpl. - all : try rewrite denote_SKIP; Qsimpl. - 2 : unfold dim; lia. - all : repeat rewrite Mmult_assoc. - all : repeat rewrite Mmult_plus_distr_l. - all : repeat rewrite Mscale_mult_dist_r. - all : replace (∣0, 0, 0, 0, 0⟩) with (f_to_vec dim (fun _ => false)). - all : replace (∣1, 1, 1, 0, 0⟩) with (f_to_vec dim (fun n => n SKIP + | BitFlip0 => X 0 + | BitFlip1 => X 1 + | BitFlip2 => X 2 + end. + +Definition error_syndrome (e : error) : Vector 4 := + match e with + | NoError => ∣0,0⟩ + | BitFlip0 => ∣0,1⟩ + | BitFlip1 => ∣1,0⟩ + | BitFlip2 => ∣1,1⟩ + end. + + +Definition recover : base_ucom dim := + CNOT 0 4; CNOT 1 4; + CNOT 1 3; CNOT 2 3; + CNOT 3 4; + Toffoli_false_fst 3 4 0; + Toffoli_false_fst 4 3 1; + CCX 3 4 2. + +Definition decode : base_ucom dim := + CNOT 0 1; + CNOT 0 2. + +Theorem decode_correct : forall (α β : C) (φ : Vector 4), + WF_Matrix φ -> + (@uc_eval dim decode) × ((α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ φ) + = ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ φ). +Proof. + intros. + simpl. + autorewrite with eval_db; simpl; Qsimpl. + rewrite Mmult_assoc. + replace (I 8) with (I 2 ⊗ I 4) by ( + repeat rewrite id_kron; + Qsimpl; easy + ). + repeat (distribute_plus; + repeat rewrite <- kron_assoc by auto with wf_db; + restore_dims). + autorewrite with ket_db. + apply Mplus_comm. +Qed. + +Definition error_recover_correct (e : error) : forall (α β : C), + (@uc_eval dim (apply_error e; recover)) × (α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩) = + (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). +Proof. + intros. + destruct e. + Local Opaque CCX. + all : unfold apply_error; unfold recover; simpl. + all : try rewrite denote_SKIP by (unfold dim; lia); Qsimpl. + all : repeat rewrite Mmult_assoc. + all : repeat rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_mult_dist_r. + all : replace (∣0, 0, 0, 0, 0⟩) with (f_to_vec dim (fun _ => false)). + all : replace (∣1, 1, 1, 0, 0⟩) with (f_to_vec dim (fun n => n SKIP + | PhaseFlip0 => SQIR.Z 0 + | PhaseFlip1 => SQIR.Z 1 + | PhaseFlip2 => SQIR.Z 2 + end. + +Definition error_syndrome (e : error) : Vector 4 := + match e with + | NoError => ∣0,0⟩ + | PhaseFlip0 => ∣0,1⟩ + | PhaseFlip1 => ∣1,0⟩ + | PhaseFlip2 => ∣1,1⟩ + end. + +Definition recover : base_ucom dim := + H 0; H 1; H 2; + BitFlip.recover. + +Definition decode := BitFlip.decode. + +Definition decode_correct := BitFlip.decode_correct. + +Theorem Hplus_spec' : hadamard × ∣+⟩ = ∣0⟩. +Proof. + replace (∣+⟩) with (∣ + ⟩) by solve_matrix. + apply Hplus_spec. +Qed. + +Theorem Hminus_spec' : hadamard × ∣-⟩ = ∣1⟩. +Proof. + replace (∣-⟩) with (∣ - ⟩) by solve_matrix. + apply Hminus_spec. +Qed. + +Definition error_recover_correct (e : error) : forall (α β : C), + (@uc_eval dim (apply_error e; recover)) × (α .* ∣+⟩ ⊗ ∣+⟩ ⊗ ∣+⟩ ⊗ ∣0,0⟩ .+ β .* ∣-⟩ ⊗ ∣-⟩ ⊗ ∣-⟩ ⊗ ∣0,0⟩) = + (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). +Proof. + intros. + destruct e. + + Local Opaque CCX BitFlip.recover. + all : simpl. + all : repeat rewrite Mmult_assoc. + all : autorewrite with ket_db eval_db; simpl. + 2 : unfold dim; lia. + all : replace (I 16) with (I 4 ⊗ I 4) by (repeat rewrite id_kron; easy). + all : replace (I 8) with (I 4 ⊗ I 2) by (repeat rewrite id_kron; easy). + all : replace (I 4) with (I 2 ⊗ I 2) by (repeat rewrite id_kron; easy). + all : repeat ( + repeat rewrite <- kron_assoc by auto with wf_db; + restore_dims + ). + all : repeat rewrite kron_mixed_product. + all : Qsimpl; replace (σz × ∣+⟩) with (∣-⟩) by solve_matrix; + replace (σz × ∣-⟩) with (∣+⟩) by solve_matrix. + all : repeat rewrite Hplus_spec', Hminus_spec'. + all : replace (∣ 0 ⟩) with (∣0⟩) by solve_matrix. + all : replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by solve_matrix. + all : replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by solve_matrix. + all : restore_dims. + all : repeat rewrite kron_assoc by auto with wf_db. + all : repeat (rewrite f_to_vec_merge; restore_dims). + Local Transparent BitFlip.recover. + all : simpl uc_eval. + all : repeat rewrite Mmult_assoc; restore_dims. + + (* Faster that f_to_vec_simpl with + transparent CCX *) + all : repeat ( + first + [ rewrite f_to_vec_CNOT + | rewrite f_to_vec_CCX + | rewrite f_to_vec_X + ]; + unfold dim; try lia; + simpl update + ). + all : simpl; Qsimpl. + all : repeat rewrite <- kron_assoc by auto with wf_db. + all : reflexivity. +Qed. + +(** The rest of the circuit is the same as + the BitFlip case. *) + +Definition phase_flip_recover (e : error) : base_ucom dim := + encode; + apply_error e; + recover; + decode. + + +Theorem three_code_correct (e : error) : forall (α β : C), + (@uc_eval dim (phase_flip_recover e) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩)) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ (error_syndrome e). +Proof. + intros. + unfold phase_flip_recover. + Local Opaque encode apply_error recover decode. + simpl uc_eval. + repeat rewrite Mmult_assoc. + rewrite encode_correct. + rewrite <- Mmult_assoc with (B := uc_eval (apply_error e)). + specialize (error_recover_correct e α β) as H. + simpl in H. + setoid_rewrite H. + apply decode_correct. + destruct e; simpl; auto with wf_db. +Qed. + +End PhaseFlip. + +End ThreeQubitCode. diff --git a/examples/error-correction/dune b/examples/error-correction/dune index 122107c..4e5c890 100644 --- a/examples/error-correction/dune +++ b/examples/error-correction/dune @@ -1,4 +1,4 @@ (coq.theory - (name error-correction) + (name ErrorCorrection) (theories SQIR)) From bcfa76d6ae353d43809a1a255a8aa224a19eff36 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Sat, 6 Jul 2024 12:29:32 -0500 Subject: [PATCH 03/17] formatting gaffes --- examples/error-correction/ThreeQubitCode.v | 23 +++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/error-correction/ThreeQubitCode.v b/examples/error-correction/ThreeQubitCode.v index 067a931..d535d7c 100644 --- a/examples/error-correction/ThreeQubitCode.v +++ b/examples/error-correction/ThreeQubitCode.v @@ -7,7 +7,7 @@ Open Scope ucom. Definition Toffoli_false_fst {dim} (a b c : nat) : base_ucom dim := X a; -CCX a b c; + CCX a b c; X a. @@ -21,17 +21,17 @@ Definition encode : base_ucom dim := Theorem encode_correct : forall (α β : C), (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩ ) -= α .*∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩. + = α .*∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩. Proof. intros. simpl. autorewrite with eval_db; simpl. Qsimpl. -replace (I 8) with (I 2 ⊗ I 2 ⊗ I 2). + replace (I 8) with (I 2 ⊗ I 2 ⊗ I 2). replace (I 4) with (I 2 ⊗ I 2). 2,3: repeat rewrite id_kron; easy. repeat (distribute_plus; -repeat rewrite <- kron_assoc by auto with wf_db; + repeat rewrite <- kron_assoc by auto with wf_db; restore_dims). repeat rewrite kron_mixed_product. Qsimpl. @@ -41,8 +41,8 @@ Qed. Inductive error : Set := | NoError -| BitFlip0 -| BitFlip1 + | BitFlip0 + | BitFlip1 | BitFlip2. Definition apply_error (e : error) : base_ucom dim := @@ -95,8 +95,8 @@ Proof. Qed. Definition error_recover_correct (e : error) : forall (α β : C), - (@uc_eval dim (apply_error e; recover)) × (α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩) = - (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). + (@uc_eval dim (apply_error e; recover)) × (α .* ∣0,0,0,0,0⟩ .+ β .* ∣1,1,1,0,0⟩) + = (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). Proof. intros. destruct e. @@ -233,8 +233,8 @@ Proof. Qed. Definition error_recover_correct (e : error) : forall (α β : C), - (@uc_eval dim (apply_error e; recover)) × (α .* ∣+⟩ ⊗ ∣+⟩ ⊗ ∣+⟩ ⊗ ∣0,0⟩ .+ β .* ∣-⟩ ⊗ ∣-⟩ ⊗ ∣-⟩ ⊗ ∣0,0⟩) = - (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). + (@uc_eval dim (apply_error e; recover)) × (α .* ∣+⟩ ⊗ ∣+⟩ ⊗ ∣+⟩ ⊗ ∣0,0⟩ .+ β .* ∣-⟩ ⊗ ∣-⟩ ⊗ ∣-⟩ ⊗ ∣0,0⟩) + = (α .* ∣0,0,0⟩ .+ β .* ∣1,1,1⟩) ⊗ (error_syndrome e). Proof. intros. destruct e. @@ -292,7 +292,8 @@ Definition phase_flip_recover (e : error) : base_ucom dim := Theorem three_code_correct (e : error) : forall (α β : C), - (@uc_eval dim (phase_flip_recover e) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩)) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ (error_syndrome e). + (@uc_eval dim (phase_flip_recover e) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0,0,0⟩)) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ∣0,0⟩ ⊗ (error_syndrome e). Proof. intros. unfold phase_flip_recover. From 91c375a1384347d39bc707d022ccffd2a169b699 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 25 Jul 2024 15:02:18 -0500 Subject: [PATCH 04/17] encoding nine qubits and errors --- examples/error-correction/NineQubitCode.v | 177 ++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 examples/error-correction/NineQubitCode.v diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v new file mode 100644 index 0000000..6d5e73b --- /dev/null +++ b/examples/error-correction/NineQubitCode.v @@ -0,0 +1,177 @@ +Require Import Vectors.Fin. +Require Export SQIR.UnitaryOps. + +Module NineQubitCode. + +Open Scope ucom. +Open Scope nat_scope. + +Definition dim : nat := 9. + +(** + Blocks + *) + +(* Encoded blocks *) +Definition block_no := Fin.t 3. + +(* Qubits in a single block *) +Definition block_offset := Fin.t 3. + +Definition block_to_qubit (n : block_no) (off : block_offset) : nat := + proj1_sig (Fin.to_nat n) * 3 + proj1_sig (Fin.to_nat off). + +Compute block_to_qubit (@Fin.of_nat_lt 2 3 ltac:(lia)) (@Fin.of_nat_lt 2 3 ltac:(lia)). + +(** + Encoding + *) + +Definition encode_block (n : block_no) : base_ucom dim := + let q0 := proj1_sig (Fin.to_nat n) * 3 in + let q1 := q0 + 1 in + let q2 := q0 + 2 in + CNOT q0 q1; + CNOT q0 q2. + +Definition encode : base_ucom dim := + CNOT 0 3; CNOT 0 6; + H 0; H 3; H 6; + encode_block (@Fin.of_nat_lt 0 3 ltac:(lia)); + encode_block (@Fin.of_nat_lt 1 3 ltac:(lia)); + encode_block (@Fin.of_nat_lt 2 3 ltac:(lia)). + +Theorem encode_correct (α β : C) : + (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩ ) + = /C2 .* (/√ 2 .* (α .* (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + .+ /C2 .* (/√ 2 .* (β .* (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))). +Proof. + simpl. Qsimpl. + + replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by lma'. + replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by lma'. + restore_dims. + replace (∣0,0,0⟩) with (f_to_vec 3 (fun _ => false)) by lma'. + replace (∣1,1,1⟩) with (f_to_vec 3 (fun _ => true)) by lma'. + + repeat rewrite Mmult_assoc. + rewrite kron_plus_distr_r. + repeat rewrite Mmult_plus_distr_l. + distribute_scale. + repeat rewrite Mscale_mult_dist_r. + restore_dims. + repeat rewrite kron_assoc by ( + repeat rewrite kron_assoc by auto with wf_db; + repeat (rewrite f_to_vec_merge; restore_dims); + auto with wf_db + ). + repeat (rewrite f_to_vec_merge; restore_dims). + repeat rewrite f_to_vec_CNOT; try lia. + simpl update. + + repeat ( + rewrite f_to_vec_H; try lia; + simpl update; simpl b2R; + restore_dims; + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r + ). + repeat rewrite Mmult_plus_distr_l. + repeat rewrite Mscale_mult_dist_r. + + + repeat (rewrite f_to_vec_CNOT; try lia; try rewrite kron_1_l; simpl update). + simpl. Qsimpl. + + replace (0 * PI)%R with 0%R by lra. + replace (1 * PI)%R with PI by lra. + autorewrite with Cexp_db. + group_radicals. + repeat rewrite Mscale_1_l. + + repeat rewrite <- Mscale_plus_distr_r. + repeat rewrite kron_plus_distr_r. + repeat rewrite kron_plus_distr_l. + repeat rewrite Mplus_assoc. + f_equal. + repeat rewrite Mscale_assoc. + - replace (α * / √ 2 * / √ 2 * / √ 2)%C with (/√ 2 * / C2 * α)%C. + 2: { + (* why does lca not work here? *) + rewrite Cmult_comm. + do 2 rewrite <- Cmult_assoc. + f_equal. f_equal. + symmetry. apply Cinv_sqrt2_sqrt. + } + repeat (rewrite <- kron_assoc by auto 10 with wf_db; restore_dims). + reflexivity. + - do 2 ( + repeat rewrite Mscale_assoc; + repeat rewrite (Cmult_comm (-1)%R _); + repeat rewrite <- (Mscale_assoc _ _ _ (-1)%R _); + repeat rewrite <- Mscale_plus_distr_r + ). + repeat rewrite Mscale_assoc. + replace (β * / √ 2 * / √ 2 * / √ 2)%C with (/ √ 2 * / C2 * β)%C. + 2: { + (* why does lca not work here? *) + rewrite Cmult_comm. + do 2 rewrite <- Cmult_assoc. + f_equal. f_equal. + symmetry. apply Cinv_sqrt2_sqrt. + } + f_equal. + repeat rewrite Mscale_plus_distr_r. + distribute_scale. + repeat (rewrite <- kron_assoc by auto 10 with wf_db; restore_dims). + repeat rewrite Mplus_assoc. + reflexivity. +Qed. + + +(** + Errors + *) + +Inductive phase_flip_error (n : block_no) : Set := + | OnePhaseFlip (off : block_offset) + | MorePhaseFlip (e : phase_flip_error n) (off : block_offset). + + +Inductive bit_flip_error : Set := + | OneBitFlip (n : block_no) (off : block_offset) + | TwoBitFlip (n₁ n₂ : block_no) (h : n₁ <> n₂) (off₁ off₂ : block_offset) + | ThreeBitFlip (off₁ off₂ off₃ : block_offset). + +Inductive error : Set := + | PhaseFlipError (n : block_no) (e : phase_flip_error n) + | BitFlipError (e : bit_flip_error) + | BothErrors (n : block_no) (e₁ : phase_flip_error n) (e₂ : bit_flip_error). + +Fixpoint apply_phase_flip_error {n} (e : phase_flip_error n) : base_ucom dim := + match e with + | OnePhaseFlip _ off => SQIR.Z (proj1_sig (Fin.to_nat off)) + | MorePhaseFlip _ e off => SQIR.Z (proj1_sig (Fin.to_nat off)); apply_phase_flip_error e + end. + +Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom dim := + match e with + | OneBitFlip n off => X (block_to_qubit n off) + | TwoBitFlip n₁ n₂ _ off₁ off₂ => (X (block_to_qubit n₁ off₁)); (X (block_to_qubit n₂ off₂)) + | ThreeBitFlip off₁ off₂ off₃ => ( + let q1 := block_to_qubit (@Fin.of_nat_lt 0 3 ltac:(lia)) off₁ in + let q2 := block_to_qubit (@Fin.of_nat_lt 1 3 ltac:(lia)) off₂ in + let q3 := block_to_qubit (@Fin.of_nat_lt 2 3 ltac:(lia)) off₃ in + X q1; X q2; X q3 + ) + end. + +Definition apply_error (e : error) : base_ucom dim := + match e with + | PhaseFlipError _ e => apply_phase_flip_error e + | BitFlipError e => apply_bit_flip_error e + | BothErrors _ e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ + end. + + +End NineQubitCode. From 8850a30d9c55920560307cf9d1a92e3b31ddfd29 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 25 Jul 2024 15:05:28 -0500 Subject: [PATCH 05/17] other formatting gaffes --- examples/error-correction/ThreeQubitCode.v | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/error-correction/ThreeQubitCode.v b/examples/error-correction/ThreeQubitCode.v index d535d7c..440918a 100644 --- a/examples/error-correction/ThreeQubitCode.v +++ b/examples/error-correction/ThreeQubitCode.v @@ -36,7 +36,7 @@ Proof. repeat rewrite kron_mixed_product. Qsimpl. autorewrite with ket_db. -rewrite Mplus_comm; easy. + rewrite Mplus_comm; easy. Qed. Inductive error : Set := @@ -46,7 +46,7 @@ Inductive error : Set := | BitFlip2. Definition apply_error (e : error) : base_ucom dim := -match e with + match e with | NoError => SKIP | BitFlip0 => X 0 | BitFlip1 => X 1 From e1d683ac08c57f9e70ffba7c4750844f154f961b Mon Sep 17 00:00:00 2001 From: Ben Caldwell Date: Thu, 11 Jul 2024 21:22:47 -0500 Subject: [PATCH 06/17] fixing build --- coq-sqir.opam | 2 +- dune-project | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/coq-sqir.opam b/coq-sqir.opam index c0a318e..617d0f8 100644 --- a/coq-sqir.opam +++ b/coq-sqir.opam @@ -14,7 +14,7 @@ bug-reports: "https://github.com/inQWIRE/SQIR/issues" depends: [ "dune" {>= "3.8"} "coq-interval" {>= "4.9.0"} - "coq-quantumlib" {>= "1.3.0"} + "coq-quantumlib" {= "1.3.0"} "coq" {>= "8.16"} "odoc" {with-doc} ] diff --git a/dune-project b/dune-project index 84df113..81e4a85 100644 --- a/dune-project +++ b/dune-project @@ -17,7 +17,7 @@ ) (depends (coq-interval (>= 4.9.0)) - (coq-quantumlib (>= 1.3.0)) + (coq-quantumlib (= 1.3.0)) (coq (>= 8.16)))) (package From abe6c2c7de7113af92aecd42f440da1313b17a7f Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 25 Jul 2024 19:10:35 -0500 Subject: [PATCH 07/17] minor simplifications; add no error --- examples/error-correction/NineQubitCode.v | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index 6d5e73b..c67f916 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -43,8 +43,8 @@ Definition encode : base_ucom dim := Theorem encode_correct (α β : C) : (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩ ) - = /C2 .* (/√ 2 .* (α .* (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) - .+ /C2 .* (/√ 2 .* (β .* (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))). + = /C2 .* (/√ 2 .* (α .* 3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + .+ /C2 .* (/√ 2 .* (β .* 3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))). Proof. simpl. Qsimpl. @@ -60,11 +60,7 @@ Proof. distribute_scale. repeat rewrite Mscale_mult_dist_r. restore_dims. - repeat rewrite kron_assoc by ( - repeat rewrite kron_assoc by auto with wf_db; - repeat (rewrite f_to_vec_merge; restore_dims); - auto with wf_db - ). + repeat rewrite kron_assoc by auto 10 with wf_db. repeat (rewrite f_to_vec_merge; restore_dims). repeat rewrite f_to_vec_CNOT; try lia. simpl update. @@ -79,7 +75,6 @@ Proof. repeat rewrite Mmult_plus_distr_l. repeat rewrite Mscale_mult_dist_r. - repeat (rewrite f_to_vec_CNOT; try lia; try rewrite kron_1_l; simpl update). simpl. Qsimpl. @@ -94,8 +89,8 @@ Proof. repeat rewrite kron_plus_distr_l. repeat rewrite Mplus_assoc. f_equal. - repeat rewrite Mscale_assoc. - - replace (α * / √ 2 * / √ 2 * / √ 2)%C with (/√ 2 * / C2 * α)%C. + - repeat rewrite Mscale_assoc. + replace (α * / √ 2 * / √ 2 * / √ 2)%C with (/√ 2 * / C2 * α)%C. 2: { (* why does lca not work here? *) rewrite Cmult_comm. @@ -144,6 +139,7 @@ Inductive bit_flip_error : Set := | ThreeBitFlip (off₁ off₂ off₃ : block_offset). Inductive error : Set := + | NoError | PhaseFlipError (n : block_no) (e : phase_flip_error n) | BitFlipError (e : bit_flip_error) | BothErrors (n : block_no) (e₁ : phase_flip_error n) (e₂ : bit_flip_error). @@ -168,10 +164,10 @@ Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom dim := Definition apply_error (e : error) : base_ucom dim := match e with + | NoError => SKIP | PhaseFlipError _ e => apply_phase_flip_error e | BitFlipError e => apply_bit_flip_error e | BothErrors _ e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ end. - End NineQubitCode. From 5346a2217343d72610e9168d5334ea3dec8e5316 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 25 Jul 2024 20:56:01 -0500 Subject: [PATCH 08/17] factor out ZCCX --- _CoqProject | 1 + examples/error-correction/Common.v | 14 ++++++++++++++ examples/error-correction/ThreeQubitCode.v | 14 +++++--------- 3 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 examples/error-correction/Common.v diff --git a/_CoqProject b/_CoqProject index 2c1ddb6..c0f3383 100644 --- a/_CoqProject +++ b/_CoqProject @@ -1,4 +1,5 @@ -R _build/default/SQIR SQIR -R _build/default/examples/examples Examples -R _build/default/examples/shor Shor +-R _build/default/examples/error-correction ErrorCorrection -R _build/default/VOQC VOQC diff --git a/examples/error-correction/Common.v b/examples/error-correction/Common.v new file mode 100644 index 0000000..d8d75be --- /dev/null +++ b/examples/error-correction/Common.v @@ -0,0 +1,14 @@ +Require Export SQIR.UnitaryOps. + +Module Common. + +Open Scope ucom. + +(** A toffoli gate but controlled on the first qubit + being zero. *) +Definition ZCCX {dim} (a b c : nat) : base_ucom dim := + X a; + CCX a b c; + X a. + +End Common. diff --git a/examples/error-correction/ThreeQubitCode.v b/examples/error-correction/ThreeQubitCode.v index 440918a..dabcceb 100644 --- a/examples/error-correction/ThreeQubitCode.v +++ b/examples/error-correction/ThreeQubitCode.v @@ -1,16 +1,12 @@ Require Export SQIR.UnitaryOps. Require Import QuantumLib.Measurement. +Require Import Common. + Module ThreeQubitCode. Open Scope ucom. -Definition Toffoli_false_fst {dim} (a b c : nat) : base_ucom dim := - X a; - CCX a b c; - X a. - - (* q at 0; encoding/decoding ancillae at 1 and 2; and recovery ancillae at 3 and 4. *) Definition dim : nat := 5. @@ -66,8 +62,8 @@ Definition recover : base_ucom dim := CNOT 0 4; CNOT 1 4; CNOT 1 3; CNOT 2 3; CNOT 3 4; - Toffoli_false_fst 3 4 0; - Toffoli_false_fst 4 3 1; + Common.ZCCX 3 4 0; + Common.ZCCX 4 3 1; CCX 3 4 2. Definition decode : base_ucom dim := @@ -278,7 +274,7 @@ Proof. ). all : simpl; Qsimpl. all : repeat rewrite <- kron_assoc by auto with wf_db. - all : reflexivity. + all : easy. Qed. (** The rest of the circuit is the same as From 8d5819e2b3ea1a40f08688f8a5ecda448f46a629 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Mon, 29 Jul 2024 13:38:24 -0500 Subject: [PATCH 09/17] do not name goals --- examples/error-correction/ThreeQubitCode.v | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/error-correction/ThreeQubitCode.v b/examples/error-correction/ThreeQubitCode.v index dabcceb..27384a7 100644 --- a/examples/error-correction/ThreeQubitCode.v +++ b/examples/error-correction/ThreeQubitCode.v @@ -102,11 +102,13 @@ Proof. all : repeat rewrite Mmult_assoc. all : repeat rewrite Mmult_plus_distr_l. all : repeat rewrite Mscale_mult_dist_r. - all : replace (∣0, 0, 0, 0, 0⟩) with (f_to_vec dim (fun _ => false)). - all : replace (∣1, 1, 1, 0, 0⟩) with (f_to_vec dim (fun n => n false)) by ( simpl f_to_vec; - Qsimpl; easy + Msimpl_light; easy + ). + all : replace (∣1, 1, 1, 0, 0⟩) with (f_to_vec dim (fun n => n Date: Wed, 31 Jul 2024 17:17:59 -0500 Subject: [PATCH 10/17] sync fork --- SQIR/NDSem.v | 3 + SQIR/UnitaryOps.v | 26 +++++- examples/error-correction/NineQubitCode.v | 97 ++++++++++++++++++----- 3 files changed, 105 insertions(+), 21 deletions(-) diff --git a/SQIR/NDSem.v b/SQIR/NDSem.v index 7b2a4a3..8ec9e1b 100644 --- a/SQIR/NDSem.v +++ b/SQIR/NDSem.v @@ -114,12 +114,15 @@ Proof. Msimpl; simpl; apply sqrt_0 ). + try apply norm_zero_iff_zero. try apply WF_Zero. try easy. + contradict H0. rewrite <- Mmult_assoc. rewrite proj_twice_neq by easy. unfold norm. Msimpl; simpl. try apply sqrt_0. + try apply norm_zero_iff_zero. try apply WF_Zero. + try easy. + rewrite <- Mmult_assoc in H0, H1. rewrite proj_twice_eq in H0, H1. apply nd_meas_f; assumption. diff --git a/SQIR/UnitaryOps.v b/SQIR/UnitaryOps.v index 55d03f4..f6f1d69 100644 --- a/SQIR/UnitaryOps.v +++ b/SQIR/UnitaryOps.v @@ -416,11 +416,25 @@ Proof. - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; try lca; field_simplify_eq; try nonzero; group_Cexp. + simpl. try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). + try ( + rewrite Cplus_comm; unfold Cplus, Cmult; + autorewrite with R_db; simpl; + setoid_rewrite sin2_cos2; easy + ). + try (simpl; rewrite Copp_mult_distr_l, Copp_mult_distr_r; repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; autorewrite with RtoC_db; rewrite Ropp_involutive; setoid_rewrite sin2_cos2; rewrite Cmult_1_r; apply f_equal; lra). + try ( + simpl; repeat rewrite <- Cmult_assoc; simpl; + rewrite <- Cmult_plus_distr_l; + unfold Cplus, Cmult; + autorewrite with R_db; simpl; + setoid_rewrite sin2_cos2; autorewrite with R_db; + unfold Cexp; apply f_equal2; [apply f_equal; lra|] + ). + apply f_equal; lra. - rewrite <- Mscale_kron_dist_l. repeat rewrite <- Mscale_kron_dist_r. repeat (apply f_equal2; try reflexivity). @@ -438,7 +452,7 @@ Proof. unfold Cminus, Cmult; simpl; autorewrite with R_db; apply c_proj_eq; simpl; autorewrite with R_db). rewrite <- Rminus_unfold, <- cos_plus. - apply f_equal. try apply f_equal. try lra. + apply f_equal. try apply f_equal. try lra. lra. + apply f_equal2; [apply f_equal; lra|]. apply c_proj_eq; simpl; try lra. R_field_simplify. @@ -459,9 +473,19 @@ Proof. try (autorewrite with RtoC_db; rewrite Rplus_comm; rewrite <- Rminus_unfold, <- cos_plus; apply f_equal; apply f_equal; lra). + try ( + rewrite Cplus_comm; apply c_proj_eq; simpl; try lra; + autorewrite with R_db; rewrite <- Rminus_unfold; + rewrite <- cos_plus; apply f_equal; lra + ). - solve_matrix; autorewrite with R_db C_db RtoC_db Cexp_db trig_db; try lca; field_simplify_eq; try nonzero; group_Cexp. + try (rewrite Rplus_comm; setoid_rewrite sin2_cos2; easy). + try ( + simpl; rewrite Cplus_comm; unfold Cplus, Cmult; + autorewrite with R_db; simpl; + setoid_rewrite sin2_cos2; easy + ). + try (rewrite Copp_mult_distr_l, Copp_mult_distr_r; repeat rewrite <- Cmult_assoc; rewrite <- Cmult_plus_distr_l; autorewrite with RtoC_db; rewrite Ropp_involutive; diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index c67f916..b00e34e 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -1,12 +1,20 @@ Require Import Vectors.Fin. Require Export SQIR.UnitaryOps. +Require Import Common. + Module NineQubitCode. Open Scope ucom. Open Scope nat_scope. -Definition dim : nat := 9. +(** + 9 qubits are for encoding/decoding. + Following that, 2 * 3 qubits are used for bit flip syndrome analysis. + 2 additional qubits are used phase-flip analysis. + This can be made more compact, but this representation makes syndrome analysis easier. + *) +Definition dim : nat := 17. (** Blocks @@ -31,22 +39,28 @@ Definition encode_block (n : block_no) : base_ucom dim := let q0 := proj1_sig (Fin.to_nat n) * 3 in let q1 := q0 + 1 in let q2 := q0 + 2 in + H q0; CNOT q0 q1; CNOT q0 q2. Definition encode : base_ucom dim := CNOT 0 3; CNOT 0 6; - H 0; H 3; H 6; encode_block (@Fin.of_nat_lt 0 3 ltac:(lia)); encode_block (@Fin.of_nat_lt 1 3 ltac:(lia)); encode_block (@Fin.of_nat_lt 2 3 ltac:(lia)). -Theorem encode_correct (α β : C) : - (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩ ) - = /C2 .* (/√ 2 .* (α .* 3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩))) - .+ /C2 .* (/√ 2 .* (β .* 3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))). +Definition encoded α β := ( + /C2 .* (/√ 2 .* (α .* (3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩)) ⊗ 5 \otimes ∣0⟩)) + .+ /C2 .* (/√ 2 .* (β .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)) ⊗ ∣0⟩)) +). + +Theorem encode_correct : forall (α β : C), + (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 16 ⨂ ∣0,0⟩) + = encoded α β. Proof. - simpl. Qsimpl. + intros. + simpl. Msimpl_light. + replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by lma'. replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by lma'. @@ -62,21 +76,18 @@ Proof. restore_dims. repeat rewrite kron_assoc by auto 10 with wf_db. repeat (rewrite f_to_vec_merge; restore_dims). - repeat rewrite f_to_vec_CNOT; try lia. - simpl update. - repeat ( - rewrite f_to_vec_H; try lia; - simpl update; simpl b2R; - restore_dims; + repeat ( + first + [ rewrite f_to_vec_H + | repeat rewrite f_to_vec_CNOT; try lia + ]; + simpl update; repeat rewrite Mmult_plus_distr_l; - repeat rewrite Mscale_mult_dist_r + repeat rewrite Mscale_mult_dist_r; + restore_dims ). - repeat rewrite Mmult_plus_distr_l. - repeat rewrite Mscale_mult_dist_r. - - repeat (rewrite f_to_vec_CNOT; try lia; try rewrite kron_1_l; simpl update). - simpl. Qsimpl. + simpl. Msimpl_light. replace (0 * PI)%R with 0%R by lra. replace (1 * PI)%R with PI by lra. @@ -132,7 +143,6 @@ Inductive phase_flip_error (n : block_no) : Set := | OnePhaseFlip (off : block_offset) | MorePhaseFlip (e : phase_flip_error n) (off : block_offset). - Inductive bit_flip_error : Set := | OneBitFlip (n : block_no) (off : block_offset) | TwoBitFlip (n₁ n₂ : block_no) (h : n₁ <> n₂) (off₁ off₂ : block_offset) @@ -170,4 +180,51 @@ Definition apply_error (e : error) : base_ucom dim := | BothErrors _ e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ end. +Definition ancillae_for (e : error) : Vector (2 ^ 8) := + 8 ⨂ ∣0⟩. + +(** + Recover + *) +Definition recover : base_ucom dim := SKIP. + +Theorem error_recover_correct (e : error) : forall (α β : C), + (@uc_eval dim (apply_error e; recover)) × encoded α β + = encoded α β. +Proof. +Admitted. + + +(** + Decode + *) +Definition decode : base_ucom dim := SKIP. + + +(** + Full circuit + *) + +Definition shor (e : error) : base_ucom dim := + encode; + apply_error e; + recover; + decode. + +Definition shor_correct (e : error) : forall (α β : C), + (@uc_eval dim (shor e)) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩ ) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for e. +Proof. + intros. + Local Opaque encode. + simpl uc_eval. + repeat rewrite Mmult_assoc. + restore_dims. + rewrite (encode_correct α β). + + +Admitted. + + + End NineQubitCode. From 32e36b1c1b682f2c11754c955e5cb42ba2a5f814 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 1 Aug 2024 01:08:27 -0500 Subject: [PATCH 11/17] syndrome analysis --- examples/error-correction/NineQubitCode.v | 221 ++++++++++++++++------ 1 file changed, 167 insertions(+), 54 deletions(-) diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index b00e34e..9ee20a4 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -1,42 +1,66 @@ -Require Import Vectors.Fin. -Require Export SQIR.UnitaryOps. +Require ExportSQIR.UnitaryOps. Require Import Common. Module NineQubitCode. -Open Scope ucom. -Open Scope nat_scope. +Local Open Scope ucom. +Local Open Scope nat_scope. -(** - 9 qubits are for encoding/decoding. - Following that, 2 * 3 qubits are used for bit flip syndrome analysis. - 2 additional qubits are used phase-flip analysis. - This can be made more compact, but this representation makes syndrome analysis easier. - *) -Definition dim : nat := 17. +Definition dim : nat := 9. (** Blocks *) +Inductive up_to_three := + | Zero + | One + | Two. + +Definition t_to_nat (t : up_to_three) : nat := + match t with + | Zero => 0 + | One => 1 + | Two => 2 + end. + +Definition t_eq (t₁ t₂ : up_to_three) : bool := + match t₁, t₂ with + | Zero, Zero + | One, One + | Two, Two => true + | _, _ => false + end. + +Coercion t_to_nat : up_to_three >-> nat. + +Definition t_of_nat (n : nat) (h : n < 3) : up_to_three. +Proof. + destruct n as [| [| [| n']]]. + - exact Zero. + - exact One. + - exact Two. + - lia. +Defined. + (* Encoded blocks *) -Definition block_no := Fin.t 3. +Definition block_no := up_to_three. (* Qubits in a single block *) -Definition block_offset := Fin.t 3. +Definition block_offset := up_to_three. Definition block_to_qubit (n : block_no) (off : block_offset) : nat := - proj1_sig (Fin.to_nat n) * 3 + proj1_sig (Fin.to_nat off). + n * 3 + off. -Compute block_to_qubit (@Fin.of_nat_lt 2 3 ltac:(lia)) (@Fin.of_nat_lt 2 3 ltac:(lia)). +Compute block_to_qubit (t_of_nat 2 ltac:(lia)) (t_of_nat 2 ltac:(lia)). (** Encoding *) Definition encode_block (n : block_no) : base_ucom dim := - let q0 := proj1_sig (Fin.to_nat n) * 3 in + let q0 := n * 3 in let q1 := q0 + 1 in let q2 := q0 + 2 in H q0; @@ -45,23 +69,22 @@ Definition encode_block (n : block_no) : base_ucom dim := Definition encode : base_ucom dim := CNOT 0 3; CNOT 0 6; - encode_block (@Fin.of_nat_lt 0 3 ltac:(lia)); - encode_block (@Fin.of_nat_lt 1 3 ltac:(lia)); - encode_block (@Fin.of_nat_lt 2 3 ltac:(lia)). + encode_block (t_of_nat 0 ltac:(lia)); + encode_block (t_of_nat 1 ltac:(lia)); + encode_block (t_of_nat 2 ltac:(lia)). -Definition encoded α β := ( - /C2 .* (/√ 2 .* (α .* (3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩)) ⊗ 5 \otimes ∣0⟩)) - .+ /C2 .* (/√ 2 .* (β .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)) ⊗ ∣0⟩)) +Notation encoded α β := ( + /C2 .* (/√ 2 .* (α .* (3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩)))) + .+ /C2 .* (/√ 2 .* (β .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)))) ). Theorem encode_correct : forall (α β : C), - (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 16 ⨂ ∣0,0⟩) + (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = encoded α β. Proof. intros. simpl. Msimpl_light. - replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by lma'. replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by lma'. restore_dims. @@ -150,14 +173,15 @@ Inductive bit_flip_error : Set := Inductive error : Set := | NoError - | PhaseFlipError (n : block_no) (e : phase_flip_error n) + | PhaseFlipError {n} (e : phase_flip_error n) | BitFlipError (e : bit_flip_error) - | BothErrors (n : block_no) (e₁ : phase_flip_error n) (e₂ : bit_flip_error). + | PhaseBitErrors {phase_n} (e₁ : phase_flip_error phase_n) (e₂ : bit_flip_error) + | BitPhaseErrors (e₁ : bit_flip_error) {phase_n} (e₂ : phase_flip_error phase_n). Fixpoint apply_phase_flip_error {n} (e : phase_flip_error n) : base_ucom dim := match e with - | OnePhaseFlip _ off => SQIR.Z (proj1_sig (Fin.to_nat off)) - | MorePhaseFlip _ e off => SQIR.Z (proj1_sig (Fin.to_nat off)); apply_phase_flip_error e + | OnePhaseFlip _ off => SQIR.Z off + | MorePhaseFlip _ e off => SQIR.Z off; apply_phase_flip_error e end. Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom dim := @@ -165,41 +189,113 @@ Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom dim := | OneBitFlip n off => X (block_to_qubit n off) | TwoBitFlip n₁ n₂ _ off₁ off₂ => (X (block_to_qubit n₁ off₁)); (X (block_to_qubit n₂ off₂)) | ThreeBitFlip off₁ off₂ off₃ => ( - let q1 := block_to_qubit (@Fin.of_nat_lt 0 3 ltac:(lia)) off₁ in - let q2 := block_to_qubit (@Fin.of_nat_lt 1 3 ltac:(lia)) off₂ in - let q3 := block_to_qubit (@Fin.of_nat_lt 2 3 ltac:(lia)) off₃ in + let q1 := block_to_qubit (t_of_nat 0 ltac:(lia)) off₁ in + let q2 := block_to_qubit (t_of_nat 1 ltac:(lia)) off₂ in + let q3 := block_to_qubit (t_of_nat 2 ltac:(lia)) off₃ in X q1; X q2; X q3 ) end. Definition apply_error (e : error) : base_ucom dim := match e with - | NoError => SKIP - | PhaseFlipError _ e => apply_phase_flip_error e - | BitFlipError e => apply_bit_flip_error e - | BothErrors _ e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ + | NoError => SKIP + | PhaseFlipError e => apply_phase_flip_error e + | BitFlipError e => apply_bit_flip_error e + | PhaseBitErrors e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ + | BitPhaseErrors e₁ e₂ => apply_bit_flip_error e₁; apply_phase_flip_error e₂ end. -Definition ancillae_for (e : error) : Vector (2 ^ 8) := - 8 ⨂ ∣0⟩. +Fixpoint list_to_map (l : list nat) : nat -> bool := + match l with + | [] => fun _ => false + | x :: l' => update (list_to_map l') x true + end. -(** - Recover - *) -Definition recover : base_ucom dim := SKIP. +Fixpoint ancillae_for_phase_flip {n} (e : phase_flip_error n) : list nat := + match e with + | MorePhaseFlip _ (OnePhaseFlip _ _) _ => [] + | MorePhaseFlip _ (MorePhaseFlip _ e _) _ => ancillae_for_phase_flip e + | OnePhaseFlip _ _ => + match n with + | Zero => 3 :: [6] + | One => [3] + | Two => [6] + end + end. -Theorem error_recover_correct (e : error) : forall (α β : C), - (@uc_eval dim (apply_error e; recover)) × encoded α β - = encoded α β. -Proof. -Admitted. +Definition block_to_syn (b : block_no) (off : block_offset) : list nat := + match off with + | Zero => b + 1 :: [b + 2] + | One => [b + 1] + | Two => [b + 2] + end. +Definition ancillae_for_bit_flip (e : bit_flip_error) : list nat := + match e with + | OneBitFlip n off => block_to_syn n off + | TwoBitFlip n₁ n₂ h off₁ off₂ => + match n₁ as n₁', n₂ as n₂' return n₁ = n₁' -> n₂ = n₂' -> list nat with + | Zero, Zero + | One, One + | Two, Two => fun h₁ h₂ => ltac:( + exfalso; + subst; + contradiction + ) + | n₁', n₂' => fun _ _ => (block_to_syn n₁' off₁) ++ (block_to_syn n₂' off₂) + end eq_refl eq_refl + | ThreeBitFlip off₁ off₂ off₃ => + (block_to_syn (t_of_nat 0 ltac:(lia)) off₁) ++ + (block_to_syn (t_of_nat 1 ltac:(lia)) off₂) ++ + (block_to_syn (t_of_nat 2 ltac:(lia)) off₃) + end. + +Definition ancillae_for (e : error) : Vector (2 ^ 8) := + match e with + | NoError => 8 ⨂ ∣0⟩ + | PhaseFlipError e => f_to_vec 8 (list_to_map (ancillae_for_phase_flip e)) + | BitFlipError e => f_to_vec 8 (list_to_map (ancillae_for_bit_flip e)) + | PhaseBitErrors e₁ e₂ => + f_to_vec 8 (list_to_map ( + ancillae_for_phase_flip e₁ ++ + ancillae_for_bit_flip e₂ + )) + | @BitPhaseErrors e₁ phase_n e₂ => + let v := f_to_vec 8 (list_to_map ( + ancillae_for_bit_flip e₁ ++ + ancillae_for_phase_flip e₂ + )) in + match e₁ with + | OneBitFlip n off => + if t_eq n phase_n + then (-1)%R .* v + else v + | TwoBitFlip n₁ n₂ _ _ _ => + if orb (t_eq n₁ phase_n) (t_eq n₂ phase_n) + then (-1)%R .* v + else v + | ThreeBitFlip _ _ _ => (-1)%R .* v + end + end. + +Definition decode_block (n : block_no) : base_ucom dim := + let q0 := n * 3 in + let q1 := q0 + 1 in + let q2 := q0 + 2 in + CNOT q0 q1; + CNOT q0 q2; + CCX q1 q2 q0; + H q0. (** Decode *) -Definition decode : base_ucom dim := SKIP. - +Definition decode : base_ucom dim := + decode_block (t_of_nat 0 ltac:(lia)); + decode_block (t_of_nat 1 ltac:(lia)); + decode_block (t_of_nat 2 ltac:(lia)); + CNOT 0 3; CNOT 0 6; + CCX 6 3 0. (** Full circuit @@ -208,7 +304,14 @@ Definition decode : base_ucom dim := SKIP. Definition shor (e : error) : base_ucom dim := encode; apply_error e; - recover; + (* Does not use the regular: + `encode; apply_error e; recover; decode` + (because we do not recover original encoding). + Attempting to do so requires 8 addition + qubits (really classical bits), 2 per block + for bit flip, and 2 for phase flip. + This makes the following analysis rougher. + *) decode. Definition shor_correct (e : error) : forall (α β : C), @@ -216,15 +319,25 @@ Definition shor_correct (e : error) : forall (α β : C), = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for e. Proof. intros. - Local Opaque encode. + Local Opaque encode CCX. simpl uc_eval. repeat rewrite Mmult_assoc. - restore_dims. - rewrite (encode_correct α β). + rewrite encode_correct. + simpl. Msimpl_light. restore_dims. + replace (∣0,0,0⟩) with (f_to_vec 3 (fun _ => false)) by lma'. + replace (∣1,1,1⟩) with (f_to_vec 3 (fun _ => true)) by lma'. + repeat rewrite Mmult_assoc. + rewrite kron_plus_distr_r. + repeat rewrite Mmult_plus_distr_l. + distribute_scale. + repeat rewrite Mscale_mult_dist_r. + restore_dims. + repeat rewrite kron_assoc by auto 10 with wf_db. + repeat (rewrite f_to_vec_merge; restore_dims). + + Admitted. - - End NineQubitCode. From e10330bb0c34fded0bb77ca60d2f5c42c9674f56 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Fri, 2 Aug 2024 16:58:26 -0500 Subject: [PATCH 12/17] checkpoint --- examples/error-correction/Common.v | 13 + examples/error-correction/NineQubitCode.v | 275 ++++++++++++++++++---- 2 files changed, 248 insertions(+), 40 deletions(-) diff --git a/examples/error-correction/Common.v b/examples/error-correction/Common.v index d8d75be..f2ab5c1 100644 --- a/examples/error-correction/Common.v +++ b/examples/error-correction/Common.v @@ -11,4 +11,17 @@ Definition ZCCX {dim} (a b c : nat) : base_ucom dim := CCX a b c; X a. +Lemma zero_9_f_to_vec : + ∣0,0,0⟩ = f_to_vec 3 (fun _ => false). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma nine_9_f_to_vec : + ∣1,1,1⟩ = f_to_vec 3 (fun _ => true). +Proof. + lma'. simpl. auto with wf_db. +Qed. + + End Common. diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index 9ee20a4..65ce628 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -1,4 +1,4 @@ -Require ExportSQIR.UnitaryOps. +Require Export SQIR.UnitaryOps. Require Import Common. @@ -7,7 +7,7 @@ Module NineQubitCode. Local Open Scope ucom. Local Open Scope nat_scope. -Definition dim : nat := 9. +Notation dim := 9%nat. (** Blocks @@ -27,8 +27,8 @@ Definition t_to_nat (t : up_to_three) : nat := Definition t_eq (t₁ t₂ : up_to_three) : bool := match t₁, t₂ with - | Zero, Zero - | One, One + | Zero, Zero +| One, One | Two, Two => true | _, _ => false end. @@ -78,6 +78,25 @@ Notation encoded α β := ( .+ /C2 .* (/√ 2 .* (β .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)))) ). +Ltac f_to_vec_simpl_light := + first + [ rewrite f_to_vec_H + | rewrite f_to_vec_CCX + | rewrite f_to_vec_CNOT + ]; + try lia; + simpl update; + do 2 ( + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r + ). + +Ltac pull_scalars := + repeat rewrite <- Mscale_plus_distr_r; + repeat rewrite kron_plus_distr_r; + repeat rewrite kron_plus_distr_l; + repeat rewrite Mplus_assoc. + Theorem encode_correct : forall (α β : C), (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = encoded α β. @@ -88,8 +107,8 @@ Proof. replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by lma'. replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by lma'. restore_dims. - replace (∣0,0,0⟩) with (f_to_vec 3 (fun _ => false)) by lma'. - replace (∣1,1,1⟩) with (f_to_vec 3 (fun _ => true)) by lma'. + rewrite Common.zero_9_f_to_vec. + rewrite Common.nine_9_f_to_vec. repeat rewrite Mmult_assoc. rewrite kron_plus_distr_r. @@ -100,16 +119,7 @@ Proof. repeat rewrite kron_assoc by auto 10 with wf_db. repeat (rewrite f_to_vec_merge; restore_dims). - repeat ( - first - [ rewrite f_to_vec_H - | repeat rewrite f_to_vec_CNOT; try lia - ]; - simpl update; - repeat rewrite Mmult_plus_distr_l; - repeat rewrite Mscale_mult_dist_r; - restore_dims - ). + repeat f_to_vec_simpl_light. simpl. Msimpl_light. replace (0 * PI)%R with 0%R by lra. @@ -117,11 +127,8 @@ Proof. autorewrite with Cexp_db. group_radicals. repeat rewrite Mscale_1_l. - - repeat rewrite <- Mscale_plus_distr_r. - repeat rewrite kron_plus_distr_r. - repeat rewrite kron_plus_distr_l. - repeat rewrite Mplus_assoc. + + pull_scalars. f_equal. - repeat rewrite Mscale_assoc. replace (α * / √ 2 * / √ 2 * / √ 2)%C with (/√ 2 * / C2 * α)%C. @@ -232,7 +239,7 @@ Definition block_to_syn (b : block_no) (off : block_offset) : list nat := Definition ancillae_for_bit_flip (e : bit_flip_error) : list nat := match e with - | OneBitFlip n off => block_to_syn n off +| OneBitFlip n off => block_to_syn n off | TwoBitFlip n₁ n₂ h off₁ off₂ => match n₁ as n₁', n₂ as n₂' return n₁ = n₁' -> n₂ = n₂' -> list nat with | Zero, Zero @@ -300,13 +307,12 @@ Definition decode : base_ucom dim := (** Full circuit *) - Definition shor (e : error) : base_ucom dim := encode; apply_error e; (* Does not use the regular: `encode; apply_error e; recover; decode` - (because we do not recover original encoding). + (because we do not recover the original encoding). Attempting to do so requires 8 addition qubits (really classical bits), 2 per block for bit flip, and 2 for phase flip. @@ -314,30 +320,219 @@ Definition shor (e : error) : base_ucom dim := *) decode. +Lemma inv_sqrt2_cubed : (/ √ 2 * / √ 2 * / √ 2)%C = (/ C2 * /√ 2)%C. +Proof. + rewrite Cinv_sqrt2_sqrt. + easy. +Qed. + +Ltac compute_vec := + simpl uc_eval; + repeat rewrite Mmult_assoc; restore_dims; + repeat f_to_vec_simpl_light; + simpl; + replace (0 * PI)%R with 0%R by lra; + replace (1 * PI)%R with PI by lra; + autorewrite with Cexp_db; + repeat rewrite Mscale_1_l; + restore_dims; + autorewrite with ket_db. + +Ltac prep_err_compute := + simpl; Msimpl_light; restore_dims; + rewrite Common.zero_9_f_to_vec, Common.nine_9_f_to_vec; + repeat (rewrite kron_plus_distr_r, kron_plus_distr_l; restore_dims); + repeat rewrite kron_plus_distr_r; + repeat rewrite Mplus_assoc; + repeat (rewrite Mscale_kron_dist_l; restore_dims); + repeat rewrite Mscale_kron_dist_r; + repeat (rewrite Mscale_kron_dist_l; restore_dims); + repeat (rewrite f_to_vec_merge; restore_dims); + + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r; + + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r; + restore_dims. + + +Ltac post_compute_vec := + repeat rewrite Cmult_assoc; + rewrite inv_sqrt2_cubed; + repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); + repeat rewrite <- Mscale_plus_distr_r; + repeat rewrite <- kron_assoc; auto with wf_db; + + repeat rewrite Mplus_assoc; + repeat rewrite Mscale_plus_distr_r with (x := ((-1)%R : C)); + repeat rewrite Mplus_assoc. + +Definition error_decode_correct_no_error : + forall (α β : C), + (@uc_eval dim (apply_error NoError; decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for NoError. +Proof. + intros. subst. + Local Opaque decode CCX. + simpl ancillae_for. Msimpl_light. + simpl uc_eval. restore_dims. + repeat rewrite denote_SKIP; try lia. + repeat rewrite Mmult_1_l; auto 10 with wf_db. + restore_dims. + prep_err_compute. + Local Transparent decode. + + assert (H1 : + (uc_eval decode + × f_to_vec (3 + 3 + 3) + (fun x : nat => + if x + if x + if x + if x + if x + if x + if x + if x false)) by lma'. - replace (∣1,1,1⟩) with (f_to_vec 3 (fun _ => true)) by lma'. - - repeat rewrite Mmult_assoc. - rewrite kron_plus_distr_r. - repeat rewrite Mmult_plus_distr_l. - distribute_scale. - repeat rewrite Mscale_mult_dist_r. - restore_dims. - repeat rewrite kron_assoc by auto 10 with wf_db. - repeat (rewrite f_to_vec_merge; restore_dims). - - + destruct e. + - simpl ancillae_for. + specialize (error_decode_correct_no_error α β) as H. + simpl uc_eval in H. + simpl ancillae_for in H. + rewrite Mmult_assoc in H. + apply H. + - Admitted. End NineQubitCode. From 83b63d78d4bb320c8d7df0d3f77ed03b5bc8f097 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Mon, 5 Aug 2024 03:36:02 -0500 Subject: [PATCH 13/17] add bit and phase flips for nine qubit code --- examples/error-correction/Common.v | 68 +- examples/error-correction/NineQubitCode.v | 896 ++++++++++++++-------- 2 files changed, 643 insertions(+), 321 deletions(-) diff --git a/examples/error-correction/Common.v b/examples/error-correction/Common.v index f2ab5c1..511645b 100644 --- a/examples/error-correction/Common.v +++ b/examples/error-correction/Common.v @@ -2,7 +2,7 @@ Require Export SQIR.UnitaryOps. Module Common. -Open Scope ucom. +Local Open Scope ucom. (** A toffoli gate but controlled on the first qubit being zero. *) @@ -11,17 +11,77 @@ Definition ZCCX {dim} (a b c : nat) : base_ucom dim := CCX a b c; X a. -Lemma zero_9_f_to_vec : +Lemma zero_3_f_to_vec : ∣0,0,0⟩ = f_to_vec 3 (fun _ => false). Proof. lma'. simpl. auto with wf_db. Qed. -Lemma nine_9_f_to_vec : +Lemma one_3_f_to_vec : + ∣1,0,0⟩ = f_to_vec 3 (fun n => n =? 0). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma two_3_f_to_vec : + ∣0,1,0⟩ = f_to_vec 3 (fun n => n =? 1). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma three_3_f_to_vec : + ∣1,1,0⟩ = f_to_vec 3 (fun n => orb (n =? 0) (n =? 1)). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma four_3_f_to_vec : + ∣0,0,1⟩ = f_to_vec 3 (fun n => n =? 2). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma five_3_f_to_vec : + ∣1,0,1⟩ = f_to_vec 3 (fun n => orb (n =? 0) (n =? 2)). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma six_3_f_to_vec : + ∣0,1,1⟩ = f_to_vec 3 (fun n => orb (n =? 1) (n =? 2)). +Proof. + lma'. simpl. auto with wf_db. +Qed. + +Lemma seven_3_f_to_vec : ∣1,1,1⟩ = f_to_vec 3 (fun _ => true). Proof. lma'. simpl. auto with wf_db. Qed. - + +#[export] Hint Rewrite + zero_3_f_to_vec + one_3_f_to_vec + two_3_f_to_vec + three_3_f_to_vec + four_3_f_to_vec + five_3_f_to_vec + six_3_f_to_vec + seven_3_f_to_vec + : f_to_vec_3_db. + + +Ltac f_to_vec_simpl_light := + first + [ rewrite f_to_vec_H + | rewrite f_to_vec_CCX + | rewrite f_to_vec_CNOT + ]; + try lia; + simpl update; + do 2 ( + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r + ). End Common. diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index 65ce628..ecf1e93 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -1,6 +1,7 @@ Require Export SQIR.UnitaryOps. Require Import Common. +Import Common. Module NineQubitCode. @@ -8,27 +9,30 @@ Local Open Scope ucom. Local Open Scope nat_scope. Notation dim := 9%nat. +Notation block_dim := 3%nat. + +Local Opaque CCX. (** Blocks *) Inductive up_to_three := - | Zero + | Zer0 | One | Two. Definition t_to_nat (t : up_to_three) : nat := match t with - | Zero => 0 + | Zer0 => 0 | One => 1 | Two => 2 end. Definition t_eq (t₁ t₂ : up_to_three) : bool := match t₁, t₂ with - | Zero, Zero -| One, One + | Zer0, Zer0 + | One, One | Two, Two => true | _, _ => false end. @@ -38,7 +42,7 @@ Coercion t_to_nat : up_to_three >-> nat. Definition t_of_nat (n : nat) (h : n < 3) : up_to_three. Proof. destruct n as [| [| [| n']]]. - - exact Zero. + - exact Zer0. - exact One. - exact Two. - lia. @@ -53,62 +57,102 @@ Definition block_offset := up_to_three. Definition block_to_qubit (n : block_no) (off : block_offset) : nat := n * 3 + off. -Compute block_to_qubit (t_of_nat 2 ltac:(lia)) (t_of_nat 2 ltac:(lia)). +Ltac compute_vec := + simpl uc_eval; + repeat rewrite Mmult_assoc; restore_dims; + repeat Common.f_to_vec_simpl_light; + simpl; + replace (0 * PI)%R with 0%R by lra; + replace (1 * PI)%R with PI by lra; + autorewrite with Cexp_db; + repeat rewrite Mscale_1_l; + restore_dims; + autorewrite with ket_db. + +Ltac correct_inPar well_typed := + try + (replace (@uc_eval 9) with (@uc_eval (3 + 6)) by easy; + rewrite inPar_correct by well_typed); + try + (replace (@uc_eval 6) with (@uc_eval (3 + 3)) by easy; + rewrite inPar_correct by well_typed); + restore_dims. (** Encoding *) -Definition encode_block (n : block_no) : base_ucom dim := - let q0 := n * 3 in - let q1 := q0 + 1 in - let q2 := q0 + 2 in - H q0; - CNOT q0 q1; - CNOT q0 q2. +Definition encode_block : base_ucom block_dim := + H 0; + CNOT 0 1; + CNOT 0 2. + +Theorem encode_block_zero : + uc_eval encode_block × ∣0,0,0⟩ + = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ ∣ 1, 1, 1 ⟩). +Proof. + rewrite Common.zero_3_f_to_vec. + compute_vec. + reflexivity. +Qed. + + +Theorem encode_block_one : + uc_eval encode_block × ∣1,0,0⟩ + = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩). +Proof. + replace (∣1,0,0⟩) with (f_to_vec 3 (fun n => n =? 0)) by lma'. + compute_vec. + reflexivity. +Qed. + +Theorem encode_block_well_typed : + uc_well_typed encode_block. +Proof. + unfold encode_block. + auto. + constructor. + - constructor. + + apply uc_well_typed_H; lia. + + apply uc_well_typed_CNOT; lia. + - apply uc_well_typed_CNOT; lia. +Qed. Definition encode : base_ucom dim := CNOT 0 3; CNOT 0 6; - encode_block (t_of_nat 0 ltac:(lia)); - encode_block (t_of_nat 1 ltac:(lia)); - encode_block (t_of_nat 2 ltac:(lia)). + inPar encode_block (inPar encode_block encode_block). Notation encoded α β := ( - /C2 .* (/√ 2 .* (α .* (3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩)))) - .+ /C2 .* (/√ 2 .* (β .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)))) + α .* (/C2 .* (/√ 2 .* (3 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩)))) + .+ β .* (/C2 .* (/√ 2 .* (3 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩)))) ). -Ltac f_to_vec_simpl_light := - first - [ rewrite f_to_vec_H - | rewrite f_to_vec_CCX - | rewrite f_to_vec_CNOT - ]; - try lia; - simpl update; - do 2 ( - repeat rewrite Mmult_plus_distr_l; - repeat rewrite Mscale_mult_dist_r - ). -Ltac pull_scalars := - repeat rewrite <- Mscale_plus_distr_r; - repeat rewrite kron_plus_distr_r; - repeat rewrite kron_plus_distr_l; - repeat rewrite Mplus_assoc. +Ltac reorder_scalars := + repeat rewrite Mscale_assoc; + repeat rewrite Cmult_comm with (x := ((-1)%R : C)); + repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); + repeat rewrite <- Mscale_plus_distr_r. + +Ltac normalize_kron_notation := + repeat rewrite <- kron_assoc by auto 8 with wf_db; + try easy. + +Lemma inv_sqrt2_cubed : (/ √ 2 * / √ 2 * / √ 2)%C = (/ C2 * /√ 2)%C. +Proof. + now rewrite Cinv_sqrt2_sqrt. +Qed. Theorem encode_correct : forall (α β : C), (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = encoded α β. Proof. intros. + Local Opaque inPar. simpl. Msimpl_light. - replace (∣0⟩) with (f_to_vec 1 (fun _ => false)) by lma'. replace (∣1⟩) with (f_to_vec 1 (fun _ => true)) by lma'. restore_dims. - rewrite Common.zero_9_f_to_vec. - rewrite Common.nine_9_f_to_vec. repeat rewrite Mmult_assoc. rewrite kron_plus_distr_r. @@ -118,50 +162,26 @@ Proof. restore_dims. repeat rewrite kron_assoc by auto 10 with wf_db. repeat (rewrite f_to_vec_merge; restore_dims). - - repeat f_to_vec_simpl_light. - simpl. Msimpl_light. - replace (0 * PI)%R with 0%R by lra. - replace (1 * PI)%R with PI by lra. - autorewrite with Cexp_db. - group_radicals. - repeat rewrite Mscale_1_l. - - pull_scalars. - f_equal. - - repeat rewrite Mscale_assoc. - replace (α * / √ 2 * / √ 2 * / √ 2)%C with (/√ 2 * / C2 * α)%C. - 2: { - (* why does lca not work here? *) - rewrite Cmult_comm. - do 2 rewrite <- Cmult_assoc. - f_equal. f_equal. - symmetry. apply Cinv_sqrt2_sqrt. - } - repeat (rewrite <- kron_assoc by auto 10 with wf_db; restore_dims). - reflexivity. - - do 2 ( - repeat rewrite Mscale_assoc; - repeat rewrite (Cmult_comm (-1)%R _); - repeat rewrite <- (Mscale_assoc _ _ _ (-1)%R _); - repeat rewrite <- Mscale_plus_distr_r - ). - repeat rewrite Mscale_assoc. - replace (β * / √ 2 * / √ 2 * / √ 2)%C with (/ √ 2 * / C2 * β)%C. - 2: { - (* why does lca not work here? *) - rewrite Cmult_comm. - do 2 rewrite <- Cmult_assoc. - f_equal. f_equal. - symmetry. apply Cinv_sqrt2_sqrt. - } - f_equal. - repeat rewrite Mscale_plus_distr_r. - distribute_scale. - repeat (rewrite <- kron_assoc by auto 10 with wf_db; restore_dims). - repeat rewrite Mplus_assoc. - reflexivity. + repeat Common.f_to_vec_simpl_light. + simpl. Msimpl_light. + restore_dims. + correct_inPar ltac:(apply encode_block_well_typed). + replace (∣0, 0, 0, 0, 0, 0, 0, 0, 0⟩) with (∣0, 0, 0⟩ ⊗ ∣0, 0, 0, 0, 0, 0⟩) by normalize_kron_notation. + replace (∣0, 0, 0, 0, 0, 0⟩) with (∣0, 0, 0⟩ ⊗ ∣0, 0, 0⟩) by normalize_kron_notation. + replace (∣1, 0, 0, 1, 0, 0, 1, 0, 0⟩) with (∣1, 0, 0⟩ ⊗ ∣1, 0, 0, 1, 0, 0⟩) by normalize_kron_notation. + replace (∣1, 0, 0, 1, 0, 0⟩) with (∣1, 0, 0⟩ ⊗ ∣1, 0, 0⟩) by normalize_kron_notation. + restore_dims. + do 4 rewrite kron_mixed_product. + rewrite encode_block_zero, encode_block_one. + repeat rewrite Mscale_kron_dist_l. + repeat rewrite Mscale_kron_dist_r. + normalize_kron_notation. + repeat rewrite <- Cmult_assoc. + repeat rewrite <- inv_sqrt2_cubed. + repeat rewrite Cmult_assoc. + repeat rewrite Mscale_assoc. + reflexivity. Qed. @@ -173,33 +193,49 @@ Inductive phase_flip_error (n : block_no) : Set := | OnePhaseFlip (off : block_offset) | MorePhaseFlip (e : phase_flip_error n) (off : block_offset). +Fixpoint phase_flip_odd_parity {n : block_no} (e : phase_flip_error n) : bool := + match e with + | OnePhaseFlip _ _ => true + | MorePhaseFlip _ e _ => negb (phase_flip_odd_parity e) + end. + Inductive bit_flip_error : Set := | OneBitFlip (n : block_no) (off : block_offset) - | TwoBitFlip (n₁ n₂ : block_no) (h : n₁ <> n₂) (off₁ off₂ : block_offset) + | TwoBitFlip (safe_n : block_no) (off₁ off₂ : block_offset) | ThreeBitFlip (off₁ off₂ off₃ : block_offset). Inductive error : Set := | NoError | PhaseFlipError {n} (e : phase_flip_error n) | BitFlipError (e : bit_flip_error) - | PhaseBitErrors {phase_n} (e₁ : phase_flip_error phase_n) (e₂ : bit_flip_error) - | BitPhaseErrors (e₁ : bit_flip_error) {phase_n} (e₂ : phase_flip_error phase_n). + | PhaseBitErrors {phase_n} (e₁ : phase_flip_error phase_n) (e₂ : bit_flip_error). + +Definition apply_to_block (n : block_no) (uc : base_ucom block_dim) := + match n with + | Zer0 => inPar uc (inPar SKIP SKIP) + | One => inPar SKIP (inPar uc SKIP) + | Two => inPar SKIP (inPar SKIP uc) + end. Fixpoint apply_phase_flip_error {n} (e : phase_flip_error n) : base_ucom dim := match e with - | OnePhaseFlip _ off => SQIR.Z off - | MorePhaseFlip _ e off => SQIR.Z off; apply_phase_flip_error e + | OnePhaseFlip _ off => apply_to_block n (SQIR.Z off) + | MorePhaseFlip _ e off => + apply_to_block n (SQIR.Z off); + apply_phase_flip_error e end. -Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom dim := +Definition apply_bit_flip_error (e : bit_flip_error) : base_ucom (block_dim + (block_dim + block_dim)) := match e with - | OneBitFlip n off => X (block_to_qubit n off) - | TwoBitFlip n₁ n₂ _ off₁ off₂ => (X (block_to_qubit n₁ off₁)); (X (block_to_qubit n₂ off₂)) + | OneBitFlip n off => apply_to_block n (X off) + | TwoBitFlip safe_n off₁ off₂ => + match safe_n with + | Zer0 => inPar SKIP (inPar (X off₁) (X off₂)) + | One => inPar (X off₁) (inPar SKIP (X off₂)) + | Two => inPar (X off₁) (inPar (X off₂) SKIP) + end | ThreeBitFlip off₁ off₂ off₃ => ( - let q1 := block_to_qubit (t_of_nat 0 ltac:(lia)) off₁ in - let q2 := block_to_qubit (t_of_nat 1 ltac:(lia)) off₂ in - let q3 := block_to_qubit (t_of_nat 2 ltac:(lia)) off₃ in - X q1; X q2; X q3 + inPar (X off₁) (inPar (X off₂) (X off₃)) ) end. @@ -209,7 +245,6 @@ Definition apply_error (e : error) : base_ucom dim := | PhaseFlipError e => apply_phase_flip_error e | BitFlipError e => apply_bit_flip_error e | PhaseBitErrors e₁ e₂ => apply_phase_flip_error e₁; apply_bit_flip_error e₂ - | BitPhaseErrors e₁ e₂ => apply_bit_flip_error e₁; apply_phase_flip_error e₂ end. Fixpoint list_to_map (l : list nat) : nat -> bool := @@ -224,37 +259,33 @@ Fixpoint ancillae_for_phase_flip {n} (e : phase_flip_error n) : list nat := | MorePhaseFlip _ (MorePhaseFlip _ e _) _ => ancillae_for_phase_flip e | OnePhaseFlip _ _ => match n with - | Zero => 3 :: [6] - | One => [3] - | Two => [6] + | Zer0 => 2 :: [5] + | One => [2] + | Two => [5] end end. -Definition block_to_syn (b : block_no) (off : block_offset) : list nat := +Definition block_to_bit_syn (b : block_no) (off : block_offset) : list nat := + let left_edge := b * 3 in match off with - | Zero => b + 1 :: [b + 2] - | One => [b + 1] - | Two => [b + 2] + | Zer0 => left_edge :: [left_edge + 1] + | One => [left_edge] + | Two => [left_edge + 1] end. Definition ancillae_for_bit_flip (e : bit_flip_error) : list nat := match e with -| OneBitFlip n off => block_to_syn n off - | TwoBitFlip n₁ n₂ h off₁ off₂ => - match n₁ as n₁', n₂ as n₂' return n₁ = n₁' -> n₂ = n₂' -> list nat with - | Zero, Zero - | One, One - | Two, Two => fun h₁ h₂ => ltac:( - exfalso; - subst; - contradiction - ) - | n₁', n₂' => fun _ _ => (block_to_syn n₁' off₁) ++ (block_to_syn n₂' off₂) - end eq_refl eq_refl - | ThreeBitFlip off₁ off₂ off₃ => - (block_to_syn (t_of_nat 0 ltac:(lia)) off₁) ++ - (block_to_syn (t_of_nat 1 ltac:(lia)) off₂) ++ - (block_to_syn (t_of_nat 2 ltac:(lia)) off₃) + | OneBitFlip n off => block_to_bit_syn n off + | TwoBitFlip safe_n off₁ off₂ => + match safe_n with + | Zer0 => (block_to_bit_syn One off₁) ++ (block_to_bit_syn Two off₂) + | One => (block_to_bit_syn Zer0 off₁) ++ (block_to_bit_syn Two off₂) + | Two => (block_to_bit_syn Zer0 off₁) ++ (block_to_bit_syn One off₂) + end + | ThreeBitFlip off₁ off₂ off₃ => + (block_to_bit_syn Zer0 off₁) ++ + (block_to_bit_syn One off₂) ++ + (block_to_bit_syn Two off₃) end. Definition ancillae_for (e : error) : Vector (2 ^ 8) := @@ -267,40 +298,109 @@ Definition ancillae_for (e : error) : Vector (2 ^ 8) := ancillae_for_phase_flip e₁ ++ ancillae_for_bit_flip e₂ )) - | @BitPhaseErrors e₁ phase_n e₂ => - let v := f_to_vec 8 (list_to_map ( - ancillae_for_bit_flip e₁ ++ - ancillae_for_phase_flip e₂ - )) in - match e₁ with - | OneBitFlip n off => - if t_eq n phase_n - then (-1)%R .* v - else v - | TwoBitFlip n₁ n₂ _ _ _ => - if orb (t_eq n₁ phase_n) (t_eq n₂ phase_n) - then (-1)%R .* v - else v - | ThreeBitFlip _ _ _ => (-1)%R .* v - end end. -Definition decode_block (n : block_no) : base_ucom dim := - let q0 := n * 3 in - let q1 := q0 + 1 in - let q2 := q0 + 2 in - CNOT q0 q1; - CNOT q0 q2; - CCX q1 q2 q0; - H q0. +Lemma ancillae_for_two_phases_cancel {n}: + forall (e : phase_flip_error n) (off₁ off₂ : block_offset), + ancillae_for (PhaseFlipError (MorePhaseFlip n (MorePhaseFlip n e off₁) off₂)) + = ancillae_for (PhaseFlipError e). +Proof. + easy. +Qed. + +Definition decode_block : base_ucom block_dim := + CNOT 0 1; + CNOT 0 2; + CCX 1 2 0; + H 0. + +Theorem decode_block_well_typed : uc_well_typed decode_block. +Proof. + repeat constructor. + Local Transparent CCX. + all : unfold CCX. + Local Opaque CCX. + 3 : repeat constructor. + all : unfold TDAG. + all : try apply uc_well_typed_H; try lia. + all : try apply uc_well_typed_CNOT; try lia. + all : try apply uc_well_typed_Rz; lia. +Qed. + + +Lemma decode_block_zero : + uc_eval decode_block × ∣0,0,0⟩ = / √ 2 .* (∣0,0,0⟩ .+ ∣1,0,0⟩). +Proof. + rewrite Common.zero_3_f_to_vec. + compute_vec. + reflexivity. +Qed. + +Lemma decode_block_one : + uc_eval decode_block × ∣1,0,0⟩ = / √ 2 .* (∣0,1,1⟩ .+ ∣1,1,1⟩). +Proof. + rewrite Common.one_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_two : + uc_eval decode_block × ∣0,1,0⟩ = / √ 2 .* (∣0,1,0⟩ .+ ∣1,1,0⟩). +Proof. + rewrite Common.two_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_three : + uc_eval decode_block × ∣1,1,0⟩ = / √ 2 .* (∣0,0,1⟩ .+ (-1)%R .* ∣1,0,1⟩). +Proof. + rewrite Common.three_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_four : + uc_eval decode_block × ∣0,0,1⟩ = / √ 2 .* (∣0,0,1⟩ .+ ∣1,0,1⟩). +Proof. + rewrite Common.four_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_five : + uc_eval decode_block × ∣1,0,1⟩ = / √ 2 .* (∣0,1,0⟩ .+ (-1)%R .* ∣1,1,0⟩). +Proof. + rewrite Common.five_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_six : + uc_eval decode_block × ∣0,1,1⟩ = / √ 2 .* (∣0,1,1⟩ .+ (-1)%R .* ∣1,1,1⟩). +Proof. + rewrite Common.six_3_f_to_vec. + now compute_vec. +Qed. + +Lemma decode_block_seven : + uc_eval decode_block × ∣1,1,1⟩ = / √ 2 .* (∣0,0,0⟩ .+ (-1)%R .* ∣1,0,0⟩). +Proof. + rewrite Common.seven_3_f_to_vec. + now compute_vec. +Qed. + +#[export] Hint Rewrite + decode_block_zero + decode_block_one + decode_block_two + decode_block_three + decode_block_four + decode_block_five + decode_block_six + decode_block_seven + : decode_block_db. (** Decode *) Definition decode : base_ucom dim := - decode_block (t_of_nat 0 ltac:(lia)); - decode_block (t_of_nat 1 ltac:(lia)); - decode_block (t_of_nat 2 ltac:(lia)); + inPar decode_block (inPar decode_block decode_block); CNOT 0 3; CNOT 0 6; CCX 6 3 0. @@ -313,206 +413,359 @@ Definition shor (e : error) : base_ucom dim := (* Does not use the regular: `encode; apply_error e; recover; decode` (because we do not recover the original encoding). - Attempting to do so requires 8 addition + Attempting to do so requires 8 additional qubits (really classical bits), 2 per block for bit flip, and 2 for phase flip. This makes the following analysis rougher. *) decode. -Lemma inv_sqrt2_cubed : (/ √ 2 * / √ 2 * / √ 2)%C = (/ C2 * /√ 2)%C. +Lemma uc_well_typed_SKIP {d} {h : 0 < d} : + @uc_well_typed _ d SKIP. Proof. - rewrite Cinv_sqrt2_sqrt. - easy. + unfold SKIP. + apply uc_well_typed_ID. + assumption. Qed. -Ltac compute_vec := - simpl uc_eval; - repeat rewrite Mmult_assoc; restore_dims; - repeat f_to_vec_simpl_light; - simpl; - replace (0 * PI)%R with 0%R by lra; - replace (1 * PI)%R with PI by lra; - autorewrite with Cexp_db; - repeat rewrite Mscale_1_l; - restore_dims; - autorewrite with ket_db. +Ltac simplify_sums := + match goal with + | [ |- context [ (?A .+ ?B) + .+ (?A .+ RtoC (-1)%R .* ?B) ] + ] => + replace (A .+ B + .+ (A .+ RtoC (-1)%R .* B)) with (C2 .* A) by lma + | [ |- context [?A .+ ?B + .+ RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B)] + ] => + replace (A .+ B + .+ RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B)) + with (C2 .* B) by lma + end. -Ltac prep_err_compute := - simpl; Msimpl_light; restore_dims; - rewrite Common.zero_9_f_to_vec, Common.nine_9_f_to_vec; - repeat (rewrite kron_plus_distr_r, kron_plus_distr_l; restore_dims); - repeat rewrite kron_plus_distr_r; - repeat rewrite Mplus_assoc; - repeat (rewrite Mscale_kron_dist_l; restore_dims); - repeat rewrite Mscale_kron_dist_r; - repeat (rewrite Mscale_kron_dist_l; restore_dims); - repeat (rewrite f_to_vec_merge; restore_dims); +Ltac pull_scalars := + distribute_scale; + repeat rewrite Mscale_mult_dist_r; + repeat rewrite Mscale_assoc. + +Lemma collapse_scalar : + (/ C2 * (/ √ 2 * (/ √ 2 * (C2 * (/ √ 2 * (C2 * (/ √ 2 * C2)))))))%C = C1. +Proof. C_field. Qed. +Ltac distribute_over_blocks := + repeat rewrite kron_1_l by auto 10 with wf_db; + repeat rewrite kron_assoc by auto with wf_db; repeat rewrite Mmult_plus_distr_l; repeat rewrite Mscale_mult_dist_r; - + repeat rewrite Mmult_assoc; + restore_dims; + repeat rewrite kron_mixed_product; repeat rewrite Mmult_plus_distr_l; + normalize_kron_notation; + repeat rewrite Mscale_mult_dist_r. + + +Ltac flatten := + rewrite kron_plus_distr_r; + rewrite 2 Mscale_kron_dist_l; + rewrite ket0_equiv, ket1_equiv; + repeat (rewrite <- kron_assoc by auto 9 with wf_db; restore_dims). + +Ltac compute_decoding := + repeat rewrite kron_1_l by auto with wf_db; + repeat rewrite <- Mmult_assoc; + rewrite Mmult_plus_distr_l; repeat rewrite Mscale_mult_dist_r; - restore_dims. + repeat rewrite Mmult_assoc; + correct_inPar ltac:(apply decode_block_well_typed); + restore_dims; + distribute_over_blocks; + restore_dims; + + autorewrite with decode_block_db; + restore_dims; + reorder_scalars; + repeat simplify_sums; + pull_scalars; + rewrite Common.zero_3_f_to_vec; + rewrite Common.one_3_f_to_vec; -Ltac post_compute_vec := - repeat rewrite Cmult_assoc; - rewrite inv_sqrt2_cubed; - repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); - repeat rewrite <- Mscale_plus_distr_r; - repeat rewrite <- kron_assoc; auto with wf_db; + restore_dims; + repeat rewrite kron_assoc by auto 10 with wf_db; + repeat (rewrite f_to_vec_merge; restore_dims); + repeat Common.f_to_vec_simpl_light; + simpl; Qsimpl; + repeat rewrite <- Cmult_assoc; + rewrite collapse_scalar; + autorewrite with C_db; - repeat rewrite Mplus_assoc; - repeat rewrite Mscale_plus_distr_r with (x := ((-1)%R : C)); - repeat rewrite Mplus_assoc. + now flatten. -Definition error_decode_correct_no_error : +Theorem decode_correct : + forall (α β : C), + (@uc_eval dim decode) × encoded α β + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩. +Proof. + intros. simpl uc_eval. + Qsimpl. simpl. + now compute_decoding. +Qed. + + +Theorem error_decode_correct_no_error : forall (α β : C), (@uc_eval dim (apply_error NoError; decode)) × (encoded α β) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for NoError. Proof. - intros. subst. - Local Opaque decode CCX. + intros. simpl ancillae_for. Msimpl_light. + Local Opaque decode. simpl uc_eval. restore_dims. - repeat rewrite denote_SKIP; try lia. - repeat rewrite Mmult_1_l; auto 10 with wf_db. - restore_dims. - prep_err_compute. Local Transparent decode. + repeat rewrite denote_SKIP; try lia. + repeat rewrite Mmult_assoc; restore_dims. + repeat rewrite Mmult_1_l by auto 10 with wf_db. + rewrite decode_correct. + simpl (_ ⨂ _). + now rewrite kron_1_l by auto with wf_db. +Qed. - assert (H1 : - (uc_eval decode - × f_to_vec (3 + 3 + 3) - (fun x : nat => - if x - if x - if x - if x - if x - if x - if x - if x ( + α .* (/C2 .* (/√ 2 .* ( + (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ 2 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + ) + .+ β .* (/C2 .* (/√ 2 .* ( + (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ 2 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))) + ) + ) + | One => ( + α .* (/C2 .* (/√ 2 .* ( + (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + ) + .+ β .* (/C2 .* (/√ 2 .* ( + (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))) + ) + ) + | Two => ( + α .* (/C2 .* (/√ 2 .* ( + (2 ⨂ (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))) + )) + .+ β .* (/C2 .* (/√ 2 .* ( + 2 ⨂ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + ) + ) + end. + +Theorem one_phase_flip_correct : + forall (α β : C) {n} (off : block_offset), + uc_eval (apply_error (PhaseFlipError (OnePhaseFlip n off))) + × encoded α β + = post_one_phase_flip α β n. +Proof. + intros. + simpl uc_eval. + destruct n. + all : simpl (_ ⨂ _). + all : simpl apply_to_block. + all : simpl post_one_phase_flip. + all : correct_inPar ltac:( + (destruct off; + apply uc_well_typed_Z; simpl; lia) + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : distribute_over_blocks. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : rewrite Z_block_zero, Z_block_seven. + all : repeat rewrite Mscale_assoc. + all : restore_dims. + all : replace ((-1)%R * (-1)%R)%C with C1 by lca. + all : now rewrite Mscale_1_l. +Qed. + +Theorem two_phase_flip_correct : + forall (α β : C) {n} (off₁ off₂ : block_offset), + uc_eval (apply_error (PhaseFlipError (MorePhaseFlip n (OnePhaseFlip n off₁) off₂))) + × encoded α β + = encoded α β. +Proof. + intros. + simpl uc_eval. + destruct n. + all : simpl (_ ⨂ _). + all : simpl apply_to_block. + all : do 2 correct_inPar ltac:( + (destruct off₁; destruct off₂; + apply uc_well_typed_Z; simpl; lia) + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : distribute_over_blocks. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : restore_dims. + all : repeat rewrite Mmult_assoc. + all : rewrite Z_block_zero, Z_block_seven. + all : repeat rewrite Mscale_mult_dist_r. + all : rewrite Z_block_zero, Z_block_seven. + all : repeat rewrite Mscale_assoc. + all : replace ((-1)%R * (-1)%R)%C with C1 by lca. + all : rewrite <- Mscale_assoc with (y := C1). + all : now rewrite Mscale_1_l. +Qed. + +Theorem more_than_two_phase_flip_correct : + forall (α β : C) {n} (off₁ off₂ : block_offset) (e : phase_flip_error n), + uc_eval (apply_error (PhaseFlipError (MorePhaseFlip n (MorePhaseFlip n e off₂) off₁))) + × encoded α β + = uc_eval (apply_error (PhaseFlipError e)) × encoded α β. +Proof. + intros. + simpl uc_eval. + destruct n. + all : simpl (_ ⨂ _). + all : simpl apply_to_block. + all : do 2 correct_inPar ltac:( + (destruct off₁; destruct off₂; + apply uc_well_typed_Z; simpl; lia) + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : distribute_over_blocks. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : restore_dims. + all : repeat rewrite Mmult_assoc. + all : rewrite Z_block_zero, Z_block_seven. + all : repeat rewrite Mscale_mult_dist_r. + all : rewrite Z_block_zero, Z_block_seven. + all : repeat rewrite Mscale_assoc. + all : replace ((-1)%R * (-1)%R)%C with C1 by lca. + all : rewrite <- Mscale_assoc with (y := C1). + all : now rewrite Mscale_1_l. +Qed. + +Lemma error_decode_correct_phase_flip_base : + forall (α β : C) {n} off, + (@uc_eval dim (apply_error (PhaseFlipError (OnePhaseFlip n off)); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError (OnePhaseFlip n off)). +Proof. + intros. + Local Opaque apply_error. + simpl uc_eval. + rewrite Mmult_assoc. + rewrite one_phase_flip_correct. + destruct n. + all : simpl ancillae_for; simpl post_one_phase_flip. + par : now compute_decoding. +Qed. + +Theorem error_decode_correct_phase_flip : + forall (α β : C) {n} (e : @phase_flip_error n), + (@uc_eval dim (apply_error (PhaseFlipError e); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError e). +Proof. + Local Opaque decode. + intros. + enough ( + (@uc_eval dim (apply_error (PhaseFlipError e); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError e) + /\ + forall off, + (@uc_eval dim (apply_error (PhaseFlipError (MorePhaseFlip n e off)); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError (MorePhaseFlip n e off)) + ). + { destruct H; assumption. } + induction e. + - split. + + apply error_decode_correct_phase_flip_base. + + intros. + simpl uc_eval. + rewrite Mmult_assoc. + rewrite two_phase_flip_correct. + apply decode_correct. + - destruct IHe as [IHe IHme]. + split. + + apply IHme. + + intros off0. + simpl uc_eval. + rewrite Mmult_assoc. + rewrite more_than_two_phase_flip_correct. + rewrite ancillae_for_two_phases_cancel. + simpl uc_eval in IHe. + rewrite Mmult_assoc in IHe. + apply IHe. +Qed. + + +Theorem error_decode_correct_bit_flip : + forall (α β : C) e, + (@uc_eval dim (apply_error (BitFlipError e); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (BitFlipError e). +Proof. + intros. + Local Opaque decode_block. + destruct e. + all : repeat rewrite <- Mmult_assoc. + all : rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_mult_dist_r. + 2 : destruct safe_n. + 1 : destruct n. + Local Transparent apply_error. + all : simpl uc_eval. + all : simpl (_ ⨂ _). + all : repeat rewrite Mmult_assoc; restore_dims. + all : simpl apply_to_block. + all : try rewrite kron_1_l by auto with wf_db. + all : restore_dims. + all : repeat rewrite kron_assoc by auto 10 with wf_db. + all : correct_inPar ltac:( + try apply decode_block_well_typed + ). + all : correct_inPar ltac:( + try apply uc_well_typed_X; + first [destruct off | destruct off₁; destruct off₂]; + try destruct off₃; simpl; lia + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : simpl (_ + _). + all : distribute_over_blocks. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : first [destruct off | destruct off₁; destruct off₂]; + try destruct off₃; simpl uc_eval; simpl ancillae_for. + (* slow; around ~2m *) + all : simpl apply_error. + par : restore_dims; + autorewrite with f_to_vec_3_db; + try repeat rewrite f_to_vec_X; try lia; simpl f_to_vec; + repeat rewrite kron_1_l by auto with wf_db; + restore_dims; + autorewrite with decode_block_db; + reorder_scalars; restore_dims; + repeat simplify_sums; + pull_scalars; restore_dims; + autorewrite with f_to_vec_3_db; + restore_dims; + repeat rewrite kron_assoc by auto 10 with wf_db; + repeat (rewrite f_to_vec_merge; restore_dims); + repeat Common.f_to_vec_simpl_light; + simpl f_to_vec; Msimpl_light; + repeat rewrite <- Cmult_assoc; + rewrite collapse_scalar; autorewrite with C_db; + now flatten. +Qed. Definition shor_correct (e : error) : forall (α β : C), @@ -532,7 +785,16 @@ Proof. simpl ancillae_for in H. rewrite Mmult_assoc in H. apply H. - - + - simpl ancillae_for. + specialize (error_decode_correct_phase_flip α β e) as H. + simpl uc_eval in H. + rewrite Mmult_assoc in H. + apply H. + - simpl ancillae_for. + specialize (error_decode_correct_bit_flip α β e) as H. + simpl uc_eval in H. + rewrite Mmult_assoc in H. + apply H. Admitted. End NineQubitCode. From 294156d52a1318172331e315e4b33e02dc166e7d Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Mon, 5 Aug 2024 13:25:50 -0500 Subject: [PATCH 14/17] finish nine qubit code --- examples/error-correction/NineQubitCode.v | 207 ++++++++++++++++------ 1 file changed, 155 insertions(+), 52 deletions(-) diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index ecf1e93..96b513a 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -92,8 +92,7 @@ Theorem encode_block_zero : = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ ∣ 1, 1, 1 ⟩). Proof. rewrite Common.zero_3_f_to_vec. - compute_vec. - reflexivity. + now compute_vec. Qed. @@ -101,9 +100,8 @@ Theorem encode_block_one : uc_eval encode_block × ∣1,0,0⟩ = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩). Proof. - replace (∣1,0,0⟩) with (f_to_vec 3 (fun n => n =? 0)) by lma'. - compute_vec. - reflexivity. + rewrite Common.one_3_f_to_vec. + now compute_vec. Qed. Theorem encode_block_well_typed : @@ -193,12 +191,6 @@ Inductive phase_flip_error (n : block_no) : Set := | OnePhaseFlip (off : block_offset) | MorePhaseFlip (e : phase_flip_error n) (off : block_offset). -Fixpoint phase_flip_odd_parity {n : block_no} (e : phase_flip_error n) : bool := - match e with - | OnePhaseFlip _ _ => true - | MorePhaseFlip _ e _ => negb (phase_flip_odd_parity e) - end. - Inductive bit_flip_error : Set := | OneBitFlip (n : block_no) (off : block_offset) | TwoBitFlip (safe_n : block_no) (off₁ off₂ : block_offset) @@ -304,9 +296,7 @@ Lemma ancillae_for_two_phases_cancel {n}: forall (e : phase_flip_error n) (off₁ off₂ : block_offset), ancillae_for (PhaseFlipError (MorePhaseFlip n (MorePhaseFlip n e off₁) off₂)) = ancillae_for (PhaseFlipError e). -Proof. - easy. -Qed. +Proof. easy. Qed. Definition decode_block : base_ucom block_dim := CNOT 0 1; @@ -441,6 +431,17 @@ Ltac simplify_sums := replace (A .+ B .+ RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B)) with (C2 .* B) by lma + (* (∣ 0, 1, 1 ⟩ .+ ∣ 1, 1, 1 ⟩ + .+ (-1)%R + .* ((-1)%R .* (∣ 0, 1, 1 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩)) *) + | [ |- context [?A .+ ?B + .+ RtoC (-1)%R .* ( + RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B))] + ] => + replace (A .+ B + .+ RtoC (-1)%R .* ( + RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B))) + with (C2 .* A) by lma end. Ltac pull_scalars := @@ -557,7 +558,7 @@ Definition post_one_phase_flip (α β : C) (n : block_no) := ) | One => ( α .* (/C2 .* (/√ 2 .* ( - (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) + (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩))) ) .+ β .* (/C2 .* (/√ 2 .* ( (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ ∣1,1,1⟩) ⊗ (∣0,0,0⟩ .+ (-1)%R .* ∣1,1,1⟩))) @@ -657,23 +658,8 @@ Proof. all : now rewrite Mscale_1_l. Qed. -Lemma error_decode_correct_phase_flip_base : - forall (α β : C) {n} off, - (@uc_eval dim (apply_error (PhaseFlipError (OnePhaseFlip n off)); decode)) × (encoded α β) - = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError (OnePhaseFlip n off)). -Proof. - intros. - Local Opaque apply_error. - simpl uc_eval. - rewrite Mmult_assoc. - rewrite one_phase_flip_correct. - destruct n. - all : simpl ancillae_for; simpl post_one_phase_flip. - par : now compute_decoding. -Qed. - Theorem error_decode_correct_phase_flip : - forall (α β : C) {n} (e : @phase_flip_error n), + forall (α β : C) {n} (e : phase_flip_error n), (@uc_eval dim (apply_error (PhaseFlipError e); decode)) × (encoded α β) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseFlipError e). Proof. @@ -690,7 +676,16 @@ Proof. { destruct H; assumption. } induction e. - split. - + apply error_decode_correct_phase_flip_base. + + intros. + Local Opaque apply_error. + Local Transparent decode. + simpl uc_eval. + Local Opaque decode. + rewrite Mmult_assoc. + rewrite one_phase_flip_correct. + destruct n. + all : simpl ancillae_for; simpl post_one_phase_flip. + par : now compute_decoding. + intros. simpl uc_eval. rewrite Mmult_assoc. @@ -709,6 +704,26 @@ Proof. apply IHe. Qed. +Ltac post_offset_destruct := + restore_dims; + autorewrite with f_to_vec_3_db; + try repeat rewrite f_to_vec_X; try lia; simpl f_to_vec; + repeat rewrite kron_1_l by auto with wf_db; + restore_dims; + autorewrite with decode_block_db; + reorder_scalars; restore_dims; + repeat simplify_sums; + pull_scalars; restore_dims; + autorewrite with f_to_vec_3_db; + restore_dims; + repeat rewrite kron_assoc by auto 10 with wf_db; + repeat (rewrite f_to_vec_merge; restore_dims); + repeat Common.f_to_vec_simpl_light; + simpl f_to_vec; Msimpl_light; + repeat rewrite <- Cmult_assoc; + rewrite collapse_scalar; autorewrite with C_db; + now flatten. + Theorem error_decode_correct_bit_flip : forall (α β : C) e, @@ -716,6 +731,7 @@ Theorem error_decode_correct_bit_flip : = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (BitFlipError e). Proof. intros. + Local Transparent decode. Local Opaque decode_block. destruct e. all : repeat rewrite <- Mmult_assoc. @@ -746,25 +762,107 @@ Proof. all : first [destruct off | destruct off₁; destruct off₂]; try destruct off₃; simpl uc_eval; simpl ancillae_for. (* slow; around ~2m *) - all : simpl apply_error. - par : restore_dims; - autorewrite with f_to_vec_3_db; - try repeat rewrite f_to_vec_X; try lia; simpl f_to_vec; - repeat rewrite kron_1_l by auto with wf_db; - restore_dims; - autorewrite with decode_block_db; - reorder_scalars; restore_dims; - repeat simplify_sums; - pull_scalars; restore_dims; - autorewrite with f_to_vec_3_db; - restore_dims; - repeat rewrite kron_assoc by auto 10 with wf_db; - repeat (rewrite f_to_vec_merge; restore_dims); - repeat Common.f_to_vec_simpl_light; - simpl f_to_vec; Msimpl_light; - repeat rewrite <- Cmult_assoc; - rewrite collapse_scalar; autorewrite with C_db; - now flatten. + par : now post_offset_destruct. +Qed. + +Lemma error_decode_correct_bit_one_phase_flip : + forall (α β : C) e off {phase_n}, + (@uc_eval dim (apply_error (PhaseBitErrors (OnePhaseFlip phase_n off) e); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseBitErrors (OnePhaseFlip phase_n off) e). +Proof. + intros. + simpl uc_eval. + destruct phase_n; destruct e. + all : repeat rewrite <- Mmult_assoc. + all : rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_mult_dist_r. + all : try destruct safe_n; try destruct n. + all : simpl uc_eval. + all : simpl (_ ⨂ _). + all : repeat rewrite Mmult_assoc; restore_dims. + all : simpl apply_to_block. + all : try rewrite kron_1_l by auto with wf_db. + all : restore_dims. + all : repeat rewrite kron_assoc by auto 10 with wf_db. + all : correct_inPar ltac:( + try apply decode_block_well_typed + ). + all : correct_inPar ltac:( + try apply uc_well_typed_X; + first [destruct off0 | destruct off₁; destruct off₂]; + try destruct off₃; simpl; lia + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : correct_inPar ltac:( + try apply uc_well_typed_Z; + destruct off; simpl; lia + || apply (@uc_well_typed_SKIP block_dim); lia + ). + all : restore_dims. + all : simpl (_ + _). + all : distribute_over_blocks. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : repeat rewrite Mmult_assoc. + all : rewrite Z_block_zero, Z_block_seven. + all : try rewrite denote_SKIP; try lia; Msimpl_light. + all : repeat rewrite Mscale_mult_dist_r. + all : first [destruct off0 | destruct off₁; destruct off₂]; + try destruct off₃; simpl uc_eval; simpl ancillae_for. + par : post_offset_destruct. +Qed. + +Theorem error_decode_correct_bit_phase_flip : + forall (α β : C) {phase_n} (e₁ : phase_flip_error phase_n) (e₂ : bit_flip_error), + (@uc_eval dim (apply_error (PhaseBitErrors e₁ e₂); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseBitErrors e₁ e₂). +Proof. + Local Opaque decode. + intros. + enough ( + (@uc_eval dim (apply_error (PhaseBitErrors e₁ e₂); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseBitErrors e₁ e₂) + /\ + forall off, + (@uc_eval dim (apply_error (PhaseBitErrors (MorePhaseFlip phase_n e₁ off) e₂); decode)) × (encoded α β) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for (PhaseBitErrors (MorePhaseFlip phase_n e₁ off) e₂) + ). + { destruct H; assumption. } + induction e₁. + - split. + + apply error_decode_correct_bit_one_phase_flip. + + intros. + unfold apply_error. + change (apply_phase_flip_error ?a) with (apply_error (PhaseFlipError a)). + change (apply_bit_flip_error ?a) with (apply_error (BitFlipError a)). + Local Opaque apply_error. + simpl uc_eval. + do 2 rewrite Mmult_assoc. + rewrite two_phase_flip_correct. + specialize (error_decode_correct_bit_flip α β e₂) as He. + simpl uc_eval in He. + rewrite Mmult_assoc in He. + Set Printing Implicit. + restore_dims. + simpl in *. + apply He. + - destruct IHe₁ as [IHe IHme]. + split. + + apply IHme. + + intros off0. + Local Transparent apply_error. + unfold apply_error. + change (apply_phase_flip_error ?a) with (apply_error (PhaseFlipError a)). + change (apply_bit_flip_error ?a) with (apply_error (BitFlipError a)). + Local Opaque apply_error. + simpl uc_eval. + do 2 rewrite Mmult_assoc. + rewrite more_than_two_phase_flip_correct. + change (apply_error (@PhaseBitErrors phase_n e₁ e₂)) with (apply_error (PhaseFlipError e₁); apply_error (BitFlipError e₂)) in IHe. + simpl uc_eval in IHe. + repeat rewrite Mmult_assoc in IHe. + restore_dims. + simpl in *. + apply IHe. Qed. @@ -795,6 +893,11 @@ Proof. simpl uc_eval in H. rewrite Mmult_assoc in H. apply H. -Admitted. + - simpl ancillae_for. + specialize (error_decode_correct_bit_phase_flip α β e₁ e₂) as H. + simpl uc_eval in H. + rewrite Mmult_assoc in H. + apply H. +Qed. End NineQubitCode. From ee32b527993f3faa91d5ebf315c2be3757ca7cbf Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Mon, 5 Aug 2024 14:19:10 -0500 Subject: [PATCH 15/17] some cleanup --- examples/error-correction/NineQubitCode.v | 251 +++++++++++----------- 1 file changed, 123 insertions(+), 128 deletions(-) diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index 96b513a..eb8796b 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -13,6 +13,104 @@ Notation block_dim := 3%nat. Local Opaque CCX. +(** + Utilities + *) + +Ltac compute_vec := + simpl uc_eval; + repeat rewrite Mmult_assoc; restore_dims; + repeat Common.f_to_vec_simpl_light; + simpl; + replace (0 * PI)%R with 0%R by lra; + replace (1 * PI)%R with PI by lra; + autorewrite with Cexp_db; + repeat rewrite Mscale_1_l; + restore_dims; + autorewrite with ket_db. + +Ltac correct_inPar well_typed := + try + (replace (@uc_eval 9) with (@uc_eval (3 + 6)) by easy; + rewrite inPar_correct by well_typed); + try + (replace (@uc_eval 6) with (@uc_eval (3 + 3)) by easy; + rewrite inPar_correct by well_typed); + restore_dims. + +Ltac reorder_scalars := + repeat rewrite Mscale_assoc; + repeat rewrite Cmult_comm with (x := ((-1)%R : C)); + repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); + repeat rewrite <- Mscale_plus_distr_r. + +Ltac normalize_kron_notation := + repeat rewrite <- kron_assoc by auto 8 with wf_db; + try easy. + +Lemma inv_sqrt2_cubed : (/ √ 2 * / √ 2 * / √ 2)%C = (/ C2 * /√ 2)%C. +Proof. + now rewrite Cinv_sqrt2_sqrt. +Qed. + +Lemma uc_well_typed_SKIP {d} {h : 0 < d} : + @uc_well_typed _ d SKIP. +Proof. + unfold SKIP. + apply uc_well_typed_ID. + assumption. +Qed. + +Ltac simplify_sums := + match goal with + | [ |- context [ (?A .+ ?B) + .+ (?A .+ RtoC (-1)%R .* ?B) ] + ] => + replace (A .+ B + .+ (A .+ RtoC (-1)%R .* B)) with (C2 .* A) by lma + | [ |- context [?A .+ ?B + .+ RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B)] + ] => + replace (A .+ B + .+ RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B)) + with (C2 .* B) by lma + | [ |- context [?A .+ ?B + .+ RtoC (-1)%R .* ( + RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B))] + ] => + replace (A .+ B + .+ RtoC (-1)%R .* ( + RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B))) + with (C2 .* A) by lma + end. + +Ltac pull_scalars := + distribute_scale; + repeat rewrite Mscale_mult_dist_r; + repeat rewrite Mscale_assoc. + +Lemma collapse_scalar : + (/ C2 * (/ √ 2 * (/ √ 2 * (C2 * (/ √ 2 * (C2 * (/ √ 2 * C2)))))))%C = C1. +Proof. C_field. Qed. + +Ltac distribute_over_blocks := + repeat rewrite kron_1_l by auto 10 with wf_db; + repeat rewrite kron_assoc by auto with wf_db; + repeat rewrite Mmult_plus_distr_l; + repeat rewrite Mscale_mult_dist_r; + repeat rewrite Mmult_assoc; + restore_dims; + repeat rewrite kron_mixed_product; + repeat rewrite Mmult_plus_distr_l; + normalize_kron_notation; + repeat rewrite Mscale_mult_dist_r. + +Ltac flatten := + rewrite kron_plus_distr_r; + rewrite 2 Mscale_kron_dist_l; + rewrite ket0_equiv, ket1_equiv; + repeat (rewrite <- kron_assoc by auto 9 with wf_db; restore_dims). + (** Blocks *) @@ -54,30 +152,6 @@ Definition block_no := up_to_three. (* Qubits in a single block *) Definition block_offset := up_to_three. -Definition block_to_qubit (n : block_no) (off : block_offset) : nat := - n * 3 + off. - -Ltac compute_vec := - simpl uc_eval; - repeat rewrite Mmult_assoc; restore_dims; - repeat Common.f_to_vec_simpl_light; - simpl; - replace (0 * PI)%R with 0%R by lra; - replace (1 * PI)%R with PI by lra; - autorewrite with Cexp_db; - repeat rewrite Mscale_1_l; - restore_dims; - autorewrite with ket_db. - -Ltac correct_inPar well_typed := - try - (replace (@uc_eval 9) with (@uc_eval (3 + 6)) by easy; - rewrite inPar_correct by well_typed); - try - (replace (@uc_eval 6) with (@uc_eval (3 + 3)) by easy; - rewrite inPar_correct by well_typed); - restore_dims. - (** Encoding *) @@ -126,21 +200,6 @@ Notation encoded α β := ( ). -Ltac reorder_scalars := - repeat rewrite Mscale_assoc; - repeat rewrite Cmult_comm with (x := ((-1)%R : C)); - repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); - repeat rewrite <- Mscale_plus_distr_r. - -Ltac normalize_kron_notation := - repeat rewrite <- kron_assoc by auto 8 with wf_db; - try easy. - -Lemma inv_sqrt2_cubed : (/ √ 2 * / √ 2 * / √ 2)%C = (/ C2 * /√ 2)%C. -Proof. - now rewrite Cinv_sqrt2_sqrt. -Qed. - Theorem encode_correct : forall (α β : C), (@uc_eval dim encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = encoded α β. @@ -274,10 +333,10 @@ Definition ancillae_for_bit_flip (e : bit_flip_error) : list nat := | One => (block_to_bit_syn Zer0 off₁) ++ (block_to_bit_syn Two off₂) | Two => (block_to_bit_syn Zer0 off₁) ++ (block_to_bit_syn One off₂) end - | ThreeBitFlip off₁ off₂ off₃ => - (block_to_bit_syn Zer0 off₁) ++ - (block_to_bit_syn One off₂) ++ - (block_to_bit_syn Two off₃) + | ThreeBitFlip off₁ off₂ off₃ => + (block_to_bit_syn Zer0 off₁) ++ + (block_to_bit_syn One off₂) ++ + (block_to_bit_syn Two off₃) end. Definition ancillae_for (e : error) : Vector (2 ^ 8) := @@ -298,6 +357,9 @@ Lemma ancillae_for_two_phases_cancel {n}: = ancillae_for (PhaseFlipError e). Proof. easy. Qed. +(** + Decode + *) Definition decode_block : base_ucom block_dim := CNOT 0 1; CNOT 0 2; @@ -317,13 +379,11 @@ Proof. all : try apply uc_well_typed_Rz; lia. Qed. - Lemma decode_block_zero : uc_eval decode_block × ∣0,0,0⟩ = / √ 2 .* (∣0,0,0⟩ .+ ∣1,0,0⟩). Proof. rewrite Common.zero_3_f_to_vec. - compute_vec. - reflexivity. + now compute_vec. Qed. Lemma decode_block_one : @@ -386,92 +446,11 @@ Qed. decode_block_seven : decode_block_db. -(** - Decode - *) Definition decode : base_ucom dim := inPar decode_block (inPar decode_block decode_block); CNOT 0 3; CNOT 0 6; CCX 6 3 0. -(** - Full circuit - *) -Definition shor (e : error) : base_ucom dim := - encode; - apply_error e; - (* Does not use the regular: - `encode; apply_error e; recover; decode` - (because we do not recover the original encoding). - Attempting to do so requires 8 additional - qubits (really classical bits), 2 per block - for bit flip, and 2 for phase flip. - This makes the following analysis rougher. - *) - decode. - -Lemma uc_well_typed_SKIP {d} {h : 0 < d} : - @uc_well_typed _ d SKIP. -Proof. - unfold SKIP. - apply uc_well_typed_ID. - assumption. -Qed. - -Ltac simplify_sums := - match goal with - | [ |- context [ (?A .+ ?B) - .+ (?A .+ RtoC (-1)%R .* ?B) ] - ] => - replace (A .+ B - .+ (A .+ RtoC (-1)%R .* B)) with (C2 .* A) by lma - | [ |- context [?A .+ ?B - .+ RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B)] - ] => - replace (A .+ B - .+ RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B)) - with (C2 .* B) by lma - (* (∣ 0, 1, 1 ⟩ .+ ∣ 1, 1, 1 ⟩ - .+ (-1)%R - .* ((-1)%R .* (∣ 0, 1, 1 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩)) *) - | [ |- context [?A .+ ?B - .+ RtoC (-1)%R .* ( - RtoC (-1)%R .* (?A .+ RtoC (-1)%R .* ?B))] - ] => - replace (A .+ B - .+ RtoC (-1)%R .* ( - RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B))) - with (C2 .* A) by lma - end. - -Ltac pull_scalars := - distribute_scale; - repeat rewrite Mscale_mult_dist_r; - repeat rewrite Mscale_assoc. - -Lemma collapse_scalar : - (/ C2 * (/ √ 2 * (/ √ 2 * (C2 * (/ √ 2 * (C2 * (/ √ 2 * C2)))))))%C = C1. -Proof. C_field. Qed. - -Ltac distribute_over_blocks := - repeat rewrite kron_1_l by auto 10 with wf_db; - repeat rewrite kron_assoc by auto with wf_db; - repeat rewrite Mmult_plus_distr_l; - repeat rewrite Mscale_mult_dist_r; - repeat rewrite Mmult_assoc; - restore_dims; - repeat rewrite kron_mixed_product; - repeat rewrite Mmult_plus_distr_l; - normalize_kron_notation; - repeat rewrite Mscale_mult_dist_r. - - -Ltac flatten := - rewrite kron_plus_distr_r; - rewrite 2 Mscale_kron_dist_l; - rewrite ket0_equiv, ket1_equiv; - repeat (rewrite <- kron_assoc by auto 9 with wf_db; restore_dims). - Ltac compute_decoding := repeat rewrite kron_1_l by auto with wf_db; repeat rewrite <- Mmult_assoc; @@ -503,6 +482,23 @@ Ltac compute_decoding := now flatten. +(** + Full circuit + *) +Definition shor (e : error) : base_ucom dim := + encode; + apply_error e; + (* Does not use the regular: + `encode; apply_error e; recover; decode` + (because we do not recover the original encoding). + Attempting to do so requires 8 additional + qubits (really classical bits), 2 per block + for bit flip, and 2 for phase flip. + This makes the following analysis rougher. + *) + decode. + + Theorem decode_correct : forall (α β : C), (@uc_eval dim decode) × encoded α β @@ -724,7 +720,6 @@ Ltac post_offset_destruct := rewrite collapse_scalar; autorewrite with C_db; now flatten. - Theorem error_decode_correct_bit_flip : forall (α β : C) e, (@uc_eval dim (apply_error (BitFlipError e); decode)) × (encoded α β) From 69ae8a48918d382edc06035934187d12ead0e892 Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Wed, 14 Aug 2024 21:52:55 -0500 Subject: [PATCH 16/17] first pass with arbitrary --- examples/error-correction/Common.v | 1 + examples/error-correction/NineQubitCode.v | 436 ++++++++++++++++++++-- 2 files changed, 398 insertions(+), 39 deletions(-) diff --git a/examples/error-correction/Common.v b/examples/error-correction/Common.v index 511645b..dd0ddfa 100644 --- a/examples/error-correction/Common.v +++ b/examples/error-correction/Common.v @@ -74,6 +74,7 @@ Qed. Ltac f_to_vec_simpl_light := first [ rewrite f_to_vec_H + | rewrite f_to_vec_X | rewrite f_to_vec_CCX | rewrite f_to_vec_CNOT ]; diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index eb8796b..e41d24f 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -1,4 +1,5 @@ Require Export SQIR.UnitaryOps. +Require Export QuantumLib.Measurement. Require Import Common. Import Common. @@ -33,7 +34,7 @@ Ltac correct_inPar well_typed := try (replace (@uc_eval 9) with (@uc_eval (3 + 6)) by easy; rewrite inPar_correct by well_typed); - try + try (replace (@uc_eval 6) with (@uc_eval (3 + 3)) by easy; rewrite inPar_correct by well_typed); restore_dims. @@ -42,7 +43,7 @@ Ltac reorder_scalars := repeat rewrite Mscale_assoc; repeat rewrite Cmult_comm with (x := ((-1)%R : C)); repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); - repeat rewrite <- Mscale_plus_distr_r. +repeat rewrite <- Mscale_plus_distr_r. Ltac normalize_kron_notation := repeat rewrite <- kron_assoc by auto 8 with wf_db; @@ -82,6 +83,23 @@ Ltac simplify_sums := .+ RtoC (-1)%R .* ( RtoC (-1)%R .* (A .+ RtoC (-1)%R .* B))) with (C2 .* A) by lma + | [ |- context [?A .+ RtoC (-1)%R .* ?B + .+ (?A .+ ?B)] + ] => + replace (A .+ RtoC (-1)%R .* B + .+ (A .+ B)) with (C2 .* A) by lma + | [ |- context [RtoC (-1)%R .* ?A .+ ?B + .+ (?A .+ ?B)] + ] => + replace (RtoC (-1)%R .* A .+ B + .+ (A .+ B)) + with (C2 .* B) by lma + | [ |- context [(?A .+ ?B + .+ (RtoC (-1)%R .* ?A .+ ?B))] + ] => + replace (A .+ B + .+ (RtoC (-1)%R .* A .+ B)) + with (C2 .* B) by lma end. Ltac pull_scalars := @@ -127,25 +145,8 @@ Definition t_to_nat (t : up_to_three) : nat := | Two => 2 end. -Definition t_eq (t₁ t₂ : up_to_three) : bool := - match t₁, t₂ with - | Zer0, Zer0 - | One, One - | Two, Two => true - | _, _ => false - end. - Coercion t_to_nat : up_to_three >-> nat. -Definition t_of_nat (n : nat) (h : n < 3) : up_to_three. -Proof. - destruct n as [| [| [| n']]]. - - exact Zer0. - - exact One. - - exact Two. - - lia. -Defined. - (* Encoded blocks *) Definition block_no := up_to_three. @@ -169,7 +170,6 @@ Proof. now compute_vec. Qed. - Theorem encode_block_one : uc_eval encode_block × ∣1,0,0⟩ = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩). @@ -700,12 +700,59 @@ Proof. apply IHe. Qed. +Lemma X_0_block_zero : + @uc_eval block_dim (X 0) × ∣0,0,0⟩ = ∣1,0,0⟩. +Proof. + rewrite Common.zero_3_f_to_vec. + now compute_vec. +Qed. + +Lemma X_1_block_zero : + @uc_eval block_dim (X 1) × ∣0,0,0⟩ = ∣0,1,0⟩. +Proof. + rewrite Common.zero_3_f_to_vec. + now compute_vec. +Qed. + +Lemma X_2_block_zero : + @uc_eval block_dim (X 2) × ∣0,0,0⟩ = ∣0,0,1⟩. +Proof. + rewrite Common.zero_3_f_to_vec. + now compute_vec. +Qed. + +Lemma X_0_block_seven : + @uc_eval block_dim (X 0) × ∣1,1,1⟩ = ∣0,1,1⟩. +Proof. + rewrite Common.seven_3_f_to_vec. + now compute_vec. +Qed. + +Lemma X_1_block_seven : + @uc_eval block_dim (X 1) × ∣1,1,1⟩ = ∣1,0,1⟩. +Proof. + rewrite Common.seven_3_f_to_vec. + now compute_vec. +Qed. + +Lemma X_2_block_seven : + @uc_eval block_dim (X 2) × ∣1,1,1⟩ = ∣1,1,0⟩. +Proof. + rewrite Common.seven_3_f_to_vec. + now compute_vec. +Qed. + +#[export] Hint Rewrite + X_0_block_zero + X_1_block_zero + X_2_block_zero + X_0_block_seven + X_1_block_seven + X_2_block_seven + : X_off_block_db. + Ltac post_offset_destruct := - restore_dims; - autorewrite with f_to_vec_3_db; - try repeat rewrite f_to_vec_X; try lia; simpl f_to_vec; - repeat rewrite kron_1_l by auto with wf_db; - restore_dims; + autorewrite with X_off_block_db; autorewrite with decode_block_db; reorder_scalars; restore_dims; repeat simplify_sums; @@ -801,9 +848,13 @@ Proof. all : rewrite Z_block_zero, Z_block_seven. all : try rewrite denote_SKIP; try lia; Msimpl_light. all : repeat rewrite Mscale_mult_dist_r. - all : first [destruct off0 | destruct off₁; destruct off₂]; - try destruct off₃; simpl uc_eval; simpl ancillae_for. - par : post_offset_destruct. + all : restore_dims. + all : try (destruct off0; simpl uc_eval; simpl ancillae_for). + all : try post_offset_destruct. + all : destruct off₁; destruct off₂; simpl uc_eval; simpl ancillae_for. + 1-81 : try post_offset_destruct. + all : try destruct off₃; simpl uc_eval; simpl ancillae_for. + all : now post_offset_destruct. Qed. Theorem error_decode_correct_bit_phase_flip : @@ -836,7 +887,6 @@ Proof. specialize (error_decode_correct_bit_flip α β e₂) as He. simpl uc_eval in He. rewrite Mmult_assoc in He. - Set Printing Implicit. restore_dims. simpl in *. apply He. @@ -861,7 +911,7 @@ Proof. Qed. -Definition shor_correct (e : error) : forall (α β : C), +Theorem shor_correct (e : error) : forall (α β : C), (@uc_eval dim (shor e)) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for e. Proof. @@ -871,28 +921,336 @@ Proof. repeat rewrite Mmult_assoc. rewrite encode_correct. - destruct e. - - simpl ancillae_for. - specialize (error_decode_correct_no_error α β) as H. + destruct e; simpl ancillae_for. + - specialize (error_decode_correct_no_error α β) as H. simpl uc_eval in H. simpl ancillae_for in H. rewrite Mmult_assoc in H. apply H. - - simpl ancillae_for. - specialize (error_decode_correct_phase_flip α β e) as H. + - specialize (error_decode_correct_phase_flip α β e) as H. simpl uc_eval in H. rewrite Mmult_assoc in H. apply H. - - simpl ancillae_for. - specialize (error_decode_correct_bit_flip α β e) as H. + - specialize (error_decode_correct_bit_flip α β e) as H. simpl uc_eval in H. rewrite Mmult_assoc in H. apply H. - - simpl ancillae_for. - specialize (error_decode_correct_bit_phase_flip α β e₁ e₂) as H. + - specialize (error_decode_correct_bit_phase_flip α β e₁ e₂) as H. simpl uc_eval in H. rewrite Mmult_assoc in H. apply H. Qed. +Lemma pauli_spans_2_by_2 : + forall (M : Square 2), WF_Matrix M -> + exists λ₁ λ₂ λ₃ λ₄, + M = λ₁ .* (I 2) .+ λ₂ .* σx .+ λ₃ .* σy .+ λ₄ .* σz. +Proof. + intros. + exists ((M 0 0 + M 1 1) / C2)%C. + exists ((M 0 1 + M 1 0) / C2)%C. + exists (Ci * (M 0 1 - M 1 0) / C2)%C. + exists ((M 0 0 - M 1 1) / C2)%C. + solve_matrix. +Qed. + +Lemma pauli_spans_unitary_2_by_2 : + forall (M : Square 2), WF_Unitary M -> + exists λ₁ λ₂ λ₃ λ₄, + M = λ₁ .* (I 2) .+ λ₂ .* σx .+ λ₃ .* σy .+ λ₄ .* σz + /\ (Cmod λ₁ ^ 2 + Cmod λ₂ ^ 2 + Cmod λ₃ ^ 2 + Cmod λ₄ ^ 2)%C = C1. +Proof. + intros ? [Hwf Hinv]. + specialize (pauli_spans_2_by_2 M Hwf) as [λ₁ [λ₂ [λ₃ [λ₄ Heq]]]]. + exists λ₁, λ₂, λ₃, λ₄. + split. + apply Heq. + rewrite Heq in Hinv. + + repeat rewrite Mplus_adjoint in Hinv. + repeat rewrite Mscale_adj in Hinv. + repeat rewrite Mmult_plus_distr_l in Hinv. + repeat rewrite Mmult_plus_distr_r in Hinv. + repeat rewrite Mscale_mult_dist_r in Hinv. + repeat rewrite Mscale_mult_dist_l in Hinv. + specialize σx_unitary as [_ Hinvσx]. + specialize σy_unitary as [_ Hinvσy]. + specialize σz_unitary as [_ Hinvσz]. + rewrite Hinvσx in Hinv. clear Hinvσx. + rewrite Hinvσy in Hinv. clear Hinvσy. + rewrite Hinvσz in Hinv. clear Hinvσz. + + replace ((σx) †) with σx in Hinv by solve_matrix. + replace ((σy) †) with σy in Hinv by solve_matrix. + replace ((σz) †) with σz in Hinv by solve_matrix. + + autorewrite with M_db M_db_light in Hinv. + replace (σx × σy) with (Ci .* σz) in Hinv by lma'. + replace (σy × σx) with (-Ci .* σz) in Hinv by lma'. + replace (σz × σx) with (Ci .* σy) in Hinv by lma'. + replace (σx × σz) with (-Ci .* σy) in Hinv by lma'. + replace (σy × σz) with (Ci .* σx) in Hinv by lma'. + replace (σz × σy) with (-Ci .* σx) in Hinv by lma'. + assert (H00 := Hinv). + assert (H11 := Hinv). + clear Hinv. + apply (f_equal (fun m => m 0 0)) in H00. + apply (f_equal (fun m => m 1 1)) in H11. + unfold scale, Mplus, I, σx, σy, σz in H00, H11; simpl in H00, H11. + specialize (Cplus_simplify _ _ _ _ H00 H11) as H. + clear H00. clear H11. + ring_simplify in H. + repeat rewrite <- Cplus_assoc in H. + repeat rewrite <- Cmult_assoc in H. + repeat rewrite <- Cmult_plus_distr_l in H. + + replace (((R1 + R1)%R, (R0 + R0)%R)) with C2 in H. + 2 :{ + unfold C2. + apply c_proj_eq; simpl. + field. + field. + } + apply Cmult_cancel_l with (a := C2); try nonzero. + rewrite Cmult_1_r. + + repeat rewrite <- Cmod_sqr in H. + rewrite Cmult_comm with (x := λ₁) in H. + rewrite <- Cmod_sqr in H. + repeat rewrite Cplus_assoc in H. + + exact H. +Qed. + + +Lemma YeqiXZ : + σy = Ci .* σx × σz. +Proof. solve_matrix. Qed. + +Definition block_to_qubit (n : block_no) (off : block_offset) : nat := + n * 3 + off. + +Definition ancillae_for_arbitrary + (λ₁ λ₂ λ₃ λ₄ : C) + (n : block_no) + (off : block_offset) : Vector (2 ^ 8) + := ( + λ₁ .* ancillae_for NoError + .+ λ₂ .* ancillae_for (BitFlipError (OneBitFlip n off)) + .+ λ₃ * Ci .* ancillae_for (PhaseBitErrors (OnePhaseFlip n off) (OneBitFlip n off)) + .+ λ₄ .* ancillae_for (PhaseFlipError (OnePhaseFlip n off)) + ). + +Lemma Cmod_Ci : Cmod Ci = 1%R. +Proof. + unfold Ci, Cmod; simpl. + rewrite Rmult_0_l. + rewrite Rplus_0_l. + do 2 rewrite Rmult_1_l. + exact sqrt_1. +Qed. + +Lemma ancillae_pure_vector_cond : + forall (λ₁ λ₂ λ₃ λ₄ : C) (n : block_no) (off : block_offset), + (Cmod λ₁ ^ 2 + Cmod λ₂ ^ 2 + Cmod λ₃ ^ 2 + Cmod λ₄ ^ 2)%C = C1 -> + Pure_State_Vector (ancillae_for_arbitrary λ₁ λ₂ λ₃ λ₄ n off). +Proof. + intros. + unfold Pure_State_Vector. + split. + 1: { + destruct n; destruct off; unfold ancillae_for_arbitrary; simpl. + all : auto 18 with wf_db. + } + destruct n; destruct off. + all : unfold ancillae_for_arbitrary; simpl. + all : repeat rewrite kron_1_l by auto with wf_db. + all : repeat rewrite Mplus_adjoint. + all : repeat rewrite Mscale_adj. + all : restore_dims. + all : rewrite <- ket0_equiv, <- ket1_equiv. + all : repeat rewrite kron_adjoint. + all : repeat rewrite Mmult_plus_distr_r. + all : autorewrite with ket_db. + all : repeat rewrite Mplus_assoc. + all : repeat rewrite <- Mscale_plus_distr_l. + all : repeat rewrite <- Cmod_sqr. + all : rewrite Cmod_mult. + all : rewrite Cmod_Ci. + all : rewrite Rmult_1_r. + all : repeat rewrite Cplus_assoc. + all : rewrite H. + all : now rewrite Mscale_1_l. +Qed. + + +Theorem shor_arbitrary_correct (M : Square 2) : + WF_Unitary M -> + forall (α β : C) (n : block_no) (off : block_offset), + exists (φ : Vector (2^8)), + ( uc_eval decode + × pad_u dim (block_to_qubit n off) M + × uc_eval encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) + = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ φ. +Proof. + intros. + repeat rewrite Mmult_assoc. + rewrite encode_correct. + specialize (pauli_spans_unitary_2_by_2 M H) as Hpauli. + destruct Hpauli as [λ₁ [λ₂ [λ₃ [λ₄ [Hpauli Hmod]]]]]. + rewrite Hpauli. + exists (ancillae_for_arbitrary λ₁ λ₂ λ₃ λ₄ n off). + destruct n; destruct off. + all : cbn. + all : repeat rewrite kron_1_l by auto with wf_db. + all : try rewrite kron_1_r by auto with wf_db. + 1 : replace (I 256) with (I 4 ⊗ I 8 ⊗ I 8) by (repeat rewrite id_kron; easy). + 9 : replace (I 256) with (I 8 ⊗ I 8 ⊗ I 4) by (repeat rewrite id_kron; easy). + 2 : replace (I 128) with (I 2 ⊗ I 8 ⊗ I 8) by (repeat rewrite id_kron; easy). + 8 : replace (I 128) with (I 8 ⊗ I 8 ⊗ I 2) by (repeat rewrite id_kron; easy). + 3,7 : replace (I 64) with (I 8 ⊗ I 8) by (repeat rewrite id_kron; easy). + 7 : replace (I 32) with (I 8 ⊗ I 4) by (repeat rewrite id_kron; easy). + 5 : replace (I 32) with (I 4 ⊗ I 8) by (repeat rewrite id_kron; easy). + 6 : replace (I 16) with (I 8 ⊗ I 2) by (repeat rewrite id_kron; easy). + + all : restore_dims. + all : repeat rewrite kron_assoc by auto 10 with wf_db. + 6 : replace (I 8 ⊗ I 2) with (I 2 ⊗ I 8) by (repeat rewrite id_kron; easy); + restore_dims. + 1-3 : repeat rewrite <- kron_assoc by auto 10 with wf_db. + 5-6 : rewrite <- kron_assoc with (C := I 8) by auto 10 with wf_db. + 6 : repeat (rewrite kron_assoc by auto 10 with wf_db; restore_dims). + 6 : repeat rewrite <- kron_assoc with (A := I 2) by auto 10 with wf_db. + 6 : repeat rewrite <- kron_assoc with (B := I 2) by auto 10 with wf_db. + 7 : repeat rewrite <- kron_assoc with (A := I 4) by auto 10 with wf_db. + + all : restore_dims. + all : do 2 rewrite Mmult_plus_distr_l. + all : pull_scalars; restore_dims. + all : repeat rewrite kron_mixed_product. + all : repeat rewrite Mmult_1_l by auto with wf_db. + all : replace (I 4) with (I 2 ⊗ I 2) by (repeat rewrite id_kron; easy). + all : do 2 rewrite Mmult_plus_distr_l. + all : rewrite Mscale_mult_dist_r. + all : restore_dims. + all : repeat (rewrite kron_assoc by auto 10 with wf_db; restore_dims). + all : repeat rewrite kron_mixed_product. + all : repeat rewrite Mmult_plus_distr_r. + all : repeat rewrite Mscale_mult_dist_l. + all : rewrite ket0_equiv, ket1_equiv. + all : restore_dims. + all : repeat rewrite Mmult_1_l by auto with wf_db. + all : rewrite X0_spec, X1_spec, Y0_spec, Y1_spec, Z0_spec, Z1_spec. + Local Transparent decode. + all : simpl uc_eval. + all : repeat rewrite Mmult_assoc by auto 10 with wf_db. + all : correct_inPar ltac:(apply decode_block_well_typed). + all : repeat rewrite kron_mixed_product. + all : repeat rewrite kron_plus_distr_r. + all : repeat rewrite kron_plus_distr_l. + all : repeat rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_mult_dist_r. + all : repeat rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_kron_dist_l. + all : repeat rewrite Mscale_mult_dist_r. + all : repeat rewrite Mscale_kron_dist_r. + all : repeat rewrite Mscale_mult_dist_r. + all : restore_dims. + all : repeat rewrite <- kron_assoc by auto 10 with wf_db. + all : restore_dims. + all : repeat rewrite Mscale_plus_distr_r with (x := ((-1)%R : C)). + all : repeat rewrite Mscale_assoc. + all : repeat rewrite Cmult_comm with (x := ((-1)%R : C)). + all : repeat rewrite <- Mscale_assoc. + all : repeat rewrite Mplus_assoc. + all : repeat rewrite Mplus_comm with (A := λ₁ .* _). + all : repeat rewrite Mplus_assoc. + all : do 2 rewrite <- Mscale_plus_distr_r. + all : repeat rewrite Mplus_comm with (A := λ₂ .* _). + all : repeat rewrite Mplus_assoc. + all : do 2 rewrite <- Mscale_plus_distr_r. + all : repeat rewrite Mplus_comm with (A := λ₃ .* _). + all : repeat rewrite Mplus_assoc. + all : do 2 rewrite <- Mscale_plus_distr_r. + all : repeat rewrite Mplus_comm with (A := λ₄ .* _). + all : repeat rewrite Mplus_assoc. + all : do 2 rewrite <- Mscale_plus_distr_r. + all : autorewrite with decode_block_db. + all : restore_dims. + all : replace (-Ci) with (Ci * (-1)%R)%C by lca. + all : reorder_scalars. + all : repeat rewrite <- Cmult_assoc with (y := ((-1)%R : C)). + all : rewrite Cmult_comm with (x := ((-1)%R : C)). + all : reorder_scalars. + all : repeat rewrite Cmult_assoc with (z := ((-1)%R : C)). + all : repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)). + all : reorder_scalars. + all : pull_scalars. + all : repeat rewrite Mscale_plus_distr_r with (x := ((-1)%R : C)). + all : repeat rewrite Mscale_assoc. + all : replace ((-1)%R * (-1)%R)%C with C1 by lca. + all : repeat rewrite Mscale_1_l. + all : restore_dims. + all : repeat simplify_sums. + all : autorewrite with f_to_vec_3_db. + all : distribute_scale. + all : distribute_plus. + all : repeat rewrite Mscale_mult_dist_r. + all : repeat rewrite Mmult_plus_distr_l. + all : repeat rewrite Mscale_kron_dist_r. + all : repeat rewrite Mscale_kron_dist_l. + all : repeat rewrite Mscale_kron_dist_r. + all : repeat rewrite kron_assoc by auto 10 with wf_db. + all : repeat rewrite Mscale_assoc. + all : repeat rewrite Mscale_mult_dist_r. + all : restore_dims. + all : repeat rewrite kron_assoc by auto 10 with wf_db. + all : repeat (rewrite f_to_vec_merge; restore_dims). + all : repeat f_to_vec_simpl_light. + all : simpl. + all : repeat rewrite kron_1_l by auto with wf_db. + all : repeat rewrite kron_assoc by auto with wf_db. + all : repeat rewrite <- Cmult_assoc. + all : rewrite <- Mscale_assoc with (x := α); + rewrite <- Mscale_assoc with (x := β). + all : repeat rewrite Mscale_plus_distr_r. + all : repeat rewrite Mscale_assoc. + all : repeat rewrite <- Cmult_assoc. + all : repeat rewrite Cmult_comm with (x := λ₁); + repeat rewrite Cmult_comm with (x := λ₂); + repeat rewrite Cmult_comm with (x := λ₃); + repeat rewrite Cmult_comm with (x := λ₄). + all : repeat rewrite Cmult_comm with (x := Ci). + all : repeat rewrite Cmult_assoc. + all : do 2 rewrite Cmult_comm with (y := λ₁); + do 2 rewrite Cmult_comm with (y := λ₂); + do 2 rewrite Cmult_comm with (y := λ₃); + do 2 rewrite Cmult_comm with (y := λ₄). + all : repeat rewrite Cmult_comm with (y := Ci). + all : repeat rewrite <- Cmult_assoc. + all : match goal with + | [ |- context [ + ?λ * (?γ * (/ C2 * ?c)) .* _ + ] + ] => replace (/ C2 * c)%C with (C1) by C_field + end. + all : repeat rewrite Cmult_1_r. + all : unfold ancillae_for_arbitrary; simpl. + all : repeat rewrite kron_1_l by auto with wf_db. + all : rewrite ket0_equiv. + all : repeat rewrite kron_plus_distr_l. + all : repeat rewrite Mscale_kron_dist_r. + all : repeat rewrite Mscale_plus_distr_r. + all : restore_dims. + all : repeat rewrite <- kron_assoc by auto 10 with wf_db. + all : repeat rewrite Mscale_assoc. + all : repeat rewrite Cmult_assoc. + all : repeat rewrite Cmult_comm with (y := α); + repeat rewrite Cmult_comm with (y := β). + all : repeat rewrite Cmult_assoc. + all : repeat rewrite Mplus_assoc. + all : reflexivity. +Qed. + + End NineQubitCode. From 38e5fc5445c21f22cbb35efe7128454a1b50124d Mon Sep 17 00:00:00 2001 From: Fady Adal <2masadel@gmail.com> Date: Thu, 15 Aug 2024 00:29:35 -0500 Subject: [PATCH 17/17] continous case --- examples/error-correction/NineQubitCode.v | 62 +++++++++++++++++++---- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/examples/error-correction/NineQubitCode.v b/examples/error-correction/NineQubitCode.v index e41d24f..12fe7a0 100644 --- a/examples/error-correction/NineQubitCode.v +++ b/examples/error-correction/NineQubitCode.v @@ -43,7 +43,7 @@ Ltac reorder_scalars := repeat rewrite Mscale_assoc; repeat rewrite Cmult_comm with (x := ((-1)%R : C)); repeat rewrite <- Mscale_assoc with (y := ((-1)%R : C)); -repeat rewrite <- Mscale_plus_distr_r. + repeat rewrite <- Mscale_plus_distr_r. Ltac normalize_kron_notation := repeat rewrite <- kron_assoc by auto 8 with wf_db; @@ -153,6 +153,11 @@ Definition block_no := up_to_three. (* Qubits in a single block *) Definition block_offset := up_to_three. + +Definition block_to_qubit (n : block_no) (off : block_offset) : nat := + n * 3 + off. + + (** Encoding *) @@ -163,7 +168,7 @@ Definition encode_block : base_ucom block_dim := CNOT 0 2. Theorem encode_block_zero : - uc_eval encode_block × ∣0,0,0⟩ + uc_eval encode_block × ∣ 0, 0, 0 ⟩ = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ ∣ 1, 1, 1 ⟩). Proof. rewrite Common.zero_3_f_to_vec. @@ -171,7 +176,7 @@ Proof. Qed. Theorem encode_block_one : - uc_eval encode_block × ∣1,0,0⟩ + uc_eval encode_block × ∣ 1, 0 , 0 ⟩ = / √ 2 .* (∣ 0, 0, 0 ⟩ .+ (-1)%R .* ∣ 1, 1, 1 ⟩). Proof. rewrite Common.one_3_f_to_vec. @@ -509,6 +514,9 @@ Proof. now compute_decoding. Qed. +(** + Correctness + *) Theorem error_decode_correct_no_error : forall (α β : C), @@ -911,6 +919,9 @@ Proof. Qed. +(** + Main correctness proof for the discrete error case. +*) Theorem shor_correct (e : error) : forall (α β : C), (@uc_eval dim (shor e)) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ ancillae_for e. @@ -941,6 +952,10 @@ Proof. apply H. Qed. +(** + Generalized errors on single qubits + *) + Lemma pauli_spans_2_by_2 : forall (M : Square 2), WF_Matrix M -> exists λ₁ λ₂ λ₃ λ₄, @@ -1027,9 +1042,6 @@ Lemma YeqiXZ : σy = Ci .* σx × σz. Proof. solve_matrix. Qed. -Definition block_to_qubit (n : block_no) (off : block_offset) : nat := - n * 3 + off. - Definition ancillae_for_arbitrary (λ₁ λ₂ λ₃ λ₄ : C) (n : block_no) @@ -1083,23 +1095,33 @@ Proof. all : now rewrite Mscale_1_l. Qed. +Definition shor_arbitrary_unitary_matrix (M : Square 2) (n : block_no) (off : block_offset) := + uc_eval decode + × pad_u dim (block_to_qubit n off) M + × uc_eval encode. +(** + Main correctness proof for the continuous error case. +*) Theorem shor_arbitrary_correct (M : Square 2) : WF_Unitary M -> forall (α β : C) (n : block_no) (off : block_offset), exists (φ : Vector (2^8)), - ( uc_eval decode - × pad_u dim (block_to_qubit n off) M - × uc_eval encode) × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) + Pure_State_Vector φ /\ + shor_arbitrary_unitary_matrix M n off × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) = (α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ φ. Proof. intros. repeat rewrite Mmult_assoc. + unfold shor_arbitrary_unitary_matrix. + repeat rewrite Mmult_assoc. rewrite encode_correct. specialize (pauli_spans_unitary_2_by_2 M H) as Hpauli. destruct Hpauli as [λ₁ [λ₂ [λ₃ [λ₄ [Hpauli Hmod]]]]]. rewrite Hpauli. exists (ancillae_for_arbitrary λ₁ λ₂ λ₃ λ₄ n off). + split. + 1 : exact (ancillae_pure_vector_cond λ₁ λ₂ λ₃ λ₄ n off Hmod). destruct n; destruct off. all : cbn. all : repeat rewrite kron_1_l by auto with wf_db. @@ -1252,5 +1274,27 @@ Proof. all : reflexivity. Qed. +Theorem shor_arbitrary_correct_prob (M : Square 2) : + WF_Unitary M -> + forall (α β : C) (n : block_no) (off : block_offset), + let r := shor_arbitrary_unitary_matrix M n off × ((α .* ∣0⟩ .+ β .* ∣1⟩) ⊗ 8 ⨂ ∣0⟩) in + @prob_partial_meas 1 (dim - 1) ∣0⟩ r = (Cmod α ^ 2)%R + /\ @prob_partial_meas 1 (dim - 1) ∣1⟩ r = (Cmod β ^ 2)%R. +Proof. + intros. + specialize (shor_arbitrary_correct M H α β n off) as [R [[HWFR HDag] HR]]. + subst r. + rewrite HR. + do 2 rewrite prob_partial_meas_alt. + distribute_adjoint. + Msimpl. + autorewrite with ket_db. + do 2 rewrite norm_scale. + unfold norm. + unfold inner_product. + restore_dims. + rewrite HDag. + split; simpl; rewrite sqrt_1; repeat rewrite Rmult_1_r; easy. +Qed. End NineQubitCode.