Skip to content

Commit

Permalink
[jvm] rework functional interface unification again
Browse files Browse the repository at this point in the history
see #11390
  • Loading branch information
Simn committed Jan 8, 2024
1 parent 05631cd commit 5f18f21
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 115 deletions.
6 changes: 6 additions & 0 deletions src-json/meta.json
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@
"targets": ["TAbstractField"],
"links": ["https://haxe.org/manual/types-abstract-implicit-casts.html"]
},
{
"name": "FunctionalInterface",
"metadata": ":functionalInterface",
"doc": "Mark an interface as a functional interface",
"platforms": ["java"]
},
{
"name": "FunctionCode",
"metadata": ":functionCode",
Expand Down
10 changes: 10 additions & 0 deletions src/codegen/javaModern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,16 @@ module Converter = struct
in
add_meta (Meta.Annotation,args,p)
end;
List.iter (fun attr -> match attr with
| AttrVisibleAnnotations ann ->
List.iter (function
| { ann_type = TObject( (["java";"lang"], "FunctionalInterface"), [] ) } ->
add_meta (Meta.FunctionalInterface,[],p);
| _ -> ()
) ann
| _ ->
()
) jc.jc_attributes;
let d = {
d_name = (class_name,p);
d_doc = None;
Expand Down
5 changes: 3 additions & 2 deletions src/context/abstractCast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ and do_check_cast ctx uctx tleft eright p =
loop2 a.a_to
end
| TInst(c,tl), TFun _ when has_class_flag c CFunctionalInterface ->
let cf = ctx.g.functional_interface_lut#find c.cl_path in
let _,cf = ctx.com.functional_interface_lut#find c.cl_path in
let map = apply_params c.cl_params tl in
let monos = Monomorph.spawn_constrained_monos map cf.cf_params in
unify_raise_custom uctx eright.etype (map (apply_params cf.cf_params monos cf.cf_type)) p;
unify_raise_custom native_unification_context eright.etype (map (apply_params cf.cf_params monos cf.cf_type)) p;
if has_mono tright then raise_typing_error ("Cannot use this function as a functional interface because it has unknown types: " ^ (s_type (print_context()) tright)) p;
eright
| _ ->
raise Not_found
Expand Down
3 changes: 3 additions & 0 deletions src/context/common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ type context = {
mutable main : Type.texpr option;
mutable types : Type.module_type list;
mutable resources : (string,string) Hashtbl.t;
functional_interface_lut : (path,(tclass * tclass_field)) lookup;
(* target-specific *)
mutable flash_version : float;
mutable neko_lib_paths : string list;
Expand Down Expand Up @@ -871,6 +872,7 @@ let create compilation_step cs version args display_mode =
has_error = false;
report_mode = RMNone;
is_macro_context = false;
functional_interface_lut = new Lookup.hashtbl_lookup;
} in
com

Expand Down Expand Up @@ -917,6 +919,7 @@ let clone com is_macro_context =
overload_cache = new hashtbl_lookup;
module_lut = new module_lut;
std = null_class;
functional_interface_lut = new Lookup.hashtbl_lookup;
}

let file_time file = Extc.filetime file
Expand Down
1 change: 0 additions & 1 deletion src/context/typecore.ml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ type typer_globals = {
mutable complete : bool;
mutable type_hints : (module_def_display * pos * t) list;
mutable load_only_cached_modules : bool;
functional_interface_lut : (path,tclass_field) lookup;
(* api *)
do_macro : typer -> macro_mode -> path -> string -> expr list -> pos -> macro_result;
do_load_macro : typer -> bool -> path -> string -> pos -> ((string * bool * t) list * t * tclass * Type.tclass_field);
Expand Down
9 changes: 9 additions & 0 deletions src/core/tUnification.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ let default_unification_context = {
equality_underlying = false;
}

(* Unify like targets (e.g. Java) probably would. *)
let native_unification_context = {
allow_transitive_cast = false;
allow_abstract_cast = false;
allow_dynamic_to_cast = false;
equality_kind = EqStrict;
equality_underlying = false;
}

module Monomorph = struct
let create () = {
tm_type = None;
Expand Down
50 changes: 39 additions & 11 deletions src/generators/genjvm.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type generation_context = {
t_exception : Type.t;
t_throwable : Type.t;
anon_identification : jsignature tanon_identification;
mutable functional_interfaces : (tclass * tclass_field * JvmFunctions.JavaFunctionalInterface.t) list;
mutable preprocessor : jsignature preprocessor;
default_export_config : export_config;
typed_functions : JvmFunctions.typed_functions;
Expand Down Expand Up @@ -436,10 +437,31 @@ let generate_equals_function (jc : JvmClass.builder) jsig_arg =
save();
jm_equals,load

let create_field_closure gctx jc path_this jm name jsig =
let associate_functional_interfaces gctx f t =
if not (has_mono t) then begin
List.iter (fun (c,cf,jfi) ->
let c_monos = Monomorph.spawn_constrained_monos (fun t -> t) c.cl_params in
let map t = apply_params c.cl_params c_monos t in
let cf_monos = Monomorph.spawn_constrained_monos map cf.cf_params in
try
Type.unify_custom native_unification_context t (apply_params cf.cf_params cf_monos (map cf.cf_type));
ignore(List.map follow cf_monos);
f#add_functional_interface jfi (List.map (jsignature_of_type gctx) c_monos)
with Unify_error _ ->
()
) gctx.functional_interfaces
end

let create_field_closure gctx jc path_this jm name jsig t =
let jsig_this = object_path_sig path_this in
let context = ["this",jsig_this] in
let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncMember(path_this,name)) jc jm context in
begin match t with
| None ->
()
| Some t ->
associate_functional_interfaces gctx wf t
end;
let jc_closure = wf#get_class in
ignore(wf#generate_constructor true);
let args,ret = match jsig with
Expand Down Expand Up @@ -480,12 +502,12 @@ let create_field_closure gctx jc path_this jm name jsig =
write_class gctx jc_closure#get_this_path (jc_closure#export_class gctx.default_export_config);
jc_closure#get_this_path

let create_field_closure gctx jc path_this jm name jsig f =
let create_field_closure gctx jc path_this jm name jsig f t =
let jsig_this = object_path_sig path_this in
let closure_path = try
Hashtbl.find gctx.closure_paths (path_this,name,jsig)
with Not_found ->
let closure_path = create_field_closure gctx jc path_this jm name jsig in
let closure_path = create_field_closure gctx jc path_this jm name jsig t in
Hashtbl.add gctx.closure_paths (path_this,name,jsig) closure_path;
closure_path
in
Expand Down Expand Up @@ -595,6 +617,7 @@ class texpr_to_jvm
| _ -> None
in
let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncLocal name) jc jm context in
associate_functional_interfaces gctx wf e.etype;
let jc_closure = wf#get_class in
ignore(wf#generate_constructor (env <> []));
let filter = match ret with
Expand Down Expand Up @@ -678,12 +701,13 @@ class texpr_to_jvm
| None ->
default();

method read_static_closure (path : path) (name : string) (args : (string * jsignature) list) (ret : jsignature option) =
method read_static_closure (path : path) (name : string) (args : (string * jsignature) list) (ret : jsignature option) (t : Type.t) =
let jsig = method_sig (List.map snd args) ret in
let closure_path = try
Hashtbl.find gctx.closure_paths (path,name,jsig)
with Not_found ->
let wf = new JvmFunctions.typed_function gctx.typed_functions (FuncStatic(path,name)) jc jm [] in
associate_functional_interfaces gctx wf t;
let jc_closure = wf#get_class in
ignore(wf#generate_constructor false);
let jm_invoke = wf#generate_invoke args ret [] in
Expand All @@ -710,7 +734,7 @@ class texpr_to_jvm
| TFun(tl,tr) -> List.map (fun (n,_,t) -> n,self#vtype t) tl,(return_of_type gctx tr)
| _ -> die "" __LOC__
in
self#read_static_closure path cf.cf_name args ret
self#read_static_closure path cf.cf_name args ret cf.cf_type
in
let dynamic_read s =
self#texpr rvalue_any e1;
Expand Down Expand Up @@ -757,7 +781,7 @@ class texpr_to_jvm
else
create_field_closure gctx jc c.cl_path jm cf.cf_name (self#vtype cf.cf_type) (fun () ->
self#texpr rvalue_any e1;
)
) (Some cf.cf_type)

method read_write ret ak e (f : unit -> unit) =
let apply dup =
Expand Down Expand Up @@ -2228,7 +2252,7 @@ let generate_dynamic_access gctx (jc : JvmClass.builder) fields is_anon =
begin match kind,jsig with
| Method (MethNormal | MethInline),TMethod(args,_) ->
if gctx.dynamic_level >= 2 then begin
create_field_closure gctx jc jc#get_this_path jm name jsig (fun () -> jm#load_this)
create_field_closure gctx jc jc#get_this_path jm name jsig (fun () -> jm#load_this) None
end else begin
jm#load_this;
jm#string name;
Expand Down Expand Up @@ -2954,7 +2978,7 @@ module Preprocessor = struct
end else if fst mt.mt_path = [] then
mt.mt_path <- make_root mt.mt_path

let check_single_method_interface gctx c =
let check_functional_interface gctx c =
let rec loop m l = match l with
| [] ->
m
Expand All @@ -2973,7 +2997,8 @@ module Preprocessor = struct
| Some cf ->
match jsignature_of_type gctx cf.cf_type with
| TMethod(args,ret) ->
JvmFunctions.JavaFunctionalInterfaces.add args ret c.cl_path cf.cf_name (List.map extract_param_name (c.cl_params @ cf.cf_params));
let jfi = JvmFunctions.JavaFunctionalInterface.create args ret c.cl_path cf.cf_name (List.map extract_param_name (c.cl_params @ cf.cf_params)) in
gctx.functional_interfaces <- (c,cf,jfi) :: gctx.functional_interfaces;
| _ ->
()

Expand Down Expand Up @@ -3005,8 +3030,10 @@ module Preprocessor = struct
List.iter (fun mt ->
match mt with
| TClassDecl c ->
if not (has_class_flag c CInterface) then gctx.preprocessor#preprocess_class c
else check_single_method_interface gctx c;
if not (has_class_flag c CInterface) then
gctx.preprocessor#preprocess_class c
else if has_class_flag c CFunctionalInterface then
check_functional_interface gctx c
| _ -> ()
) gctx.com.types;
(* find typedef-interface implementations *)
Expand Down Expand Up @@ -3082,6 +3109,7 @@ let generate jvm_flag com =
timer = new Timer.timer ["generate";"java"];
jar_compression_level = compression_level;
dynamic_level = dynamic_level;
functional_interfaces = [];
} in
gctx.preprocessor <- new preprocessor com.basic (jsignature_of_type gctx);
gctx.typedef_interfaces <- new typedef_interfaces gctx.preprocessor#get_infos anon_identification;
Expand Down
109 changes: 18 additions & 91 deletions src/generators/jvm/jvmFunctions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ type typed_function_kind =
| FuncMember of jpath * string
| FuncStatic of jpath * string

module JavaFunctionalInterfaces = struct
module JavaFunctionalInterface = struct
type t = {
jargs: jsignature list;
jret : jsignature option;
Expand All @@ -302,93 +302,17 @@ module JavaFunctionalInterfaces = struct
"jparams",String.concat ", " jfi.jparams;
]

let java_functional_interfaces = DynArray.create ()

let add args ret path name params =
let create args ret path name params =
let jfi = {
jargs = args;
jret = ret;
jpath = path;
jname = name;
jparams = params;
} in
DynArray.add java_functional_interfaces jfi

let unify jfi args ret =
let params = ref [] in
let rec unify jsig1 jsig2 = match jsig1,jsig2 with
| TObject _,TObject((["java";"lang"],"Object"),[]) ->
true
| TObject(path1,params1),TObject(path2,params2) ->
path1 = path2 &&
unify_params params1 params2
| TTypeParameter n,jsig
| jsig,TTypeParameter n ->
List.mem_assoc n !params || begin
params := (n,jsig) :: !params;
true
end
| _ ->
jsig1 = jsig2
and unify_params params1 params2 = match params1,params2 with
| [],_
| _,[] ->
(* Assume raw type, I guess? *)
true
| param1 :: params1,param2 :: params2 ->
match param1,param2 with
| TAny,_
| _,TAny ->
(* Is this correct in both directions? *)
unify_params params1 params2
| TType(_,jsig1),TType(_,jsig2) ->
(* TODO: wildcard? *)
unify jsig1 jsig2 && unify_params params1 params2
in
let rec loop want have = match want,have with
| [],[] ->
let params = List.map (fun s ->
try
TType(WNone,List.assoc s !params)
with Not_found ->
TAny
) jfi.jparams in
Some (jfi,params)
| want1 :: want,have1 :: have ->
if unify have1 want1 then loop want have
else None
| _ ->
None
in
match jfi.jret,ret with
| None,None ->
loop jfi.jargs args
| Some jsig1,Some jsig2 ->
if unify jsig2 jsig1 then loop jfi.jargs args
else None
| _ ->
None


let find_compatible args ret filter =
DynArray.fold_left (fun acc jfi ->
if filter = [] || List.mem jfi.jpath filter then begin
if jfi.jparams = [] then begin
if jfi.jargs = args && jfi.jret = ret then
(jfi,[]) :: acc
else
acc
end else match unify jfi args ret with
| Some x ->
x :: acc
| None ->
acc
end else
acc
) [] java_functional_interfaces
jfi
end

open JavaFunctionalInterfaces
open JvmGlobals

class typed_function
Expand All @@ -400,6 +324,8 @@ class typed_function

= object(self)

val mutable functional_interfaces = []

val jc_closure =
let name = match kind with
| FuncLocal s ->
Expand Down Expand Up @@ -431,6 +357,10 @@ class typed_function
jm_ctor#return;
jm_ctor

method add_functional_interface (jfi : JavaFunctionalInterface.t) (params : jsignature list) =
let params = List.map (fun jsig -> TType(WNone,jsig)) params in
functional_interfaces <- (jfi,params) :: functional_interfaces

method generate_invoke (args : (string * jsignature) list) (ret : jsignature option) (functional_interface_filter : jpath list) =
let arg_sigs = List.map snd args in
let meth = functions#register_signature arg_sigs ret in
Expand All @@ -455,19 +385,16 @@ class typed_function
functions#make_forward_method jc_closure jm_invoke_next meth_from meth_to;
end
in
let check_functional_interfaces meth =
let l = JavaFunctionalInterfaces.find_compatible meth.dargs meth.dret functional_interface_filter in
List.iter (fun (jfi,params) ->
add_interface jfi.jpath params;
let msig = method_sig jfi.jargs jfi.jret in
if not (jc_closure#has_method jfi.jname msig) then begin
let jm_invoke_next = spawn_invoke_next jfi.jname msig false in
functions#make_forward_method_jsig jc_closure jm_invoke_next meth.name jfi.jargs jfi.jret meth.dargs meth.dret
end
) l
in
let open JavaFunctionalInterface in
List.iter (fun (jfi,params) ->
add_interface jfi.jpath params;
let msig = method_sig jfi.jargs jfi.jret in
if not (jc_closure#has_method jfi.jname msig) then begin
let jm_invoke_next = spawn_invoke_next jfi.jname msig false in
functions#make_forward_method_jsig jc_closure jm_invoke_next meth.name jfi.jargs jfi.jret meth.dargs meth.dret
end
) functional_interfaces;
let rec loop meth =
check_functional_interfaces meth;
begin match meth.next with
| Some meth_next ->
spawn_forward_function meth_next meth true;
Expand Down
Loading

0 comments on commit 5f18f21

Please sign in to comment.