Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement primitive typechecker #11

Merged
merged 3 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions bin/main.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ let command : Command.t =
and pp_opt = flag "-pp" no_arg ~doc:" Pretty print the program"
and graph_opt = flag "-graph" no_arg ~doc:" Print the compiled graph" in
fun () ->
let open Typedprog in
if pp_opt then (
printf "Pretty-print: %s\n" filename;
print_s [%sexp (get_program filename : Program.program)]);
Expand All @@ -46,16 +47,16 @@ let command : Command.t =
if pp_opt then printf "\n";
printf "Compile: %s\n" filename;
Out_channel.flush stdout;
let graph, query = get_program filename |> Compiler.compile in
let graph, query = get_program filename |> Compiler.compile_program in
graph_query := Some (graph, query);
print_s [%sexp (graph : Graph.t)]);

print_s [%sexp (Printing.of_graph graph : Printing.graph)]);
if pp_opt || graph_opt then printf "\n";
printf "Inference: %s\n" filename;
Out_channel.flush stdout;
let graph, query =
!graph_query
|> Option.value ~default:(get_program filename |> Compiler.compile)
|> Option.value
~default:(get_program filename |> Compiler.compile_program)
in
printf "Query result saved at %s\n"
(Evaluator.infer ~filename graph query))
Expand Down
2 changes: 2 additions & 0 deletions lib/lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ let keywords =
("in", IN);
("sample", SAMPLE);
("observe", OBSERVE);
("true", BOOL true);
("false", BOOL false);
]
}

Expand Down
2 changes: 2 additions & 0 deletions lib/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ open Program

%token <int> INT
%token <float> REAL
%token <bool> BOOL
%token <string> ID
%token IF THEN ELSE FUN LET IN
%token PLUS MINUS NEG MULT DIV RPLUS RMINUS RNEG RMULT RDIV EQ NE LT GT RLT RGT AND OR NOT
Expand Down Expand Up @@ -36,6 +37,7 @@ exp:
| LPAREN; e = exp; RPAREN { e }
| i = INT { Int i }
| r = REAL { Real r }
| b = BOOL { Bool b }
| x = ID { Var x }
| f = ID; LPAREN; es = args; RPAREN { Call (f, es) }
| IF; e_pred = exp; THEN; e_con = exp; ELSE; e_alt = exp { If (e_pred, e_con, e_alt) }
Expand Down
211 changes: 0 additions & 211 deletions lib/program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -203,214 +203,3 @@ end

type fn = { name : Id.t; params : Id.t list; body : Exp.t } [@@deriving sexp]
type program = { funs : fn list; exp : Exp.t } [@@deriving sexp]

module Type_safe = struct
type real = float

type _ value =
| Int : int -> int value
| Real : real -> real value
| Bool : bool -> bool value

type _ ty = Tyi : int ty | Tyr : real ty | Tyb : bool ty
type ('a, 'b, 'c) bop = ('a ty * 'b ty * 'c ty) * ('a -> 'b -> 'c)
type ('a, 'b) uop = ('a ty * 'b ty) * ('a -> 'b)

type _ params =
| [] : unit params
| ( :: ) : 'a ty * 'b params -> ('a * 'b) params

type det = Det
type non_det = Non_det
type 'a sampler = unit -> 'a
type 'a log_pmdf = 'a -> real

type 'a dist = {
name : Id.t;
ty : 'a ty;
sampler : 'a sampler;
log_pmdf : 'a log_pmdf;
}

type any_dist = Any_dist : 'a dist -> any_dist

type (_, _) args =
| [] : (unit, _) args
| ( :: ) : ('a, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args

and (_, _) exp =
| Value : 'a value -> ('a, _) exp
| Var : Id.t -> _ exp
| Bop : ('a, 'b, 'c) bop * ('a, 'd) texp * ('b, 'd) texp -> ('c, 'd) exp
| Uop : ('a, 'b) uop * ('a, 'd) texp -> ('b, 'd) exp
(* TODO: Add list and record constructors *)
(*| List : ('a, 'd) exp list -> ('a list, 'd) exp*)
(*| Record : ('k * 'v, 'd) exp list -> ('k * 'v, 'd) exp*)
| If : (bool, 'd) texp * ('a, 'd) texp * ('a, 'd) texp -> ('a, 'd) exp
| Let : Id.t * ('a, non_det) texp * ('b, non_det) texp -> ('b, non_det) exp
| Call : Id.t * ('a, 'd) args -> ('b, 'd) exp
| Sample : ('a, non_det) texp -> ('a, non_det) exp
| Observe : ('a, non_det) texp * ('a, non_det) texp -> ('a, non_det) exp
| Dist : 'b dist -> ('b, det) exp

and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp }

let rec fv : type a. (a, det) exp -> Id.Set.t = function
| Value _ | Dist _ -> Id.Set.empty
| Var x -> Id.Set.singleton x
| Bop (_, { exp = e1; _ }, { exp = e2; _ }) -> Id.(fv e1 @| fv e2)
| Uop (_, { exp; _ }) -> fv exp
| If ({ exp = e_pred; _ }, { exp = e_cons; _ }, { exp = e_alt; _ }) ->
Id.(fv e_pred @| fv e_cons @| fv e_alt)
| Call (_, args) -> fv_args args

and fv_args : type a. (a, det) args -> Id.Set.t = function
| [] -> Id.Set.empty
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)

let bop (type a b c) (op : (a, b, c) bop) (v1 : a value) (v2 : b value) :
c value =
match (op, v1, v2) with
| ((Tyi, Tyi, Tyi), op), Int i1, Int i2 -> Int (op i1 i2)
| ((Tyi, Tyi, Tyr), op), Int i1, Int i2 -> Real (op i1 i2)
| ((Tyi, Tyi, Tyb), op), Int i1, Int i2 -> Bool (op i1 i2)
| ((Tyi, Tyr, Tyi), op), Int i, Real r -> Int (op i r)
| ((Tyi, Tyr, Tyr), op), Int i, Real r -> Real (op i r)
| ((Tyi, Tyr, Tyb), op), Int i, Real r -> Bool (op i r)
| ((Tyi, Tyb, Tyr), op), Int i, Bool b -> Real (op i b)
| ((Tyi, Tyb, Tyi), op), Int i, Bool b -> Int (op i b)
| ((Tyi, Tyb, Tyb), op), Int i, Bool b -> Bool (op i b)
| ((Tyr, Tyi, Tyi), op), Real r, Int i -> Int (op r i)
| ((Tyr, Tyi, Tyr), op), Real r, Int i -> Real (op r i)
| ((Tyr, Tyi, Tyb), op), Real r, Int i -> Bool (op r i)
| ((Tyr, Tyr, Tyi), op), Real r1, Real r2 -> Int (op r1 r2)
| ((Tyr, Tyr, Tyr), op), Real r1, Real r2 -> Real (op r1 r2)
| ((Tyr, Tyr, Tyb), op), Real r1, Real r2 -> Bool (op r1 r2)
| ((Tyr, Tyb, Tyi), op), Real r, Bool b -> Int (op r b)
| ((Tyr, Tyb, Tyr), op), Real r, Bool b -> Real (op r b)
| ((Tyr, Tyb, Tyb), op), Real r, Bool b -> Bool (op r b)
| ((Tyb, Tyi, Tyr), op), Bool b, Int i -> Real (op b i)
| ((Tyb, Tyi, Tyi), op), Bool b, Int i -> Int (op b i)
| ((Tyb, Tyi, Tyb), op), Bool b, Int i -> Bool (op b i)
| ((Tyb, Tyr, Tyi), op), Bool b, Real r -> Int (op b r)
| ((Tyb, Tyr, Tyr), op), Bool b, Real r -> Real (op b r)
| ((Tyb, Tyr, Tyb), op), Bool b, Real r -> Bool (op b r)
| ((Tyb, Tyb, Tyi), op), Bool b1, Bool b2 -> Int (op b1 b2)
| ((Tyb, Tyb, Tyr), op), Bool b1, Bool b2 -> Real (op b1 b2)
| ((Tyb, Tyb, Tyb), op), Bool b1, Bool b2 -> Bool (op b1 b2)

let uop (type a b) (op : (a, b) uop) (v : a value) : b value =
match (op, v) with
| ((Tyi, Tyi), op), Int i -> Int (op i)
| ((Tyi, Tyr), op), Int i -> Real (op i)
| ((Tyi, Tyb), op), Int i -> Bool (op i)
| ((Tyr, Tyi), op), Real r -> Int (op r)
| ((Tyr, Tyr), op), Real r -> Real (op r)
| ((Tyr, Tyb), op), Real r -> Bool (op r)
| ((Tyb, Tyi), op), Bool b -> Int (op b)
| ((Tyb, Tyr), op), Bool b -> Real (op b)
| ((Tyb, Tyb), op), Bool b -> Bool (op b)

type _ vargs =
| [] : unit vargs
| ( :: ) : ('a ty * 'a) * 'b vargs -> ('a * 'b) vargs

let varg_of_value : type a. a value -> a ty * a = function
| Int i -> (Tyi, i)
| Real r -> (Tyr, r)
| Bool b -> (Tyb, b)

exception Dist_type_error of string

let get_bernoulli (type a b) (ret : a ty) (vargs : b vargs) : a dist =
let open Owl.Stats in
match (ret, vargs) with
| Tyb, [ (Tyr, p) ] ->
{
name = "bernoulli";
ty = Tyb;
sampler = (fun () -> binomial_rvs ~p ~n:1 = 1);
log_pmdf = (fun b -> binomial_logpdf ~p ~n:1 (Bool.to_int b));
}
| Tyb, [] -> raise (Dist_type_error "Bernoulli: too few args")
| Tyb, [ (Tyi, i) ] ->
raise (Dist_type_error (sprintf "Bernoulli: got %i expected real" i))
| Tyb, [ (Tyb, b) ] ->
raise (Dist_type_error (sprintf "Bernoulli: got %b expected real" b))
| Tyb, _ -> raise (Dist_type_error "Bernoulli: too many arguments")
| _, _ -> raise (Dist_type_error "Bernoulli: should return bool")

let get_normal (type a b) (ret : a ty) (vargs : b vargs) : a dist =
let open Owl.Stats in
match (ret, vargs) with
| Tyr, [ (Tyr, mu); (Tyr, sigma) ] ->
{
name = "normal";
ty = Tyr;
sampler = (fun () -> gaussian_rvs ~mu ~sigma);
log_pmdf = gaussian_logpdf ~mu ~sigma;
}
| Tyr, [] | Tyr, [ _ ] -> raise (Dist_type_error "Normal: too few args")
| Tyr, [ (Tyi, i); _ ] ->
raise (Dist_type_error (sprintf "Normal: got %i expected real" i))
| Tyr, [ (Tyr, _); (Tyi, i) ] ->
raise (Dist_type_error (sprintf "Normal: got %i expected real" i))
| Tyr, [ (Tyb, b); _ ] ->
raise (Dist_type_error (sprintf "Normal: got %b expected real" b))
| Tyr, [ (Tyr, _); (Tyb, b) ] ->
raise (Dist_type_error (sprintf "Normal: got %b expected real" b))
| Tyr, _ -> raise (Dist_type_error "Normal: too many arguments")
| _, _ -> raise (Dist_type_error "Normal: should return real")

type ('arg, 'k) cont_dist_box = {
k : 'a 'b. ('a params * 'b ty) * ('arg vargs -> 'b dist) -> 'k;
}

let dist_lookup (name : Id.t) (ret : 'a ty) (vargs : 'b vargs) : 'a dist =
match name with
| "bernoulli" -> get_bernoulli ret vargs
| "normal" -> get_normal ret vargs
(* TODO: Add more distributions *)
| _ -> raise (Invalid_argument "Distribution not found")

let rec peval : type a. (a, det) texp -> (a, det) texp =
fun { ty; exp } ->
let exp =
match exp with
| (Value _ | Var _) as e -> e
| Bop (op, te1, te2) -> (
match (peval te1, peval te2) with
| { exp = Value v1; _ }, { exp = Value v2; _ } -> Value (bop op v1 v2)
| te1, te2 -> Bop (op, te1, te2))
| Uop (op, te) -> (
match peval te with
| { exp = Value v; _ } -> Value (uop op v)
| e -> Uop (op, e))
| If (te_pred, te_cons, te_alt) -> (
match peval te_pred with
| { exp = Value (Bool true); _ } -> (peval te_cons).exp
| { exp = Value (Bool false); _ } -> (peval te_alt).exp
| te_pred -> If (te_pred, peval te_cons, peval te_alt))
| Call (f, args) -> (
match peval_args args with
| args, None -> Call (f, args)
| _, Some vargs ->
(* All arguments are fully evaluated;
Go ahead and fully evaluate the (primitive) call.
It is a primitive call as this is a deterministic expression. *)
Dist (dist_lookup f ty vargs))
| Dist _ as e -> e (* TODO: probably should not be encountered *)
in
{ ty; exp }

and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
function
| [] -> ([], Some [])
| te :: tl -> (
match (peval te, peval_args tl) with
| { ty; exp = Value v }, (tl, Some vargs) ->
({ ty; exp = Value v } :: tl, Some (varg_of_value v :: vargs))
| te, (tl, _) -> (te :: tl, None))

(*let rec convert (exp : Exp.t) : (float, non_det) exp =*)
end
Loading