Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/circuit.sig
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
signature CIRCUIT = sig

datatype t = I | X | Y | Z | H | SW
datatype t = I | X | Y | Z | H | T | SW
| C of t
| Tensor of t * t
| Seq of t * t

val oo : t * t -> t
val ** : t * t -> t
val ++ : t * t -> t
val ** : t * t -> t

val pp : t -> string
val draw : t -> string
val dim : t -> int
val pp : t -> string
val draw : t -> string
val draw_latex : t -> string
val height : t -> int

end
42 changes: 33 additions & 9 deletions src/circuit.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ structure Circuit : CIRCUIT = struct

infix |> fun a |> f = f a

datatype t = I | X | Y | Z | H | SW
datatype t = I | X | Y | Z | H | T | SW
| Tensor of t * t
| Seq of t * t
| C of t

val oo = op Seq
val ++ = op Seq
val ** = op Tensor

fun pp t =
let fun maybePar P s = if P then "(" ^ s ^ ")" else s
fun pp p t =
case t of
Tensor(t1,t2) => maybePar (p > 4) (pp 4 t1 ^ " ** " ^ pp 4 t2)
| Seq(t1,t2) => maybePar (p > 3) (pp 3 t1 ^ " oo " ^ pp 3 t2)
| Seq(t1,t2) => maybePar (p > 3) (pp 3 t1 ^ " ++ " ^ pp 3 t2)
| C t => "C" ^ pp 8 t
| I => "I" | X => "X" | Y => "Y" | Z => "Z" | H => "H" | SW => "SW"
| I => "I" | X => "X" | Y => "Y" | Z => "Z" | H => "H" | T => "T" | SW => "SW"
in pp 0 t
end

Expand All @@ -32,6 +32,7 @@ structure Circuit : CIRCUIT = struct
| Y => Diagram.box "Y"
| Z => Diagram.box "Z"
| H => Diagram.box "H"
| T => Diagram.box "T"
| C X => Diagram.cntrl "X"
| C Y => Diagram.cntrl "Y"
| C Z => Diagram.cntrl "Z"
Expand All @@ -41,17 +42,40 @@ structure Circuit : CIRCUIT = struct
in dr t |> Diagram.toString
end

fun dim t =
structure DiagramL = DiagramLatex

fun draw_latex t =
let fun dr t =
case t of
SW => DiagramL.swap
| Tensor(a,b) => DiagramL.par(dr a, dr b)
| Seq(a,b) => DiagramL.seq(dr a, dr b)
| I => DiagramL.line
| X => DiagramL.box "X"
| Y => DiagramL.box "Y"
| Z => DiagramL.box "Z"
| H => DiagramL.box "H"
| T => DiagramL.box "T"
| C X => DiagramL.cntrl "X"
| C Y => DiagramL.cntrl "Y"
| C Z => DiagramL.cntrl "Z"
| C H => DiagramL.cntrl "H"
| C _ => raise Fail ("Circuit.draw_latex: Controlled circuit " ^
pp t ^ " cannot be drawn")
in dr t |> DiagramL.toString
end

fun height t =
case t of
Tensor(a,b) => dim a + dim b
Tensor(a,b) => height a + height b
| Seq(a,b) =>
let val d = dim a
in if d <> dim b
let val d = height a
in if d <> height b
then raise Fail "Sequence error: mismatching dimensions"
else d
end
| SW => 2
| C t => 1 + dim t
| C t => 1 + height t
| _ => 1

end
36 changes: 21 additions & 15 deletions src/comp.sml
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,30 @@ structure Comp :> COMP = struct
| Y => ret (APP("Y",[]))
| Z => ret (APP("Z",[]))
| H => ret (APP("H",[]))
| T => ret (APP("T",[]))
| SW => ret (APP("SW",[]))
| Seq(t1,t2) =>
comp t1 >>= (fn e1 =>
comp t2 >>= (fn e2 =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("matmul", [e2,e1]),ty))
end))
| Tensor(t1,t2) =>
comp t1 >>= (fn e1 =>
comp t2 >>= (fn e2 =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("tensor", [e1,e2]),ty))
end))
| C t' => comp t' >>= (fn e =>
let val n = Int.toString (pow2 (dim t))
let val n = Int.toString (pow2 (height t))
val ty = "[" ^ n ^ "][" ^ n ^ "]C.complex"
in ret (TYPED(APP("control",[e]),ty))
end)
end

fun vecTyFromDim d =
fun vecTyFromHeight d =
"[" ^ Int.toString(pow2 d) ^ "]C.complex"

local
Expand All @@ -73,15 +74,15 @@ structure Comp :> COMP = struct
end
fun FunC A (f:F.exp -> F.exp F.M) : F.var option F.M =
if allI A then ret NONE
else let val ty = vecTyFromDim (Circuit.dim A)
else let val ty = vecTyFromHeight (Circuit.height A)
in Fun f ty ty >>= (ret o SOME)
end
fun splitF d v =
let val ty = "[" ^ Int.toString (pow2 d) ^ "+" ^
Int.toString (pow2 d) ^ "]C.complex"
in APP("split",[TYPED(v,ty)])
end
fun concatF d a b = TYPED(APP("concat",[a,b]),vecTyFromDim d)
fun concatF d a b = TYPED(APP("concat",[a,b]),vecTyFromHeight d)
fun unvecF (e,ty) = APP("unvec", [TYPED(e,ty)])
fun vecF e = APP("vec",[e])
fun mapF f e = APP("map", [VAR f,e])
Expand All @@ -95,25 +96,26 @@ structure Comp :> COMP = struct
Circuit.I => ret v
| Circuit.Seq(t1,t2) => icomp t1 v >>= (icomp t2)
| Circuit.C t' =>
Let (splitF (Circuit.dim t') v) >>= (fn p =>
Let (splitF (Circuit.height t') v) >>= (fn p =>
icomp t' (SEL(1,VAR p)) >>= (fn v1 =>
ret (concatF (Circuit.dim t) (SEL(0,VAR p)) v1)))
ret (concatF (Circuit.height t) (SEL(0,VAR p)) v1)))
| Circuit.Tensor(A,B) =>
FunC A (icomp A) >>= (fn Af =>
FunC B (icomp B) >>= (fn Bf =>
let val dA = pow2(Circuit.dim A)
val dB = pow2(Circuit.dim B)
let val dA = pow2(Circuit.height A)
val dB = pow2(Circuit.height B)
val ty = "[" ^ Int.toString dA ^ "*" ^
Int.toString dB ^ "]C.complex"
in Let (unvecF(v,ty)) >>= (fn V =>
Let (mapF' Bf (transposeF (VAR V))) >>= (fn W =>
Let (mapF' Af (transposeF (VAR W))) >>= (fn Y =>
ret (TYPED(vecF (VAR Y),vecTyFromDim (Circuit.dim t))))))
ret (TYPED(vecF (VAR Y),vecTyFromHeight (Circuit.height t))))))
end))
| Circuit.H => ret (matvecmulF (APP("H",[])) v)
| Circuit.X => ret (matvecmulF (APP("X",[])) v)
| Circuit.Y => ret (matvecmulF (APP("Y",[])) v)
| Circuit.Z => ret (matvecmulF (APP("Z",[])) v)
| Circuit.H => ret (matvecmulF (APP("H",[])) v)
| Circuit.T => ret (matvecmulF (APP("T",[])) v)
| Circuit.SW => ret (matvecmulF (APP("SW",[])) v)
end

Expand All @@ -126,20 +128,24 @@ structure Comp :> COMP = struct
val cni = APP("C.mk_im", [CONST "(-1)"])
val rsqrt2 = APP("C.mk_re", [CONST "(1.0 / f64.sqrt(2.0))"])
val rnsqrt2 = APP("C.mk_re", [CONST "((-1.0) / f64.sqrt(2.0))"])
val tmp = APP("C.exp", [APP("C.mk_im",[CONST "(f64.pi/4)"])])
val rsqrt2eipi4 = APP("C.*", [rsqrt2,tmp])
fun ty n = "[" ^ Int.toString n ^ "][" ^ Int.toString n ^ "]C.complex"
fun binds nil = ret ()
| binds ((s,n,e)::rest) =
FunNamed s (fn _ => ret e) "()" (ty n) >>= (fn _ => binds rest)
in binds [("I", 2, ARR[ARR[c1,c0],
ARR[c0,c1]]),
("H", 2, ARR[ARR[rsqrt2,rsqrt2],
ARR[rsqrt2,rnsqrt2]]),
("X", 2, ARR[ARR[c0,c1],
ARR[c1,c0]]),
("Y", 2, ARR[ARR[c0,cni],
ARR[ci,c0]]),
("Z", 2, ARR[ARR[c1,c0],
ARR[c0,cn1]]),
("H", 2, ARR[ARR[rsqrt2,rsqrt2],
ARR[rsqrt2,rnsqrt2]]),
("T", 2, ARR[ARR[rsqrt2,c0],
ARR[c0,rsqrt2eipi4]]),
("SW", 4, ARR[ARR[c1,c0,c0,c0],
ARR[c0,c0,c1,c0],
ARR[c0,c1,c0,c0],
Expand All @@ -156,7 +162,7 @@ structure Comp :> COMP = struct

fun circuitToFutFunBind (f:string) (t:Circuit.t) : string =
let open F infix >>=
val ty = vecTyFromDim (Circuit.dim t)
val ty = vecTyFromHeight (Circuit.height t)
in runBinds (FunNamed f (icomp t) ty ty >>= (fn _ =>
ret()))
end
Expand Down
124 changes: 124 additions & 0 deletions src/diagram-latex.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
structure DiagramLatex :> DIAGRAM = struct

datatype t = Box of string
| Line
| Cntrl of string
| Swap
| Par of t * t
| Seq of t * t

fun depth t : int =
case t of
Line => 1
| Box _ => 1
| Cntrl _ => 1
| Swap => 1
| Par (t1,t2) => Int.max(depth t1, depth t2)
| Seq(t1,t2) => depth t1 + depth t2

fun height t : int =
case t of
Line => 1
| Box _ => 1
| Cntrl _ => 2
| Swap => 2
| Par (t1,t2) => height t1 + height t2
| Seq(t1,t2) => Int.max(height t1, height t2)

val dy = 10
val dx = 10

fun i2s x = if x < 0 then "-" ^ i2s (~x)
else Int.toString x

fun put (x,y) c =
"\\put(" ^ i2s x ^ "," ^ i2s y ^ "){" ^ c ^ "}"

fun circ () = "\\circle*{" ^ i2s (dy div 10) ^ "}"

fun line (x,y) l =
"\\line(" ^ i2s x ^ "," ^ i2s y ^ "){" ^ i2s l ^ "}"

fun framebox (sx,sy) s =
"\\framebox(" ^ i2s sx ^ "," ^ i2s sy ^ "){" ^ s ^ "}"

fun put_line (x,y) a =
put (x,y + dy div 2) (line (1,0) dx) :: a

fun put_swap (x,y) a =
let val (x1,y1) = (x + dx div 2, y + dy div 2)
val (x2,y2) = (x1,y1-dy)
in put_line (x,y)
(put_line (x,y-dy)
(put (x1,y2) (line (0,1) dy) ::
put (x1,y1) (circ()) ::
put (x2,y2) (circ()) :: a))
end

fun put_cntrl (x,y) a =
let val (x1,y1) = (x + dx div 2, y + dy div 2)
val dy' = dy - 3 * dy div 10
in put_line (x,y)
(put (x1,y1) (line (0,~1) dy') ::
put (x1,y1) (circ()) :: a)
end

fun put_box s (x,y) a =
let val dx' = dx div 5
val x' = x + dx'
val dy' = dy div 5
val y' = y + dy'
val sx = dx - 2 * dx'
val sy = dy - 2 * dy'
in put (x',y') (framebox(sx,sy) s) ::
put (x,y + dy div 2) (line(1,0) dx') ::
put (x+dx,y + dy div 2) (line(~1,0) dx') :: a
end

fun lines n =
if n > 1 then Par(Line,lines(n-1))
else Line

fun padl t =
Seq(lines (height t), t)

fun padr t =
Seq(t,lines (height t))

fun toStr x y t a =
case t of
Box s => put_box s (x,y) a
| Line => put_line (x,y) a
| Swap => put_swap (x,y) a
| Cntrl s => put_box s (x,y - dy) (put_cntrl (x,y) a)
| Seq (t1,t2) => toStr (x + dx*(depth t1)) y t2 (toStr x y t1 a)
| Par (t1,t2) =>
let val d1 = depth t1
val d2 = depth t2
in if d1 > d2 + 1 then
toStr x y (Par(t1,padl (padr t2))) a
else if d1 > d2 then
toStr x y (Par(t1,padl t2)) a
else if d2 > d1 + 1 then
toStr x y (Par(padl (padr t1),t2)) a
else if d2 > d1 then
toStr x y (Par(padl t1,t2)) a
else
toStr x (y - dy*(height t1)) t2 (toStr x y t1 a)
end

fun toString t =
let val (h,d) = (height t, depth t)
in String.concatWith "\n"
("\\begin{picture}(" ^ i2s (dx*d) ^ "," ^ i2s (dy*h) ^ ")(0,0)" ::
toStr 0 ((h-1)*dy) t ["\\end{picture}"])
end

val box = Box
val line = Line
val cntrl = Cntrl
val swap = Swap
val seq = Seq
val par = Par

end
4 changes: 4 additions & 0 deletions src/diagram.mlb
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
local $(SML_LIB)/basis/basis.mlb
in diagram.sml
diagram-latex.sml
end
2 changes: 1 addition & 1 deletion src/diagram.sml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ sig
val toString : t -> string
end

structure Diagram : DIAGRAM =
structure Diagram :> DIAGRAM =
struct
type t = string list (* lines; invariant: lines have equal size *)

Expand Down
2 changes: 1 addition & 1 deletion src/quantum.mlb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local $(SML_LIB)/basis/basis.mlb
../lib/github.com/diku-dk/sml-matrix/matrix.mlb
../lib/github.com/diku-dk/sml-complex/complex.mlb
in diagram.sml
in diagram.mlb
circuit.sig
circuit.sml
semantics.sig
Expand Down
5 changes: 3 additions & 2 deletions src/quantum_ex1.sml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@

open Circuit Semantics
infix 3 oo
infix 3 ++
infix 4 **

fun run c k =
(print ("Circuit for c = " ^ pp c ^ ":\n");
print (draw c ^ "\n");
print (draw_latex c ^ "\n");
print ("Semantics of c:\n" ^ pp_mat(sem c) ^ "\n");
print ("Result distribution when evaluating c on " ^ pp_ket k ^ " :\n");
let val v0 = init k
Expand All @@ -15,4 +16,4 @@ fun run c k =
; print ("V2: " ^ pp_state (interp c v0) ^ "\n")
end)

val () = run ((I ** H oo C X oo Z ** Z oo C X oo I ** H) ** I oo I ** SW oo C X ** Y) (ket[1,0,1])
val () = run ((I ** H ++ C X ++ Z ** Z ++ C X ++ I ** H) ** I ++ I ** SW ++ C X ** Y) (ket[1,0,1])
Loading