Skip to content

Commit

Permalink
Move boxing/unboxing outside of some primitives
Browse files Browse the repository at this point in the history
This makes it visible to binaryen, which than is able to eliminate so
unncessary boxing.
  • Loading branch information
vouillon committed Sep 26, 2024
1 parent 87f2119 commit 2368b9d
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 229 deletions.
117 changes: 104 additions & 13 deletions compiler/lib/wasm/wa_generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,71 @@ module Generate (Target : Wa_target_sig.S) = struct
; debug : Parse_bytecode.Debug.t
}

type repr =
| Value
| Float
| Int32
| Nativeint
| Int64

let repr_type r =
match r with
| Value -> Value.value
| Float -> F64
| Int32 -> I32
| Nativeint -> I32
| Int64 -> I64

let specialized_func_type (params, result) =
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }

let box_value stack_ctx x r e =
match r with
| Value -> e
| Float -> Memory.box_float stack_ctx x e
| Int32 -> Memory.box_int32 stack_ctx x e
| Nativeint -> Memory.box_nativeint stack_ctx x e
| Int64 -> Memory.box_int64 stack_ctx x e

let unbox_value r e =
match r with
| Value -> e
| Float -> Memory.unbox_float e
| Int32 -> Memory.unbox_int32 e
| Nativeint -> Memory.unbox_nativeint e
| Int64 -> Memory.unbox_int64 e

let specialized_primitives =
let h = Hashtbl.create 18 in
List.iter
~f:(fun (nm, typ) -> Hashtbl.add h nm typ)
[ "caml_int32_bswap", ([ Int32 ], Int32)
; "caml_nativeint_bswap", ([ Nativeint ], Nativeint)
; "caml_int64_bswap", ([ Int64 ], Int64)
; "caml_int32_compare", ([ Int32; Int32 ], Value)
; "caml_nativeint_compare", ([ Nativeint; Nativeint ], Value)
; "caml_int64_compare", ([ Int64; Int64 ], Value)
; "caml_string_get32", ([ Value; Value ], Int32)
; "caml_string_get64", ([ Value; Value ], Int64)
; "caml_bytes_get32", ([ Value; Value ], Int32)
; "caml_bytes_get64", ([ Value; Value ], Int64)
; "caml_bytes_set32", ([ Value; Value; Int32 ], Value)
; "caml_bytes_set64", ([ Value; Value; Int64 ], Value)
; "caml_lxm_next", ([ Value ], Int64)
; "caml_ba_uint8_get32", ([ Value; Value ], Int32)
; "caml_ba_uint8_get64", ([ Value; Value ], Int64)
; "caml_ba_uint8_set32", ([ Value; Value; Int32 ], Value)
; "caml_ba_uint8_set64", ([ Value; Value; Int64 ], Value)
; "caml_nextafter_float", ([ Float; Float ], Float)
; "caml_classify_float", ([ Float ], Value)
; "caml_ldexp_float", ([ Float; Value ], Float)
; "caml_signbit_float", ([ Float ], Value)
; "caml_erf_float", ([ Float ], Float)
; "caml_erfc_float", ([ Float ], Float)
; "caml_float_compare", ([ Float; Float ], Value)
];
h

let func_type n =
{ W.params = List.init ~len:n ~f:(fun _ -> Value.value); result = [ Value.value ] }

Expand Down Expand Up @@ -424,6 +489,10 @@ module Generate (Target : Wa_target_sig.S) = struct
| Extern "caml_int32_to_int", [ i ] -> Value.val_int (Memory.unbox_int32 i)
| Extern "caml_int32_of_int", [ i ] ->
Memory.box_int32 stack_ctx x (Value.int_val i)
| Extern "caml_nativeint_of_int32", [ i ] ->
Memory.box_nativeint stack_ctx x (Memory.unbox_int32 i)
| Extern "caml_nativeint_to_int32", [ i ] ->
Memory.box_int32 stack_ctx x (Memory.unbox_nativeint i)
| Extern "caml_int64_bits_of_float", [ f ] ->
let* f = Memory.unbox_float f in
Memory.box_int64 stack_ctx x (return (W.UnOp (I64 ReinterpretF, f)))
Expand Down Expand Up @@ -634,21 +703,43 @@ module Generate (Target : Wa_target_sig.S) = struct
~init:(return [])
in
Memory.allocate stack_ctx x ~tag:0 l
| Extern name, l ->
| Extern name, l -> (
let name = Primitive.resolve name in
(*ZZZ Different calling convention when large number of parameters *)
let* f = register_import ~name (Fun (func_type (List.length l))) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc l =
match l with
| [] ->
Stack.kill_variables stack_ctx;
return (W.Call (f, List.rev acc))
| x :: r ->
let* x = x in
loop (x :: acc) r
in
loop [] l
try
let typ = Hashtbl.find specialized_primitives name in
let* f = register_import ~name (Fun (specialized_func_type typ)) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc arg_typ l =
match arg_typ, l with
| [], [] ->
Stack.kill_variables stack_ctx;
box_value
stack_ctx
x
(snd typ)
(return (W.Call (f, List.rev acc)))
| repr :: rem, x :: r ->
let* x = unbox_value repr x in
loop (x :: acc) rem r
| [], _ :: _ | _ :: _, [] ->
Format.eprintf "ZZZ %s@." name;
assert false
in
loop [] (fst typ) l
with Not_found ->
let* f = register_import ~name (Fun (func_type (List.length l))) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc l =
match l with
| [] ->
Stack.kill_variables stack_ctx;
return (W.Call (f, List.rev acc))
| x :: r ->
let* x = x in
loop (x :: acc) r
in
loop [] l)
| Not, [ x ] -> Value.not x
| Lt, [ x; y ] -> Value.lt x y
| Le, [ x; y ] -> Value.le x y
Expand Down
108 changes: 52 additions & 56 deletions runtime/wasm/bigarray.wat
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@
(i32.const 8)))))

(func (export "caml_ba_uint8_get32")
(param $vba (ref eq)) (param $i (ref eq)) (result (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (result i32)
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32)
Expand All @@ -1933,23 +1933,22 @@
(struct.get $bigarray $ba_dim (local.get $ba))
(i32.const 0)))
(then (call $caml_bound_error)))
(return_call $caml_copy_int32
(i32.or
(i32.or
(call $ta_get_ui8 (local.get $data) (local.get $p))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1)))
(i32.const 8)))
(i32.or
(i32.or
(call $ta_get_ui8 (local.get $data) (local.get $p))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1)))
(i32.const 8)))
(i32.or
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2)))
(i32.const 16))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3)))
(i32.const 24))))))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2)))
(i32.const 16))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3)))
(i32.const 24)))))

(func (export "caml_ba_uint8_get64")
(param $vba (ref eq)) (param $i (ref eq)) (result (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (result i64)
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32)
Expand All @@ -1963,44 +1962,43 @@
(struct.get $bigarray $ba_dim (local.get $ba))
(i32.const 0)))
(then (call $caml_bound_error)))
(return_call $caml_copy_int64
(i64.or
(i64.or
(i64.or
(i64.extend_i32_u
(call $ta_get_ui8 (local.get $data) (local.get $p)))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1))))
(i64.const 8)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2))))
(i64.const 16))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3))))
(i64.const 24))))
(i64.or
(i64.or
(i64.or
(i64.extend_i32_u
(call $ta_get_ui8 (local.get $data) (local.get $p)))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1))))
(i64.const 8)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2))))
(i64.const 16))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3))))
(i64.const 24))))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 4))))
(i64.const 32))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 5))))
(i64.const 40)))
(i64.or
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 4))))
(i64.const 32))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 5))))
(i64.const 40)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 6))))
(i64.const 48))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 7))))
(i64.const 56)))))))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 6))))
(i64.const 48))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 7))))
(i64.const 56))))))

(func (export "caml_ba_uint8_set16")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
Expand All @@ -2026,15 +2024,14 @@
(ref.i31 (i32.const 0)))

(func (export "caml_ba_uint8_set32")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (param $d i32)
(result (ref eq))
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32) (local $d i32)
(local $p i32)
(local.set $ba (ref.cast (ref $bigarray) (local.get $vba)))
(local.set $data (struct.get $bigarray $ba_data (local.get $ba)))
(local.set $p (i31.get_s (ref.cast (ref i31) (local.get $i))))
(local.set $d (call $Int32_val (local.get $v)))
(if (i32.lt_s (local.get $p) (i32.const 0))
(then (call $caml_bound_error)))
(if (i32.ge_u (i32.add (local.get $p) (i32.const 3))
Expand All @@ -2056,15 +2053,14 @@
(ref.i31 (i32.const 0)))

(func (export "caml_ba_uint8_set64")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (param $d i64)
(result (ref eq))
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32) (local $d i64)
(local $p i32)
(local.set $ba (ref.cast (ref $bigarray) (local.get $vba)))
(local.set $data (struct.get $bigarray $ba_data (local.get $ba)))
(local.set $p (i31.get_s (ref.cast (ref i31) (local.get $i))))
(local.set $d (call $Int64_val (local.get $v)))
(if (i32.lt_s (local.get $p) (i32.const 0))
(then (call $caml_bound_error)))
(if (i32.ge_u (i32.add (local.get $p) (i32.const 7))
Expand Down
Loading

0 comments on commit 2368b9d

Please sign in to comment.