Skip to content

Commit

Permalink
function monomorphization
Browse files Browse the repository at this point in the history
  • Loading branch information
terencode committed Aug 24, 2023
1 parent 2e2deb6 commit 68eb2b6
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 18 deletions.
280 changes: 265 additions & 15 deletions src/passes/monomorphization/monomorphization.ml
Original file line number Diff line number Diff line change
@@ -1,26 +1,276 @@
open Common
open Monad
open TypesCommon
module E = Common.Error.Logger
open Monad.UseMonad(E)
open IrMir
open IrHir
open SailParser

type mono_body = {
monomorphics : AstMir.mir_function method_defn list;
polymorphics : AstMir.mir_function method_defn list;
processes : (HirUtils.statement,HirUtils.expression) AstParser.process_body process_defn list
}
module E = Common.Error
open Monad.MonadSyntax (E.Logger)
open IrMir.AstMir
open MonomorphizationMonad
module M = MonoMonad
open MonomorphizationUtils
open MonadSyntax(M)
open MonadOperator(M)
open MonadFunctions(M)


module Pass = Pass.Make (struct
let name = "Monomorphization"

type in_body = (AstMir.mir_function,(HirUtils.statement,HirUtils.expression) AstParser.process_body) SailModule.methods_processes
type out_body = mono_body
type in_body = (MonomorphizationUtils.in_body,(IrHir.HirUtils.statement,IrHir.HirUtils.expression) SailParser.AstParser.process_body) SailModule.methods_processes
type out_body = MonomorphizationUtils.out_body

module Env = SailModule.DeclEnv

let mono_fun (f : sailor_function) (sm : in_body SailModule.t) : unit M.t =

let mono_exp (e : expression) : sailtype M.t =
let rec aux (e : expression) : sailtype M.t =
match e.exp with
| Variable s -> M.get_var s >>| fun v -> (v |> Option.get |> snd).ty

| Literal l -> return (sailtype_of_literal l)

| ArrayRead (e, idx) ->
begin
let* t = aux e in
match t with
| ArrayType (t, _) ->
let+ idx_t = aux idx in
let _ = resolveType idx_t (Int 32) [] [] in
t
| _ -> failwith "cannot happen"
end
| UnOp (_, e) -> aux e

| BinOp (_, e1, e2) ->
let* t1 = aux e1 in
let+ t2 = aux e2 in
let _ = resolveType t1 t2 [] [] in
t1

| Ref (m, e) ->
let+ t = aux e in
RefType (t, m)

| Deref e -> (
let+ t = aux e in
match t with
| RefType _ -> t
| _ -> failwith "cannot happen"
)

| ArrayStatic (e :: h) ->
let* t = aux e in
let+ t =
ListM.fold_left (fun last_t e ->
let+ next_t = aux e in
let _ = resolveType next_t last_t [] [] in
next_t
) t h
in
ArrayType (t, List.length (e :: h))

| ArrayStatic [] -> failwith "error : empty array"
| StructAlloc (_, _, _) -> failwith "todo: struct alloc"
| EnumAlloc (_, _) -> failwith "todo: enum alloc"
| StructRead (_, _, _) -> failwith "todo: struct read"
| MethodCall _ -> failwith "no method call at this stage"
in
aux e
in

let construct_call (calle : string) (el : expression list) : (string * sailtype option) M.t =
(* we construct the types of the args (and collect extra new calls) *)
Logs.debug (fun m -> m "contructing call to %s from %s" calle f.m_proto.name);
let* monos = M.get_monos and* funs = M.get_funs in
Logs.debug (fun m -> m "current monos : %s" (String.concat ";" (List.map ( fun (g,(t:sailor_args)) -> g ^ " -> " ^ (List.map (fun (id,t) -> "(" ^ id ^ "," ^ string_of_sailtype (Some t) ^ ")") t |> String.concat "," )) monos)));
Logs.debug (fun m -> m "current funs : %s" (FieldMap.fold (fun name _ acc -> Fmt.str "%s;%s" name acc) funs ""));


let* call_args =
ListM.fold_left
(fun l e ->
Logs.debug (fun m -> m "analyze param expression");
let* t = mono_exp e in
Logs.debug (fun m -> m "param is %s " @@ string_of_sailtype @@ Some t);
return (t :: l)
)
[] el
in

(*don't do anything if the function is already added *)
let mname = mangle_method_name calle call_args in
let* funs = M.get_funs in
match FieldMap.find_opt mname funs with
| Some f ->
Logs.debug (fun m -> m "function %s already discovered, skipping" calle);
return (mname,f.methd.m_proto.rtype)
| None ->
begin
let* f = find_callable calle sm |> M.lift in
match f with
| None -> (*import *) return (mname,Some (Int 32) (*fixme*))
| Some f ->
begin
Logs.debug (fun m -> m "found call to %s, variadic : %b" f.m_proto.name f.m_proto.variadic );
match f.m_body with
| Right _ ->
(* process and method
we make sure they correspond to what the callable wants
if the callable is generic we check all the generic types are present at least once
we build a (string*sailtype) list of generic to type correspondance
if the generic is not found in the list, we add it with the corresponding type
if the generic already exists with the same type as the new one, we are good else we fail
*)
let* resolved_generics = check_args call_args f |> M.lift in
List.iter (fun (n, t) -> Logs.debug (fun m -> m "resolved %s to %s " n (string_of_sailtype (Some t)))) resolved_generics;

let* () = M.push_monos calle resolved_generics in

let* rtype =
match f.m_proto.rtype with
| Some t ->
(* Logs.warn (fun m -> m "TYPE BEFORE : %s" (string_of_sailtype (Some t))); *)
let+ t = (degenerifyType t resolved_generics|> M.lift) in
(* Logs.warn (fun m -> m "TYPE AFTER : %s" (string_of_sailtype (Some t))); *)
Some t
| None -> return None
in

let params = List.map2 (fun (p:param) ty -> {p with ty}) f.m_proto.params call_args in
let name = mname in
let methd = { f with m_proto = { f.m_proto with rtype ; params } } in
let+ () =
let* f = M.get_decl name (Self Method) in
if Option.is_none f then
M.add_decl name ((dummy_pos,name),(defn_to_proto (Method methd))) Method
else return ()
in
mname,rtype
| Left _ -> (* external method *) return (calle,f.m_proto.rtype)
end
end
in

let rec mono_body (lbl: label) (treated: LabelSet.t) (blocks : (VE.t,unit) basicBlock BlockMap.t): (LabelSet.t * (_,_) basicBlock BlockMap.t) MonoMonad.t =
(* collect calls and name correctly *)
if LabelSet.mem lbl treated then return (treated,blocks)
else
begin
let treated = LabelSet.add lbl treated in

let bb = BlockMap.find lbl blocks in
let* () = M.set_ve bb.forward_info in
let* () = ListM.iter (fun assign -> mono_exp assign.target >>= fun _ty -> mono_exp assign.expression >>| fun _ty -> ()) bb.assignments
in

match bb.terminator |> Option.get with
| Return e ->
let+ _ =
begin
match e with
| Some e -> let+ t = mono_exp e in Some t
| None -> return None
end
in treated,blocks

| Invoke new_f ->
let* (id,_) = construct_call new_f.id new_f.params in
mono_body new_f.next treated BlockMap.(update lbl (fun _ -> Some {bb with terminator=Some (Invoke {new_f with id})}) blocks)

| Goto lbl -> mono_body lbl treated blocks

| SwitchInt si ->
let* _ = mono_exp si.choice in
let* treated,blocks = mono_body si.default treated blocks in
ListM.fold_left ( fun (treated,blocks) (_,lbl) ->
mono_body lbl treated blocks
) (treated,blocks) si.paths

| Break -> failwith "no break should be there"
end
in

match f.m_body with
| Right (decls,cfg) -> mono_body cfg.input LabelSet.empty cfg.blocks >>= fun (_,blocks) ->
let params = List.map (fun (p:param) -> p.ty) f.m_proto.params in
let name = mangle_method_name f.m_proto.name params in
let methd = {m_proto = f.m_proto; m_body=Right (decls,{cfg with blocks})} in
M.add_fun name {methd; generics=[]}

| Left _ -> (* external *) return ()


let analyse_functions (sm : in_body SailModule.t) : unit M.t =

(* find the function, apply generic substitutions to its signature and monomorphize *)
let find_fun_and_mono (name, (g : sailor_args)) : unit M.t =
let* f = find_callable name sm |> M.lift in
match f with
| None -> (* fixme imports *) return ()
| Some f ->
(* monomorphize signature with resolved generics (if any) *)
let* params = ListM.map (fun (p : param) -> let+ ty = degenerifyType p.ty g |> M.lift in {p with ty}) f.m_proto.params in
let* rtype =
match f.m_proto.rtype with
| Some t -> let+ t = degenerifyType t g |> M.lift in Some t
| None -> return None
in
(* update function signature *)
let f = { f with m_proto = { f.m_proto with params; rtype } } in
(* monomorphize, updating env with any new function calls found *)
mono_fun f sm
in

let rec aux () : unit M.t =
let* empty = M.get_monos >>| (=) [] in
if not empty then (* runs until no more new monomorphic function is found *)
begin
let* name,args = M.pop_monos in
Logs.debug (fun m -> m "looking at function %s with args %s " name (List.map (fun (_,t) -> string_of_sailtype @@ Some t) args |> String.concat " "));

let mname = mangle_method_name name (List.split args |> snd) in

(* we only look at untreated functions *)
let* funs = M.get_funs in
match FieldMap.find_opt mname funs with
| Some _ ->
Logs.debug (fun m -> m "%s already checked" mname);
return ()
| None ->
Logs.debug (fun m -> m "analyzing monomorphic function %s" mname);
find_fun_and_mono (name, args) >>= aux
end
else return ()
in
let* empty = M.get_monos >>| (=) [] in M.throw_if Error.(make dummy_pos "no monomorphic callable (no main?)") empty >>= aux


let transform (smdl : in_body SailModule.t) : out_body SailModule.t E.t =
let polymorphics,monomorphics = List.partition (fun m -> m.m_proto.generics <> []) smdl.body.methods in
return {smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}}
let add_if_mono name args gens else_ret =
let args = List.map (fun (p:param) -> p.id,p.ty) args in
if gens <> [] then M.pure [else_ret] else M.push_monos name args >>| fun () -> []
in


let mono_poly = (* add monomorphics to the env and collect generic methods *)
M.pure []
(* our entry points are non generic methods and processes *)
>>= fun l -> ListM.fold_left (fun acc m -> add_if_mono m.m_proto.name m.m_proto.params m.m_proto.generics m >>| Fun.flip List.append acc) l smdl.body.methods
(*
analyze them, find and resolve calls to generic functions
IMPORTANT : we must keep the generic functions : if one of them
is called from an other module and we don't have a monomorphic versio, we must generate one using the generic version
*)
>>= fun l -> analyse_functions smdl >>| fun () -> l
in

let open MonadSyntax(E) in
let+ polymorphics,mono_env = M.run smdl.declEnv mono_poly in
Logs.info (fun m -> m "generated %i monomorphic functions : " (List.length (FieldMap.bindings mono_env.functions)));
FieldMap.iter print_method_proto mono_env.functions;
let monomorphics = List.filter (fun m -> Either.is_left m.m_body) smdl.body.methods |> FieldMap.fold (fun name f acc -> {f.methd with m_proto={f.methd.m_proto with name}}::acc) mono_env.functions in

{smdl with body={monomorphics;polymorphics;processes=smdl.body.processes}}
end)
45 changes: 45 additions & 0 deletions src/passes/monomorphization/monomorphizationMonad.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
open Common
open Monad
open TypesCommon
open MonomorphizationUtils

type env = {monos: monomorphics; functions : sailor_functions; env: varTypesMap}

module MonoMonad = struct
module S = MonadState.T(Error.Logger)(struct type t = env end)
open MonadSyntax(S)
open MonadOperator(S)
include S
(* error *)
let throw e = E.throw e |> lift
let throw_if e c = E.throw_if e c |> lift

let get_decl id ty = get >>| fun e -> Env.get_decl id ty e.env
let add_decl id decl ty = update (fun e -> E.bind (Env.add_decl id decl ty e.env) (fun env -> E.pure {e with env}))
let get_var id = get >>| fun e -> Env.get_var id e.env
let set_ve ve = update (fun e -> E.pure {e with env=(ve,snd e.env)})


let add_fun mname (f: 'a sailor_method) = S.update (fun e -> E.pure {e with functions=FieldMap.add mname f e.functions})
let get_funs = let+ e = S.get in e.functions

let push_monos name generics = S.update (fun e -> E.pure {e with monos=(name,generics)::e.monos})
let get_monos = let+ e = S.get in e.monos
let pop_monos = let* e = S.get in
match e.monos with
| [] -> throw Error.(make dummy_pos "empty_monos")
| h::monos -> S.set {e with monos} >>| fun () -> h

let run (decls:Env.D.t) (x: 'a t) : ('a * env) E.t = x {monos=[];functions=FieldMap.empty;env=Env.empty decls}

end


let mangle_method_name (name : string) (args : sailtype list) : string =
let back =
List.fold_left (fun s t -> s ^ string_of_sailtype (Some t) ^ "_") "" args
in
let front = "_" ^ name ^ "_" in
let res = front ^ back in
Logs.debug (fun m -> m "renamed %s to %s" name res);
res
Loading

0 comments on commit 68eb2b6

Please sign in to comment.