From 83c871e9cdfd91c908e641a50f811fe5c3a67b53 Mon Sep 17 00:00:00 2001 From: Enrico Tassi Date: Tue, 15 Oct 2024 21:45:38 +0200 Subject: [PATCH] wip --- src/compiler.ml | 77 ++++++++++++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/src/compiler.ml b/src/compiler.ml index 62420d59e..e652d6701 100644 --- a/src/compiler.ml +++ b/src/compiler.ml @@ -666,6 +666,7 @@ module MutableOnce : sig type 'a t [@@ deriving show] val make : F.t -> 'a t + val create : 'a -> 'a t val set : 'a t -> 'a -> unit val unset : 'a t -> unit val get : 'a t -> 'a @@ -677,6 +678,8 @@ end = struct let make f = f, ref None + let create t = F.from_string "_", ref (Some t) + let is_set (_,x) = Option.is_some !x let set (_,r) x = match !r with @@ -781,6 +784,7 @@ module ScopedTerm = struct | App of scope * F.t * t * t list | Lam of F.t option * t | CData of CData.t + | Spill of t * int (* 0 is the original, 1.. its phantoms *) and t = { it : t_; loc : Loc.t; ty : TypeAssignment.t MutableOnce.t } [@@ deriving show] @@ -797,6 +801,8 @@ module ScopedTerm = struct | App(_,f,x,xs) -> fprintf fmt "(%a %a)" F.pp f (Util.pplist pretty " ") (x::xs) | Var(f,xs) -> fprintf fmt "(%a %a)" F.pp f (Util.pplist pretty " ") xs | CData c -> fprintf fmt "%a" CData.pp c + | Spill (t,0) -> fprintf fmt "{%a}" pretty t + | Spill (t,n) -> fprintf fmt "{%a}_%d" pretty t n let equal t1 t2 = @@ -809,6 +815,7 @@ module ScopedTerm = struct | App(Global,c1,x,xs), App(Global,c2,y,ys) -> F.equal c1 c2 && eq ctx x y && Util.for_all2 (eq ctx) xs ys | App(Local,c1,x,xs), App(Local,c2,y,ys) -> eq_var ctx c1 c2 && eq ctx x y && Util.for_all2 (eq ctx) xs ys | Lam(None,b1), Lam (None, b2) -> eq ctx b1 b2 + | Spill(b1,n1), Spill (b2,n2) -> n1 == n2 && eq ctx b1 b2 | Lam(Some c1,b1), Lam(Some c2, b2) -> eq (push_ctx c1 c2 ctx) b1 b2 | CData c1, CData c2 -> CData.equal c1 c2 | _ -> false @@ -1010,16 +1017,14 @@ end = struct | [x], _:: ys -> x :: extend [x] ys | x::xs, _::ys -> x :: extend [x] ys - type args_classification = Spill | NoSpill - - let classify_one_arg { it } = + let is_spill { it } = match it with - | App(Global,c,_,[]) when F.equal c F.spillf -> Spill - | _ -> NoSpill + | Spill _ -> true + | _ -> false - let rec classify_args = function - | [] -> NoSpill - | x :: xs -> if classify_one_arg x = Spill then Spill else classify_args xs + let rec any_arg_is_spill = function + | [] -> false + | x :: xs -> is_spill x || any_arg_is_spill xs let check ~type_abbrevs ~kinds ~env (t : ScopedTerm.t) ~(exp : TypeAssignment.t) = @@ -1027,22 +1032,13 @@ end = struct let needs_spill = ref false in let sigma : TypeAssignment.t F.Map.t ref = ref F.Map.empty in let fresh_name = let i = ref 0 in fun () -> incr i; F.from_string ("%dummy"^ string_of_int !i) in - let rec check ctx ~loc ~tyctx x (ety : ret) : ret list = + let rec check ctx ~loc ~tyctx x (ety : ret) : ScopedTerm.t list = + (* Format.eprintf "checking %a\n" ScopedTerm.pretty_ x; *) match x with | Const(Global,c) -> check_global ctx ~loc ~tyctx c ety | Const(Local,c) -> check_local ctx ~loc ~tyctx c ety | CData c -> check_cdata ~loc ~tyctx kinds c ety - | App(Global,c,sp,[]) when F.equal c F.spillf -> - needs_spill := true; - let inner_spills = check_loc ~tyctx:None ctx sp ~ety:(mk_uvar "Spill") in (* TODO?? *) - begin match classify_arrow (ScopedTerm.type_of sp) with - | Simple { srcs; tgt } -> - let spills = srcs @ [tgt] in - let hd = List.hd spills in - if unify hd ety then List.tl spills - else error_bad_ety ~tyctx ~loc ~ety ScopedTerm.pretty_ x hd - | _ -> error ~loc "hard spill" - end + | Spill(sp,n) -> assert(n=0); check_spill ctx ~loc ~tyctx sp ety | App(Global,c,x,xs) -> check_app ctx ~loc ~tyctx c (global_type env ~loc c) (x::xs) ety | App(Local,c,x,xs) -> check_app ctx ~loc ~tyctx c (local_type ctx ~loc c) (x::xs) ety | Lam(c,t) -> check_lam ctx ~loc ~tyctx c t ety @@ -1081,12 +1077,24 @@ end = struct else error_bad_function_ety ~loc ~tyctx ~ety c t - and todo_spill _ = () + and check_spill ctx ~loc ~tyctx sp ety = + needs_spill := true; + let inner_spills = check_loc ~tyctx:None ctx sp ~ety:(mk_uvar "Spill") in (* TODO?? *) + let phantom_of_spill_ty i ty = + { loc; it = Spill(sp,i+1); ty = MutableOnce.create (TypeAssignment.Val ty) } in + match classify_arrow (ScopedTerm.type_of sp) with + | Simple { srcs; tgt } -> + if not @@ unify tgt Prop then error ~loc "only predicated can be spilled"; + let spills = srcs in + let first_spill = List.hd spills in + if unify first_spill ety then List.mapi phantom_of_spill_ty @@ List.tl spills + else error_bad_ety ~tyctx ~loc ~ety ScopedTerm.pretty_ (Spill(sp,0)) first_spill + | _ -> error ~loc "hard spill" and check_app ctx ~loc ~tyctx c cty args ety = match cty with | Overloaded l -> - List.iter (fun x -> todo_spill @@ check_loc ~tyctx:None ctx ~ety:(mk_uvar "Ety") x) args; + let args = List.concat_map (fun x -> x :: check_loc ~tyctx:None ctx ~ety:(mk_uvar "Ety") x) args in let targs = List.map ScopedTerm.type_of args in check_app_overloaded ctx ~loc c ety args targs l l | Single ty -> @@ -1094,10 +1102,12 @@ end = struct if args = [] then error_bad_ety ~loc ~tyctx ~ety F.pp c ty (* uvar *) else error_bad_ety ~loc ~tyctx ~ety ScopedTerm.pretty_ (App(Global (* sucks *),c,List.hd args,List.tl args)) ty in let monodirectional () = + (* Format.eprintf "checking app mono %a\n" F.pp c; *) let tgt = check_app_single ctx ~loc c ty [] args in if unify tgt ety then [] else err () in let bidirectional srcs tgt = + (* Format.eprintf "checking app bidi %a\n" F.pp c; *) let rec consume args srcs = match args, srcs with | [], srcs -> arrow_of_tys srcs tgt @@ -1113,15 +1123,15 @@ end = struct | Simple { srcs; tgt } -> if List.length args > List.length srcs then monodirectional () (* will error *) else - match classify_args args with - | Spill -> monodirectional () - | NoSpill -> bidirectional srcs tgt + if any_arg_is_spill args then monodirectional () + else bidirectional srcs tgt (* XXX ... look at args, is no spill then build arrow using srcs -> tgt - args .. *) and check_app_overloaded ctx ~loc c ety args targs alltys = function | [] -> error_overloaded_app ~loc c args ~ety alltys | t::ts -> + (* Format.eprintf "checking overloaded app %a\n" F.pp c; *) match classify_arrow t with | Unknown -> error ~loc "Type too ambiguous to be assigned to an overloaded constant" | Simple { srcs; tgt } -> @@ -1139,10 +1149,10 @@ end = struct (* Format.eprintf "checking app %a @ %a\n" F.pp c ScopedTerm.pretty x; *) match ty with | TypeAssignment.Arr(Ast.Structured.Variadic,s,t) -> - todo_spill @@ check_loc ~tyctx:(Some c) ctx x ~ety:s; + let xs = check_loc_if_not_phantom ~tyctx:(Some c) ctx x ~ety:s @ xs in if xs = [] then t else check_app_single ctx ~loc c ty (x::consumed) xs | TypeAssignment.Arr(Ast.Structured.NotVariadic,s,t) -> - todo_spill @@ check_loc ~tyctx:(Some c) ctx x ~ety:s; + let xs = check_loc_if_not_phantom ~tyctx:(Some c) ctx x ~ety:s @ xs in check_app_single ctx ~loc c t (x::consumed) xs | TypeAssignment.UVar m when MutableOnce.is_set m -> check_app_single ctx ~loc c (TypeAssignment.deref m) consumed (x :: xs) @@ -1152,12 +1162,17 @@ end = struct check_app_single ctx ~loc c (TypeAssignment.Arr(Ast.Structured.NotVariadic,s,t)) consumed (x :: xs) | _ -> error_not_a_function ~loc:x.loc c (List.rev consumed) x (* TODO: trim loc up to x *) - and check_loc ~tyctx ctx { loc; it; ty } ~ety : ret list = + and check_loc ~tyctx ctx { loc; it; ty } ~ety : ScopedTerm.t list = assert (not @@ MutableOnce.is_set ty); let extra_spill = check ~tyctx ctx ~loc it ety in MutableOnce.set ty (Val ety); extra_spill - + + and check_loc_if_not_phantom ~tyctx ctx x ~ety : ScopedTerm.t list = + match x.it with + | Spill(_,n) when n > 0 -> [] + | _ -> check_loc ~tyctx ctx x ~ety + and check_matches_poly_skema_loc { loc; it } = let c, args = match it with @@ -1778,6 +1793,8 @@ end = struct if is_uvar_name c then ScopedTerm.Var(c,[]) else ScopedTerm.(Const(Global,c)) | App ({ it = App (f,l1) },l2) -> scope_term ctx ~loc (App(f, l1 @ l2)) + | App({ it = Const c }, [x]) when F.equal c F.spillf -> + ScopedTerm.Spill (scope_loc_term ctx x,0) | App({ it = Const c }, x :: xs) -> if is_discard c then error ~loc "Applied discard"; let x = scope_loc_term ctx x in @@ -3006,6 +3023,7 @@ end (* }}} *) match it with | Const(Local,_) -> it | Const(Global,c) -> let c' = f c in if c == c' then it else Const(Global,c') + | Spill(t,n) -> let t' = aux_loc t in if t' == t then it else Spill(t',n) | App(scope,c,x,xs) -> let c' = if scope = Global then f c else c in let x' = aux_loc x in @@ -3583,6 +3601,7 @@ end = struct let rec todbl ctx t = match t.it with | CData c -> D.mkCData (CData.hcons c) + | Spill(t,n) -> assert(n=0); assert false (* lists *) | Const(Global,c) when F.(equal c nilf) -> D.mkNil | App(Global,c,x,[y]) when F.(equal c consf) ->