Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
vouillon committed Sep 25, 2024
1 parent 87f2119 commit a3a25b8
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 229 deletions.
117 changes: 104 additions & 13 deletions compiler/lib/wasm/wa_generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,71 @@ module Generate (Target : Wa_target_sig.S) = struct
; debug : Parse_bytecode.Debug.t
}

type repr =
| Value
| Float
| Int32
| Nativeint
| Int64

let repr_type r =
match r with
| Value -> Value.value
| Float -> F64
| Int32 -> I32
| Nativeint -> I32
| Int64 -> I64

let specialized_func_type (params, result) =
{ W.params = List.map ~f:repr_type params; result = [ repr_type result ] }

let box_value stack_ctx x r e =
match r with
| Value -> e
| Float -> Memory.box_float stack_ctx x e
| Int32 -> Memory.box_int32 stack_ctx x e
| Nativeint -> Memory.box_nativeint stack_ctx x e
| Int64 -> Memory.box_int64 stack_ctx x e

let unbox_value r e =
match r with
| Value -> e
| Float -> Memory.unbox_float e
| Int32 -> Memory.unbox_int32 e
| Nativeint -> Memory.unbox_nativeint e
| Int64 -> Memory.unbox_int64 e

let specialized_primitives =
let h = Hashtbl.create 18 in
List.iter
~f:(fun (nm, typ) -> Hashtbl.add h nm typ)
[ "caml_int32_bswap", ([ Int32 ], Int32)
; "caml_nativeint_bswap", ([ Nativeint ], Nativeint)
; "caml_int64_bswap", ([ Int64 ], Int64)
; "caml_int32_compare", ([ Int32; Int32 ], Value)
; "caml_nativeint_compare", ([ Nativeint; Nativeint ], Value)
; "caml_int64_compare", ([ Int64; Int64 ], Value)
; "caml_string_get32", ([ Value; Value ], Int32)
; "caml_string_get64", ([ Value; Value ], Int64)
; "caml_bytes_get32", ([ Value; Value ], Int32)
; "caml_bytes_get64", ([ Value; Value ], Int64)
; "caml_bytes_set32", ([ Value; Value; Int32 ], Value)
; "caml_bytes_set64", ([ Value; Value; Int64 ], Value)
; "caml_lxm_next", ([ Value ], Int64)
; "caml_ba_uint8_get32", ([ Value; Value ], Int32)
; "caml_ba_uint8_get64", ([ Value; Value ], Int64)
; "caml_ba_uint8_set32", ([ Value; Value; Int32 ], Value)
; "caml_ba_uint8_set64", ([ Value; Value; Int64 ], Value)
; "caml_nextafter_float", ([ Float; Float ], Float)
; "caml_classify_float", ([ Float ], Value)
; "caml_ldexp_float", ([ Float; Value ], Float)
; "caml_signbit_float", ([ Float ], Value)
; "caml_erf_float", ([ Float ], Float)
; "caml_erfc_float", ([ Float ], Float)
; "caml_float_compare", ([ Float; Float ], Value)
];
h

let func_type n =
{ W.params = List.init ~len:n ~f:(fun _ -> Value.value); result = [ Value.value ] }

Expand Down Expand Up @@ -424,6 +489,10 @@ module Generate (Target : Wa_target_sig.S) = struct
| Extern "caml_int32_to_int", [ i ] -> Value.val_int (Memory.unbox_int32 i)
| Extern "caml_int32_of_int", [ i ] ->
Memory.box_int32 stack_ctx x (Value.int_val i)
| Extern "caml_nativeint_of_int32", [ i ] ->
Memory.box_nativeint stack_ctx x (Memory.unbox_int32 i)
| Extern "caml_nativeint_to_int32", [ i ] ->
Memory.box_int32 stack_ctx x (Memory.unbox_nativeint i)
| Extern "caml_int64_bits_of_float", [ f ] ->
let* f = Memory.unbox_float f in
Memory.box_int64 stack_ctx x (return (W.UnOp (I64 ReinterpretF, f)))
Expand Down Expand Up @@ -634,21 +703,43 @@ module Generate (Target : Wa_target_sig.S) = struct
~init:(return [])
in
Memory.allocate stack_ctx x ~tag:0 l
| Extern name, l ->
| Extern name, l -> (
let name = Primitive.resolve name in
(*ZZZ Different calling convention when large number of parameters *)
let* f = register_import ~name (Fun (func_type (List.length l))) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc l =
match l with
| [] ->
Stack.kill_variables stack_ctx;
return (W.Call (f, List.rev acc))
| x :: r ->
let* x = x in
loop (x :: acc) r
in
loop [] l
try
let typ = Hashtbl.find specialized_primitives name in
let* f = register_import ~name (Fun (specialized_func_type typ)) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc arg_typ l =
match arg_typ, l with
| [], [] ->
Stack.kill_variables stack_ctx;
box_value
stack_ctx
x
(snd typ)
(return (W.Call (f, List.rev acc)))
| repr :: rem, x :: r ->
let* x = unbox_value repr x in
loop (x :: acc) rem r
| [], _ :: _ | _ :: _, [] ->
Format.eprintf "ZZZ %s@." name;
assert false
in
loop [] (fst typ) l
with Not_found ->
let* f = register_import ~name (Fun (func_type (List.length l))) in
let* () = Stack.perform_spilling stack_ctx (`Instr x) in
let rec loop acc l =
match l with
| [] ->
Stack.kill_variables stack_ctx;
return (W.Call (f, List.rev acc))
| x :: r ->
let* x = x in
loop (x :: acc) r
in
loop [] l)
| Not, [ x ] -> Value.not x
| Lt, [ x; y ] -> Value.lt x y
| Le, [ x; y ] -> Value.le x y
Expand Down
108 changes: 52 additions & 56 deletions runtime/wasm/bigarray.wat
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@
(i32.const 8)))))

(func (export "caml_ba_uint8_get32")
(param $vba (ref eq)) (param $i (ref eq)) (result (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (result i32)
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32)
Expand All @@ -1933,23 +1933,22 @@
(struct.get $bigarray $ba_dim (local.get $ba))
(i32.const 0)))
(then (call $caml_bound_error)))
(return_call $caml_copy_int32
(i32.or
(i32.or
(call $ta_get_ui8 (local.get $data) (local.get $p))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1)))
(i32.const 8)))
(i32.or
(i32.or
(call $ta_get_ui8 (local.get $data) (local.get $p))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1)))
(i32.const 8)))
(i32.or
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2)))
(i32.const 16))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3)))
(i32.const 24))))))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2)))
(i32.const 16))
(i32.shl (call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3)))
(i32.const 24)))))

(func (export "caml_ba_uint8_get64")
(param $vba (ref eq)) (param $i (ref eq)) (result (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (result i64)
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32)
Expand All @@ -1963,44 +1962,43 @@
(struct.get $bigarray $ba_dim (local.get $ba))
(i32.const 0)))
(then (call $caml_bound_error)))
(return_call $caml_copy_int64
(i64.or
(i64.or
(i64.or
(i64.extend_i32_u
(call $ta_get_ui8 (local.get $data) (local.get $p)))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1))))
(i64.const 8)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2))))
(i64.const 16))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3))))
(i64.const 24))))
(i64.or
(i64.or
(i64.or
(i64.extend_i32_u
(call $ta_get_ui8 (local.get $data) (local.get $p)))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 1))))
(i64.const 8)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 2))))
(i64.const 16))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 3))))
(i64.const 24))))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 4))))
(i64.const 32))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 5))))
(i64.const 40)))
(i64.or
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 4))))
(i64.const 32))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 5))))
(i64.const 40)))
(i64.or
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 6))))
(i64.const 48))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 7))))
(i64.const 56)))))))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 6))))
(i64.const 48))
(i64.shl (i64.extend_i32_u
(call $ta_get_ui8 (local.get $data)
(i32.add (local.get $p) (i32.const 7))))
(i64.const 56))))))

(func (export "caml_ba_uint8_set16")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
Expand All @@ -2026,15 +2024,14 @@
(ref.i31 (i32.const 0)))

(func (export "caml_ba_uint8_set32")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (param $d i32)
(result (ref eq))
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32) (local $d i32)
(local $p i32)
(local.set $ba (ref.cast (ref $bigarray) (local.get $vba)))
(local.set $data (struct.get $bigarray $ba_data (local.get $ba)))
(local.set $p (i31.get_s (ref.cast (ref i31) (local.get $i))))
(local.set $d (call $Int32_val (local.get $v)))
(if (i32.lt_s (local.get $p) (i32.const 0))
(then (call $caml_bound_error)))
(if (i32.ge_u (i32.add (local.get $p) (i32.const 3))
Expand All @@ -2056,15 +2053,14 @@
(ref.i31 (i32.const 0)))

(func (export "caml_ba_uint8_set64")
(param $vba (ref eq)) (param $i (ref eq)) (param $v (ref eq))
(param $vba (ref eq)) (param $i (ref eq)) (param $d i64)
(result (ref eq))
(local $ba (ref $bigarray))
(local $data (ref extern))
(local $p i32) (local $d i64)
(local $p i32)
(local.set $ba (ref.cast (ref $bigarray) (local.get $vba)))
(local.set $data (struct.get $bigarray $ba_data (local.get $ba)))
(local.set $p (i31.get_s (ref.cast (ref i31) (local.get $i))))
(local.set $d (call $Int64_val (local.get $v)))
(if (i32.lt_s (local.get $p) (i32.const 0))
(then (call $caml_bound_error)))
(if (i32.ge_u (i32.add (local.get $p) (i32.const 7))
Expand Down
Loading

0 comments on commit a3a25b8

Please sign in to comment.