From 0a7607bfd6eff58e1002c1e5f8e23580c3a49620 Mon Sep 17 00:00:00 2001 From: Jay Lee Date: Wed, 5 Jun 2024 19:20:43 +0900 Subject: [PATCH] :sparkles: Typesafe det/non_det expressions --- lib/program.ml | 306 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 211 insertions(+), 95 deletions(-) diff --git a/lib/program.ml b/lib/program.ml index 8f9922a..54fa4c7 100644 --- a/lib/program.ml +++ b/lib/program.ml @@ -6,101 +6,6 @@ module Id = struct include String end -module Type_safe = struct - type real = float - - type _ value = - | Int : int -> int value - | Real : real -> real value - | Bool : bool -> bool value - - type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty - type ('a, 'b, 'c) bop = ('a ty * 'b ty * 'c ty) * ('a -> 'b -> 'c) - type ('a, 'b) uop = ('a ty * 'b ty) * ('a -> 'b) - type det = Det - type non_det = Non_det - - type (_, _) exp = - | Value : 'a value -> ('a, _) exp - | Var : Id.t -> _ exp - | Bop : ('a, 'b, 'c) bop * ('a, 'd) exp * ('b, 'd) exp -> ('c, 'd) exp - | Uop : ('a, 'b) uop * ('a, 'd) exp -> ('b, 'd) exp - (* TODO: Add list and record constructors *) - (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) - (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) - | If : (bool, 'd) exp * ('a, 'd) exp * ('a, 'd) exp -> ('a, 'd) exp - | Let : Id.t * ('a, non_det) exp * ('b, non_det) exp -> ('b, non_det) exp - | Call : ('a -> 'b, 'd) exp * ('a, 'd) exp -> ('b, 'd) exp - | Sample : ('a, non_det) exp -> ('a, non_det) exp - | Observe : ('a, non_det) exp * ('a, non_det) exp -> ('a, non_det) exp - - let rec fv : type a. (a, det) exp -> Id.Set.t = function - | Value _ -> Id.Set.empty - | Var x -> Id.Set.singleton x - | Bop (_, e1, e2) -> Id.(fv e1 @| fv e2) - | Uop (_, e) -> fv e - | If (e_pred, e_cons, e_alt) -> Id.(fv e_pred @| fv e_cons @| fv e_alt) - | Call (f, e) -> Id.(fv f @| fv e) - - let bop (type a b c) (op : (a, b, c) bop) (v1 : a value) (v2 : b value) : - c value = - match (op, v1, v2) with - | ((Tyi, Tyi, Tyi), op), Int i1, Int i2 -> Int (op i1 i2) - | ((Tyi, Tyi, Tyr), op), Int i1, Int i2 -> Real (op i1 i2) - | ((Tyi, Tyi, Tyb), op), Int i1, Int i2 -> Bool (op i1 i2) - | ((Tyi, Tyr, Tyi), op), Int i, Real r -> Int (op i r) - | ((Tyi, Tyr, Tyr), op), Int i, Real r -> Real (op i r) - | ((Tyi, Tyr, Tyb), op), Int i, Real r -> Bool (op i r) - | ((Tyi, Tyb, Tyr), op), Int i, Bool b -> Real (op i b) - | ((Tyi, Tyb, Tyi), op), Int i, Bool b -> Int (op i b) - | ((Tyi, Tyb, Tyb), op), Int i, Bool b -> Bool (op i b) - | ((Tyr, Tyi, Tyi), op), Real r, Int i -> Int (op r i) - | ((Tyr, Tyi, Tyr), op), Real r, Int i -> Real (op r i) - | ((Tyr, Tyi, Tyb), op), Real r, Int i -> Bool (op r i) - | ((Tyr, Tyr, Tyi), op), Real r1, Real r2 -> Int (op r1 r2) - | ((Tyr, Tyr, Tyr), op), Real r1, Real r2 -> Real (op r1 r2) - | ((Tyr, Tyr, Tyb), op), Real r1, Real r2 -> Bool (op r1 r2) - | ((Tyr, Tyb, Tyi), op), Real r, Bool b -> Int (op r b) - | ((Tyr, Tyb, Tyr), op), Real r, Bool b -> Real (op r b) - | ((Tyr, Tyb, Tyb), op), Real r, Bool b -> Bool (op r b) - | ((Tyb, Tyi, Tyr), op), Bool b, Int i -> Real (op b i) - | ((Tyb, Tyi, Tyi), op), Bool b, Int i -> Int (op b i) - | ((Tyb, Tyi, Tyb), op), Bool b, Int i -> Bool (op b i) - | ((Tyb, Tyr, Tyi), op), Bool b, Real r -> Int (op b r) - | ((Tyb, Tyr, Tyr), op), Bool b, Real r -> Real (op b r) - | ((Tyb, Tyr, Tyb), op), Bool b, Real r -> Bool (op b r) - | ((Tyb, Tyb, Tyi), op), Bool b1, Bool b2 -> Int (op b1 b2) - | ((Tyb, Tyb, Tyr), op), Bool b1, Bool b2 -> Real (op b1 b2) - | ((Tyb, Tyb, Tyb), op), Bool b1, Bool b2 -> Bool (op b1 b2) - - let uop (type a b) (op : (a, b) uop) (v : a value) : b value = - match (op, v) with - | ((Tyi, Tyi), op), Int i -> Int (op i) - | ((Tyi, Tyr), op), Int i -> Real (op i) - | ((Tyi, Tyb), op), Int i -> Bool (op i) - | ((Tyr, Tyi), op), Real r -> Int (op r) - | ((Tyr, Tyr), op), Real r -> Real (op r) - | ((Tyr, Tyb), op), Real r -> Bool (op r) - | ((Tyb, Tyi), op), Bool b -> Int (op b) - | ((Tyb, Tyr), op), Bool b -> Real (op b) - | ((Tyb, Tyb), op), Bool b -> Bool (op b) - - let rec eval : type a. (a, det) exp -> (a, det) exp = function - | (Value _ | Var _) as e -> e - | Bop (op, e1, e2) -> ( - match (eval e1, eval e2) with - | Value v1, Value v2 -> Value (bop op v1 v2) - | e1, e2 -> Bop (op, e1, e2)) - | Uop (op, e) -> ( - match eval e with Value v -> Value (uop op v) | e -> Uop (op, e)) - | If (e_pred, e_cons, e_alt) -> ( - match eval e_pred with - | Value (Bool true) -> eval e_cons - | Value (Bool false) -> eval e_alt - | e_pred -> If (e_pred, eval e_cons, eval e_alt)) - | Call (f, e) -> Call (f, eval e) -end - module Det_exp = struct type t = | Int of int @@ -298,3 +203,214 @@ end type fn = { name : Id.t; params : Id.t list; body : Exp.t } [@@deriving sexp] type program = { funs : fn list; exp : Exp.t } [@@deriving sexp] + +module Type_safe = struct + type real = float + + type _ value = + | Int : int -> int value + | Real : real -> real value + | Bool : bool -> bool value + + type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty + type ('a, 'b, 'c) bop = ('a ty * 'b ty * 'c ty) * ('a -> 'b -> 'c) + type ('a, 'b) uop = ('a ty * 'b ty) * ('a -> 'b) + + type _ params = + | [] : unit params + | ( :: ) : 'a ty * 'b params -> ('a * 'b) params + + type det = Det + type non_det = Non_det + type 'a sampler = unit -> 'a + type 'a log_pmdf = 'a -> real + + type 'a dist = { + name : Id.t; + ty : 'a ty; + sampler : 'a sampler; + log_pmdf : 'a log_pmdf; + } + + type any_dist = Any_dist : 'a dist -> any_dist + + type (_, _) args = + | [] : (unit, _) args + | ( :: ) : ('a, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args + + and (_, _) exp = + | Value : 'a value -> ('a, _) exp + | Var : Id.t -> _ exp + | Bop : ('a, 'b, 'c) bop * ('a, 'd) texp * ('b, 'd) texp -> ('c, 'd) exp + | Uop : ('a, 'b) uop * ('a, 'd) texp -> ('b, 'd) exp + (* TODO: Add list and record constructors *) + (*| List : ('a, 'd) exp list -> ('a list, 'd) exp*) + (*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*) + | If : (bool, 'd) texp * ('a, 'd) texp * ('a, 'd) texp -> ('a, 'd) exp + | Let : Id.t * ('a, non_det) texp * ('b, non_det) texp -> ('b, non_det) exp + | Call : Id.t * ('a, 'd) args -> ('b, 'd) exp + | Sample : ('a, non_det) texp -> ('a, non_det) exp + | Observe : ('a, non_det) texp * ('a, non_det) texp -> ('a, non_det) exp + | Dist : 'b dist -> ('b, det) exp + + and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp } + + let rec fv : type a. (a, det) exp -> Id.Set.t = function + | Value _ | Dist _ -> Id.Set.empty + | Var x -> Id.Set.singleton x + | Bop (_, { exp = e1; _ }, { exp = e2; _ }) -> Id.(fv e1 @| fv e2) + | Uop (_, { exp; _ }) -> fv exp + | If ({ exp = e_pred; _ }, { exp = e_cons; _ }, { exp = e_alt; _ }) -> + Id.(fv e_pred @| fv e_cons @| fv e_alt) + | Call (_, args) -> fv_args args + + and fv_args : type a. (a, det) args -> Id.Set.t = function + | [] -> Id.Set.empty + | { exp; _ } :: es -> Id.(fv exp @| fv_args es) + + let bop (type a b c) (op : (a, b, c) bop) (v1 : a value) (v2 : b value) : + c value = + match (op, v1, v2) with + | ((Tyi, Tyi, Tyi), op), Int i1, Int i2 -> Int (op i1 i2) + | ((Tyi, Tyi, Tyr), op), Int i1, Int i2 -> Real (op i1 i2) + | ((Tyi, Tyi, Tyb), op), Int i1, Int i2 -> Bool (op i1 i2) + | ((Tyi, Tyr, Tyi), op), Int i, Real r -> Int (op i r) + | ((Tyi, Tyr, Tyr), op), Int i, Real r -> Real (op i r) + | ((Tyi, Tyr, Tyb), op), Int i, Real r -> Bool (op i r) + | ((Tyi, Tyb, Tyr), op), Int i, Bool b -> Real (op i b) + | ((Tyi, Tyb, Tyi), op), Int i, Bool b -> Int (op i b) + | ((Tyi, Tyb, Tyb), op), Int i, Bool b -> Bool (op i b) + | ((Tyr, Tyi, Tyi), op), Real r, Int i -> Int (op r i) + | ((Tyr, Tyi, Tyr), op), Real r, Int i -> Real (op r i) + | ((Tyr, Tyi, Tyb), op), Real r, Int i -> Bool (op r i) + | ((Tyr, Tyr, Tyi), op), Real r1, Real r2 -> Int (op r1 r2) + | ((Tyr, Tyr, Tyr), op), Real r1, Real r2 -> Real (op r1 r2) + | ((Tyr, Tyr, Tyb), op), Real r1, Real r2 -> Bool (op r1 r2) + | ((Tyr, Tyb, Tyi), op), Real r, Bool b -> Int (op r b) + | ((Tyr, Tyb, Tyr), op), Real r, Bool b -> Real (op r b) + | ((Tyr, Tyb, Tyb), op), Real r, Bool b -> Bool (op r b) + | ((Tyb, Tyi, Tyr), op), Bool b, Int i -> Real (op b i) + | ((Tyb, Tyi, Tyi), op), Bool b, Int i -> Int (op b i) + | ((Tyb, Tyi, Tyb), op), Bool b, Int i -> Bool (op b i) + | ((Tyb, Tyr, Tyi), op), Bool b, Real r -> Int (op b r) + | ((Tyb, Tyr, Tyr), op), Bool b, Real r -> Real (op b r) + | ((Tyb, Tyr, Tyb), op), Bool b, Real r -> Bool (op b r) + | ((Tyb, Tyb, Tyi), op), Bool b1, Bool b2 -> Int (op b1 b2) + | ((Tyb, Tyb, Tyr), op), Bool b1, Bool b2 -> Real (op b1 b2) + | ((Tyb, Tyb, Tyb), op), Bool b1, Bool b2 -> Bool (op b1 b2) + + let uop (type a b) (op : (a, b) uop) (v : a value) : b value = + match (op, v) with + | ((Tyi, Tyi), op), Int i -> Int (op i) + | ((Tyi, Tyr), op), Int i -> Real (op i) + | ((Tyi, Tyb), op), Int i -> Bool (op i) + | ((Tyr, Tyi), op), Real r -> Int (op r) + | ((Tyr, Tyr), op), Real r -> Real (op r) + | ((Tyr, Tyb), op), Real r -> Bool (op r) + | ((Tyb, Tyi), op), Bool b -> Int (op b) + | ((Tyb, Tyr), op), Bool b -> Real (op b) + | ((Tyb, Tyb), op), Bool b -> Bool (op b) + + type _ vargs = + | [] : unit vargs + | ( :: ) : ('a ty * 'a) * 'b vargs -> ('a * 'b) vargs + + let varg_of_value : type a. a value -> a ty * a = function + | Int i -> (Tyi, i) + | Real r -> (Tyr, r) + | Bool b -> (Tyb, b) + + exception Dist_type_error of string + + let get_bernoulli (type a b) (ret : a ty) (vargs : b vargs) : a dist = + let open Owl.Stats in + match (ret, vargs) with + | Tyb, [ (Tyr, p) ] -> + { + name = "bernoulli"; + ty = Tyb; + sampler = (fun () -> binomial_rvs ~p ~n:1 = 1); + log_pmdf = (fun b -> binomial_logpdf ~p ~n:1 (Bool.to_int b)); + } + | Tyb, [] -> raise (Dist_type_error "Bernoulli: too few args") + | Tyb, [ (Tyi, i) ] -> + raise (Dist_type_error (sprintf "Bernoulli: got %i expected real" i)) + | Tyb, [ (Tyb, b) ] -> + raise (Dist_type_error (sprintf "Bernoulli: got %b expected real" b)) + | Tyb, _ -> raise (Dist_type_error "Bernoulli: too many arguments") + | _, _ -> raise (Dist_type_error "Bernoulli: should return bool") + + let get_normal (type a b) (ret : a ty) (vargs : b vargs) : a dist = + let open Owl.Stats in + match (ret, vargs) with + | Tyr, [ (Tyr, mu); (Tyr, sigma) ] -> + { + name = "normal"; + ty = Tyr; + sampler = (fun () -> gaussian_rvs ~mu ~sigma); + log_pmdf = gaussian_logpdf ~mu ~sigma; + } + | Tyr, [] | Tyr, [ _ ] -> raise (Dist_type_error "Normal: too few args") + | Tyr, [ (Tyi, i); _ ] -> + raise (Dist_type_error (sprintf "Normal: got %i expected real" i)) + | Tyr, [ (Tyr, _); (Tyi, i) ] -> + raise (Dist_type_error (sprintf "Normal: got %i expected real" i)) + | Tyr, [ (Tyb, b); _ ] -> + raise (Dist_type_error (sprintf "Normal: got %b expected real" b)) + | Tyr, [ (Tyr, _); (Tyb, b) ] -> + raise (Dist_type_error (sprintf "Normal: got %b expected real" b)) + | Tyr, _ -> raise (Dist_type_error "Normal: too many arguments") + | _, _ -> raise (Dist_type_error "Normal: should return real") + + type ('arg, 'k) cont_dist_box = { + k : 'a 'b. ('a params * 'b ty) * ('arg vargs -> 'b dist) -> 'k; + } + + let dist_lookup (name : Id.t) (ret : 'a ty) (vargs : 'b vargs) : 'a dist = + match name with + | "bernoulli" -> get_bernoulli ret vargs + | "normal" -> get_normal ret vargs + (* TODO: Add more distributions *) + | _ -> raise (Invalid_argument "Distribution not found") + + let rec peval : type a. (a, det) texp -> (a, det) texp = + fun { ty; exp } -> + let exp = + match exp with + | (Value _ | Var _) as e -> e + | Bop (op, te1, te2) -> ( + match (peval te1, peval te2) with + | { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (bop op v1 v2) + | te1, te2 -> Bop (op, te1, te2)) + | Uop (op, te) -> ( + match peval te with + | { exp = Value v; _ } -> Value (uop op v) + | e -> Uop (op, e)) + | If (te_pred, te_cons, te_alt) -> ( + match peval te_pred with + | { exp = Value (Bool true); _ } -> (peval te_cons).exp + | { exp = Value (Bool false); _ } -> (peval te_alt).exp + | te_pred -> If (te_pred, peval te_cons, peval te_alt)) + | Call (f, args) -> ( + match peval_args args with + | args, None -> Call (f, args) + | _, Some vargs -> + (* All arguments are fully evaluated; + Go ahead and fully evaluate the (primitive) call. + It is a primitive call as this is a deterministic expression. *) + Dist (dist_lookup f ty vargs)) + | Dist _ as e -> e (* TODO: probably should not be encountered *) + in + { ty; exp } + + and peval_args : type a. (a, det) args -> (a, det) args * a vargs option = + function + | [] -> ([], Some []) + | te :: tl -> ( + match (peval te, peval_args tl) with + | { ty; exp = Value v }, (tl, Some vargs) -> + ({ ty; exp = Value v } :: tl, Some (varg_of_value v :: vargs)) + | te, (tl, _) -> (te :: tl, None)) + + (*let rec convert (exp : Exp.t) : (float, non_det) exp =*) +end