-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
427 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,276 @@ | ||
open Common | ||
open Monad | ||
open TypesCommon | ||
module E = Common.Error.Logger | ||
open Monad.UseMonad(E) | ||
open IrMir | ||
open IrHir | ||
open SailParser | ||
|
||
type mono_body = { | ||
monomorphics : AstMir.mir_function method_defn list; | ||
polymorphics : AstMir.mir_function method_defn list; | ||
processes : (HirUtils.statement,HirUtils.expression) AstParser.process_body process_defn list | ||
} | ||
module E = Common.Error | ||
open Monad.MonadSyntax (E.Logger) | ||
open IrMir.AstMir | ||
open MonomorphizationMonad | ||
module M = MonoMonad | ||
open MonomorphizationUtils | ||
open MonadSyntax(M) | ||
open MonadOperator(M) | ||
open MonadFunctions(M) | ||
|
||
|
||
module Pass = Pass.Make (struct | ||
let name = "Monomorphization" | ||
|
||
type in_body = (AstMir.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes | ||
type out_body = mono_body | ||
type in_body = (MonomorphizationUtils.in_body,(IrHir.HirUtils.statement,IrHir.HirUtils.expression) SailParser.AstParser.process_body) SailModule.methods_processes | ||
type out_body = MonomorphizationUtils.out_body | ||
|
||
module Env = SailModule.DeclEnv | ||
|
||
let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t = | ||
|
||
let mono_exp (e : expression) : sailtype M.t = | ||
let rec aux (e : expression) : sailtype M.t = | ||
match e.exp with | ||
| Variable s -> M.get_var s >>| fun v -> (v |> Option.get |> snd).ty | ||
|
||
| Literal l -> return (sailtype_of_literal l) | ||
|
||
| ArrayRead (e, idx) -> | ||
begin | ||
let* t = aux e in | ||
match t with | ||
| ArrayType (t, _) -> | ||
let+ idx_t = aux idx in | ||
let _ = resolveType idx_t (Int 32) [] [] in | ||
t | ||
| _ -> failwith "cannot happen" | ||
end | ||
| UnOp (_, e) -> aux e | ||
|
||
| BinOp (_, e1, e2) -> | ||
let* t1 = aux e1 in | ||
let+ t2 = aux e2 in | ||
let _ = resolveType t1 t2 [] [] in | ||
t1 | ||
|
||
| Ref (m, e) -> | ||
let+ t = aux e in | ||
RefType (t, m) | ||
|
||
| Deref e -> ( | ||
let+ t = aux e in | ||
match t with | ||
| RefType _ -> t | ||
| _ -> failwith "cannot happen" | ||
) | ||
|
||
| ArrayStatic (e :: h) -> | ||
let* t = aux e in | ||
let+ t = | ||
ListM.fold_left (fun last_t e -> | ||
let+ next_t = aux e in | ||
let _ = resolveType next_t last_t [] [] in | ||
next_t | ||
) t h | ||
in | ||
ArrayType (t, List.length (e :: h)) | ||
|
||
| ArrayStatic [] -> failwith "error : empty array" | ||
| StructAlloc (_, _, _) -> failwith "todo: struct alloc" | ||
| EnumAlloc (_, _) -> failwith "todo: enum alloc" | ||
| StructRead (_, _, _) -> failwith "todo: struct read" | ||
| MethodCall _ -> failwith "no method call at this stage" | ||
in | ||
aux e | ||
in | ||
|
||
let construct_call (calle : string) (el : expression list) : (string * sailtype option) M.t = | ||
(* we construct the types of the args (and collect extra new calls) *) | ||
Logs.debug (fun m -> m "contructing call to %s from %s" calle f.m_proto.name); | ||
let* monos = M.get_monos and* funs = M.get_funs in | ||
Logs.debug (fun m -> m "current monos : %s" (String.concat ";" (List.map ( fun (g,(t:sailor_args)) -> g ^ " -> " ^ (List.map (fun (id,t) -> "(" ^ id ^ "," ^ string_of_sailtype (Some t) ^ ")") t |> String.concat "," )) monos))); | ||
Logs.debug (fun m -> m "current funs : %s" (FieldMap.fold (fun name _ acc -> Fmt.str "%s;%s" name acc) funs "")); | ||
|
||
|
||
let* call_args = | ||
ListM.fold_left | ||
(fun l e -> | ||
Logs.debug (fun m -> m "analyze param expression"); | ||
let* t = mono_exp e in | ||
Logs.debug (fun m -> m "param is %s " @@ string_of_sailtype @@ Some t); | ||
return (t :: l) | ||
) | ||
[] el | ||
in | ||
|
||
(*don't do anything if the function is already added *) | ||
let mname = mangle_method_name calle call_args in | ||
let* funs = M.get_funs in | ||
match FieldMap.find_opt mname funs with | ||
| Some f -> | ||
Logs.debug (fun m -> m "function %s already discovered, skipping" calle); | ||
return (mname,f.methd.m_proto.rtype) | ||
| None -> | ||
begin | ||
let* f = find_callable calle sm |> M.lift in | ||
match f with | ||
| None -> (*import *) return (mname,Some (Int 32) (*fixme*)) | ||
| Some f -> | ||
begin | ||
Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic ); | ||
match f.m_body with | ||
| Right _ -> | ||
(* process and method | ||
we make sure they correspond to what the callable wants | ||
if the callable is generic we check all the generic types are present at least once | ||
we build a (string*sailtype) list of generic to type correspondance | ||
if the generic is not found in the list, we add it with the corresponding type | ||
if the generic already exists with the same type as the new one, we are good else we fail | ||
*) | ||
let* resolved_generics = check_args call_args f |> M.lift in | ||
List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics; | ||
|
||
let* () = M.push_monos calle resolved_generics in | ||
|
||
let* rtype = | ||
match f.m_proto.rtype with | ||
| Some t -> | ||
(* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *) | ||
let+ t = (degenerifyType t resolved_generics|> M.lift) in | ||
(* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *) | ||
Some t | ||
| None -> return None | ||
in | ||
|
||
let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in | ||
let name = mname in | ||
let methd = { f with m_proto = { f.m_proto with rtype ; params } } in | ||
let+ () = | ||
let* f = M.get_decl name (Self Method) in | ||
if Option.is_none f then | ||
M.add_decl name ((dummy_pos,name),(defn_to_proto (Method methd))) Method | ||
else return () | ||
in | ||
mname,rtype | ||
| Left _ -> (* external method *) return (calle,f.m_proto.rtype) | ||
end | ||
end | ||
in | ||
|
||
let rec mono_body (lbl: label) (treated: LabelSet.t) (blocks : (VE.t,unit) basicBlock BlockMap.t): (LabelSet.t * (_,_) basicBlock BlockMap.t) MonoMonad.t = | ||
(* collect calls and name correctly *) | ||
if LabelSet.mem lbl treated then return (treated,blocks) | ||
else | ||
begin | ||
let treated = LabelSet.add lbl treated in | ||
|
||
let bb = BlockMap.find lbl blocks in | ||
let* () = M.set_ve bb.forward_info in | ||
let* () = ListM.iter (fun assign -> mono_exp assign.target >>= fun _ty -> mono_exp assign.expression >>| fun _ty -> ()) bb.assignments | ||
in | ||
|
||
match bb.terminator |> Option.get with | ||
| Return e -> | ||
let+ _ = | ||
begin | ||
match e with | ||
| Some e -> let+ t = mono_exp e in Some t | ||
| None -> return None | ||
end | ||
in treated,blocks | ||
|
||
| Invoke new_f -> | ||
let* (id,_) = construct_call new_f.id new_f.params in | ||
mono_body new_f.next treated BlockMap.(update lbl (fun _ -> Some {bb with terminator=Some (Invoke {new_f with id})}) blocks) | ||
|
||
| Goto lbl -> mono_body lbl treated blocks | ||
|
||
| SwitchInt si -> | ||
let* _ = mono_exp si.choice in | ||
let* treated,blocks = mono_body si.default treated blocks in | ||
ListM.fold_left ( fun (treated,blocks) (_,lbl) -> | ||
mono_body lbl treated blocks | ||
) (treated,blocks) si.paths | ||
|
||
| Break -> failwith "no break should be there" | ||
end | ||
in | ||
|
||
match f.m_body with | ||
| Right (decls,cfg) -> mono_body cfg.input LabelSet.empty cfg.blocks >>= fun (_,blocks) -> | ||
let params = List.map (fun (p:param) -> p.ty) f.m_proto.params in | ||
let name = mangle_method_name f.m_proto.name params in | ||
let methd = {m_proto = f.m_proto; m_body=Right (decls,{cfg with blocks})} in | ||
M.add_fun name {methd; generics=[]} | ||
|
||
| Left _ -> (* external *) return () | ||
|
||
|
||
let analyse_functions (sm : in_body SailModule.t) : unit M.t = | ||
|
||
(* find the function, apply generic substitutions to its signature and monomorphize *) | ||
let find_fun_and_mono (name, (g : sailor_args)) : unit M.t = | ||
let* f = find_callable name sm |> M.lift in | ||
match f with | ||
| None -> (* fixme imports *) return () | ||
| Some f -> | ||
(* monomorphize signature with resolved generics (if any) *) | ||
let* params = ListM.map (fun (p : param) -> let+ ty = degenerifyType p.ty g |> M.lift in {p with ty}) f.m_proto.params in | ||
let* rtype = | ||
match f.m_proto.rtype with | ||
| Some t -> let+ t = degenerifyType t g |> M.lift in Some t | ||
| None -> return None | ||
in | ||
(* update function signature *) | ||
let f = { f with m_proto = { f.m_proto with params; rtype } } in | ||
(* monomorphize, updating env with any new function calls found *) | ||
mono_fun f sm | ||
in | ||
|
||
let rec aux () : unit M.t = | ||
let* empty = M.get_monos >>| (=) [] in | ||
if not empty then (* runs until no more new monomorphic function is found *) | ||
begin | ||
let* name,args = M.pop_monos in | ||
Logs.debug (fun m -> m "looking at function %s with args %s " name (List.map (fun (_,t) -> string_of_sailtype @@ Some t) args |> String.concat " ")); | ||
|
||
let mname = mangle_method_name name (List.split args |> snd) in | ||
|
||
(* we only look at untreated functions *) | ||
let* funs = M.get_funs in | ||
match FieldMap.find_opt mname funs with | ||
| Some _ -> | ||
Logs.debug (fun m -> m "%s already checked" mname); | ||
return () | ||
| None -> | ||
Logs.debug (fun m -> m "analyzing monomorphic function %s" mname); | ||
find_fun_and_mono (name, args) >>= aux | ||
end | ||
else return () | ||
in | ||
let* empty = M.get_monos >>| (=) [] in M.throw_if Error.(make dummy_pos "no monomorphic callable (no main?)") empty >>= aux | ||
|
||
|
||
let transform (smdl : in_body SailModule.t) : out_body SailModule.t E.t = | ||
let polymorphics,monomorphics = List.partition (fun m -> m.m_proto.generics <> []) smdl.body.methods in | ||
return {smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}} | ||
let add_if_mono name args gens else_ret = | ||
let args = List.map (fun (p:param) -> p.id,p.ty) args in | ||
if gens <> [] then M.pure [else_ret] else M.push_monos name args >>| fun () -> [] | ||
in | ||
|
||
|
||
let mono_poly = (* add monomorphics to the env and collect generic methods *) | ||
M.pure [] | ||
(* our entry points are non generic methods and processes *) | ||
>>= fun l -> ListM.fold_left (fun acc m -> add_if_mono m.m_proto.name m.m_proto.params m.m_proto.generics m >>| Fun.flip List.append acc) l smdl.body.methods | ||
(* | ||
analyze them, find and resolve calls to generic functions | ||
IMPORTANT : we must keep the generic functions : if one of them | ||
is called from an other module and we don't have a monomorphic versio, we must generate one using the generic version | ||
*) | ||
>>= fun l -> analyse_functions smdl >>| fun () -> l | ||
in | ||
|
||
let open MonadSyntax(E) in | ||
let+ polymorphics,mono_env = M.run smdl.declEnv mono_poly in | ||
Logs.info (fun m -> m "generated %i monomorphic functions : " (List.length (FieldMap.bindings mono_env.functions))); | ||
FieldMap.iter print_method_proto mono_env.functions; | ||
let monomorphics = List.filter (fun m -> Either.is_left m.m_body) smdl.body.methods |> FieldMap.fold (fun name f acc -> {f.methd with m_proto={f.methd.m_proto with name}}::acc) mono_env.functions in | ||
|
||
{smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}} | ||
end) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
open Common | ||
open Monad | ||
open TypesCommon | ||
open MonomorphizationUtils | ||
|
||
type env = {monos: monomorphics; functions : sailor_functions; env: varTypesMap} | ||
|
||
module MonoMonad = struct | ||
module S = MonadState.T(Error.Logger)(struct type t = env end) | ||
open MonadSyntax(S) | ||
open MonadOperator(S) | ||
include S | ||
(* error *) | ||
let throw e = E.throw e |> lift | ||
let throw_if e c = E.throw_if e c |> lift | ||
|
||
let get_decl id ty = get >>| fun e -> Env.get_decl id ty e.env | ||
let add_decl id decl ty = update (fun e -> E.bind (Env.add_decl id decl ty e.env) (fun env -> E.pure {e with env})) | ||
let get_var id = get >>| fun e -> Env.get_var id e.env | ||
let set_ve ve = update (fun e -> E.pure {e with env=(ve,snd e.env)}) | ||
|
||
|
||
let add_fun mname (f: 'a sailor_method) = S.update (fun e -> E.pure {e with functions=FieldMap.add mname f e.functions}) | ||
let get_funs = let+ e = S.get in e.functions | ||
|
||
let push_monos name generics = S.update (fun e -> E.pure {e with monos=(name,generics)::e.monos}) | ||
let get_monos = let+ e = S.get in e.monos | ||
let pop_monos = let* e = S.get in | ||
match e.monos with | ||
| [] -> throw Error.(make dummy_pos "empty_monos") | ||
| h::monos -> S.set {e with monos} >>| fun () -> h | ||
|
||
let run (decls:Env.D.t) (x: 'a t) : ('a * env) E.t = x {monos=[];functions=FieldMap.empty;env=Env.empty decls} | ||
|
||
end | ||
|
||
|
||
let mangle_method_name (name : string) (args : sailtype list) : string = | ||
let back = | ||
List.fold_left (fun s t -> s ^ string_of_sailtype (Some t) ^ "_") "" args | ||
in | ||
let front = "_" ^ name ^ "_" in | ||
let res = front ^ back in | ||
Logs.debug (fun m -> m "renamed %s to %s" name res); | ||
res |
Oops, something went wrong.