Skip to content

Commit

Permalink
✨ Simplify types (#10)
Browse files Browse the repository at this point in the history
Co-authored-by: Joonhyup Lee <[email protected]>
  • Loading branch information
LimitEpsilon and LimitEpsilon authored Jun 5, 2024
1 parent 0a7607b commit 3910718
Showing 1 changed file with 8 additions and 64 deletions.
72 changes: 8 additions & 64 deletions lib/program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,7 @@ type program = { funs : fn list; exp : Exp.t } [@@deriving sexp]

module Type_safe = struct
type real = float

type _ value =
| Int : int -> int value
| Real : real -> real value
| Bool : bool -> bool value

type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty
type ('a, 'b, 'c) bop = ('a ty * 'b ty * 'c ty) * ('a -> 'b -> 'c)
type ('a, 'b) uop = ('a ty * 'b ty) * ('a -> 'b)

type _ params =
| [] : unit params
Expand All @@ -239,10 +231,10 @@ module Type_safe = struct
| ( :: ) : ('a, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args

and (_, _) exp =
| Value : 'a value -> ('a, _) exp
| Value : 'a -> ('a, _) exp
| Var : Id.t -> _ exp
| Bop : ('a, 'b, 'c) bop * ('a, 'd) texp * ('b, 'd) texp -> ('c, 'd) exp
| Uop : ('a, 'b) uop * ('a, 'd) texp -> ('b, 'd) exp
| Bop : ('a -> 'b -> 'c) * ('a, 'd) texp * ('b, 'd) texp -> ('c, 'd) exp
| Uop : ('a -> 'b) * ('a, 'd) texp -> ('b, 'd) exp
(* TODO: Add list and record constructors *)
(*| List : ('a, 'd) exp list -> ('a list, 'd) exp*)
(*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*)
Expand All @@ -268,58 +260,10 @@ module Type_safe = struct
| [] -> Id.Set.empty
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)

let bop (type a b c) (op : (a, b, c) bop) (v1 : a value) (v2 : b value) :
c value =
match (op, v1, v2) with
| ((Tyi, Tyi, Tyi), op), Int i1, Int i2 -> Int (op i1 i2)
| ((Tyi, Tyi, Tyr), op), Int i1, Int i2 -> Real (op i1 i2)
| ((Tyi, Tyi, Tyb), op), Int i1, Int i2 -> Bool (op i1 i2)
| ((Tyi, Tyr, Tyi), op), Int i, Real r -> Int (op i r)
| ((Tyi, Tyr, Tyr), op), Int i, Real r -> Real (op i r)
| ((Tyi, Tyr, Tyb), op), Int i, Real r -> Bool (op i r)
| ((Tyi, Tyb, Tyr), op), Int i, Bool b -> Real (op i b)
| ((Tyi, Tyb, Tyi), op), Int i, Bool b -> Int (op i b)
| ((Tyi, Tyb, Tyb), op), Int i, Bool b -> Bool (op i b)
| ((Tyr, Tyi, Tyi), op), Real r, Int i -> Int (op r i)
| ((Tyr, Tyi, Tyr), op), Real r, Int i -> Real (op r i)
| ((Tyr, Tyi, Tyb), op), Real r, Int i -> Bool (op r i)
| ((Tyr, Tyr, Tyi), op), Real r1, Real r2 -> Int (op r1 r2)
| ((Tyr, Tyr, Tyr), op), Real r1, Real r2 -> Real (op r1 r2)
| ((Tyr, Tyr, Tyb), op), Real r1, Real r2 -> Bool (op r1 r2)
| ((Tyr, Tyb, Tyi), op), Real r, Bool b -> Int (op r b)
| ((Tyr, Tyb, Tyr), op), Real r, Bool b -> Real (op r b)
| ((Tyr, Tyb, Tyb), op), Real r, Bool b -> Bool (op r b)
| ((Tyb, Tyi, Tyr), op), Bool b, Int i -> Real (op b i)
| ((Tyb, Tyi, Tyi), op), Bool b, Int i -> Int (op b i)
| ((Tyb, Tyi, Tyb), op), Bool b, Int i -> Bool (op b i)
| ((Tyb, Tyr, Tyi), op), Bool b, Real r -> Int (op b r)
| ((Tyb, Tyr, Tyr), op), Bool b, Real r -> Real (op b r)
| ((Tyb, Tyr, Tyb), op), Bool b, Real r -> Bool (op b r)
| ((Tyb, Tyb, Tyi), op), Bool b1, Bool b2 -> Int (op b1 b2)
| ((Tyb, Tyb, Tyr), op), Bool b1, Bool b2 -> Real (op b1 b2)
| ((Tyb, Tyb, Tyb), op), Bool b1, Bool b2 -> Bool (op b1 b2)

let uop (type a b) (op : (a, b) uop) (v : a value) : b value =
match (op, v) with
| ((Tyi, Tyi), op), Int i -> Int (op i)
| ((Tyi, Tyr), op), Int i -> Real (op i)
| ((Tyi, Tyb), op), Int i -> Bool (op i)
| ((Tyr, Tyi), op), Real r -> Int (op r)
| ((Tyr, Tyr), op), Real r -> Real (op r)
| ((Tyr, Tyb), op), Real r -> Bool (op r)
| ((Tyb, Tyi), op), Bool b -> Int (op b)
| ((Tyb, Tyr), op), Bool b -> Real (op b)
| ((Tyb, Tyb), op), Bool b -> Bool (op b)

type _ vargs =
| [] : unit vargs
| ( :: ) : ('a ty * 'a) * 'b vargs -> ('a * 'b) vargs

let varg_of_value : type a. a value -> a ty * a = function
| Int i -> (Tyi, i)
| Real r -> (Tyr, r)
| Bool b -> (Tyb, b)

exception Dist_type_error of string

let get_bernoulli (type a b) (ret : a ty) (vargs : b vargs) : a dist =
Expand Down Expand Up @@ -380,16 +324,16 @@ module Type_safe = struct
| (Value _ | Var _) as e -> e
| Bop (op, te1, te2) -> (
match (peval te1, peval te2) with
| { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (bop op v1 v2)
| { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (op v1 v2)
| te1, te2 -> Bop (op, te1, te2))
| Uop (op, te) -> (
match peval te with
| { exp = Value v; _ } -> Value (uop op v)
| { exp = Value v; _ } -> Value (op v)
| e -> Uop (op, e))
| If (te_pred, te_cons, te_alt) -> (
match peval te_pred with
| { exp = Value (Bool true); _ } -> (peval te_cons).exp
| { exp = Value (Bool false); _ } -> (peval te_alt).exp
| { exp = Value true; _ } -> (peval te_cons).exp
| { exp = Value false; _ } -> (peval te_alt).exp
| te_pred -> If (te_pred, peval te_cons, peval te_alt))
| Call (f, args) -> (
match peval_args args with
Expand All @@ -409,7 +353,7 @@ module Type_safe = struct
| te :: tl -> (
match (peval te, peval_args tl) with
| { ty; exp = Value v }, (tl, Some vargs) ->
({ ty; exp = Value v } :: tl, Some (varg_of_value v :: vargs))
({ ty; exp = Value v } :: tl, Some ((ty, v) :: vargs))
| te, (tl, _) -> (te :: tl, None))

(*let rec convert (exp : Exp.t) : (float, non_det) exp =*)
Expand Down

0 comments on commit 3910718

Please sign in to comment.