Skip to content

Commit

Permalink
Insert caught exception in place of the rewritten or generated item
Browse files Browse the repository at this point in the history
When a context free rule raises, the exception was caught and turned
into an error extension node prepended to the whole AST.
This changes this behaviour to instead insert the error extension where
the generated code would be, had the rule not raised.

Signed-off-by: Nathan Rebours <[email protected]>
  • Loading branch information
NathanReb committed Feb 9, 2024
1 parent 74342e6 commit 17a7fd5
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 34 deletions.
93 changes: 73 additions & 20 deletions src/context_free.ml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,46 @@ module Generated_code_hook = struct
| _ -> t.f context { loc with loc_start = loc.loc_end } x
end

(* Used to insert error extensions *)
let wrap_extension : type a. loc:Location.t -> a EC.t -> a -> extension -> a =
fun ~loc t original_node extension ->
match t with
| Class_expr -> Ast_builder.Default.pcl_extension ~loc extension
| Class_field -> Ast_builder.Default.pcf_extension ~loc extension
| Class_type -> Ast_builder.Default.pcty_extension ~loc extension
| Class_type_field -> Ast_builder.Default.pctf_extension ~loc extension
| Core_type -> Ast_builder.Default.ptyp_extension ~loc extension
| Expression -> Ast_builder.Default.pexp_extension ~loc extension
| Module_expr -> Ast_builder.Default.pmod_extension ~loc extension
| Module_type -> Ast_builder.Default.pmty_extension ~loc extension
| Pattern -> Ast_builder.Default.ppat_extension ~loc extension
| Signature_item -> Ast_builder.Default.psig_extension ~loc extension []
| Structure_item -> Ast_builder.Default.pstr_extension ~loc extension []
| Ppx_import ->
(* Insert the error in the type decl manifest *)
let ptype_manifest =
Some (Ast_builder.Default.ptyp_extension ~loc extension)
in
{ original_node with ptype_manifest }

let exn_to_extension exn =
let error = exn_to_loc_error exn in
let loc = Location.Error.get_location error in
let extension = Location.Error.to_extension error in
(extension, loc)

let exn_to_error_extension context original_node exn =
let extension, loc = exn_to_extension exn in
wrap_extension ~loc context original_node extension

let exn_to_stri exn =
let extension, loc = exn_to_extension exn in
Ast_builder.Default.pstr_extension ~loc extension []

let exn_to_sigi exn =
let extension, loc = exn_to_extension exn in
Ast_builder.Default.psig_extension ~loc extension []

let rec map_node_rec context ts super_call loc base_ctxt x ~embed_errors =
let ctxt =
Expansion_context.Extension.make ~extension_point_loc:loc ~base:base_ctxt ()
Expand All @@ -207,7 +247,8 @@ let rec map_node_rec context ts super_call loc base_ctxt x ~embed_errors =
(try
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
with exn when embed_errors ->
With_errors.return (Some (exn_to_error_extension context x exn)))
>>= fun converted ->
match converted with
| None -> super_call base_ctxt x
Expand All @@ -227,7 +268,8 @@ let map_node context ts super_call loc base_ctxt x ~hook ~embed_errors =
(try
E.For_context.convert_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
with exn when embed_errors ->
With_errors.return (Some (exn_to_error_extension context x exn)))
>>= fun converted ->
match converted with
| None -> super_call base_ctxt x
Expand Down Expand Up @@ -261,7 +303,8 @@ let rec map_nodes context ts super_call get_loc base_ctxt l ~hook ~embed_errors
(try
E.For_context.convert_inline_res ts ~ctxt ext
|> With_errors.of_result ~default:None
with exn when embed_errors -> (None, [ exn_to_loc_error exn ]))
with exn when embed_errors ->
With_errors.return (Some [ exn_to_error_extension context x exn ]))
>>= function
| None ->
super_call base_ctxt x >>= fun x ->
Expand Down Expand Up @@ -350,7 +393,7 @@ let context_free_attribute_modification ~loc =
of one element; it only has [@@deriving].
*)
let handle_attr_group_inline attrs rf ~items ~expanded_items ~loc ~base_ctxt
~embed_errors =
~embed_errors ~convert_exn =
List.fold_left attrs ~init:(return [])
~f:(fun acc (Rule.Attr_group_inline.T group) ->
acc >>= fun acc ->
Expand All @@ -368,10 +411,12 @@ let handle_attr_group_inline attrs rf ~items ~expanded_items ~loc ~base_ctxt
try
let expect_items = group.expand ~ctxt rf expanded_items values in
return (expect_items :: acc)
with exn when embed_errors -> (acc, [ exn_to_loc_error exn ])))
with exn when embed_errors ->
let error_item = [ convert_exn exn ] in
return (error_item :: acc)))

let handle_attr_inline attrs ~item ~expanded_item ~loc ~base_ctxt ~embed_errors
=
let handle_attr_inline attrs ~convert_exn ~item ~expanded_item ~loc ~base_ctxt
~embed_errors =
List.fold_left attrs ~init:(return []) ~f:(fun acc (Rule.Attr_inline.T a) ->
acc >>= fun acc ->
Attribute.get_res a.attribute item |> of_result ~default:None
Expand All @@ -390,7 +435,9 @@ let handle_attr_inline attrs ~item ~expanded_item ~loc ~base_ctxt ~embed_errors
try
let expect_items = a.expand ~ctxt expanded_item value in
return (expect_items :: acc)
with exn when embed_errors -> (acc, [ exn_to_loc_error exn ])))
with exn when embed_errors ->
let error_item = [ convert_exn exn ] in
return (error_item :: acc)))

module Expect_mismatch_handler = struct
type t = {
Expand Down Expand Up @@ -688,43 +735,46 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
loop rest ~in_generated_code >>| fun rest -> items @ rest)
| _ -> (
super#structure_item base_ctxt item >>= fun expanded_item ->
let convert_exn = exn_to_stri in
match (item.pstr_desc, expanded_item.pstr_desc) with
| Pstr_type (rf, tds), Pstr_type (exp_rf, exp_tds) ->
(* No context-free rule can rewrite rec flags atm, this
assert acts as a failsafe in case it ever changes *)
assert (Poly.(rf = exp_rf));
handle_attr_group_inline attr_str_type_decls rf ~items:tds
~expanded_items:exp_tds ~loc ~base_ctxt
~expanded_items:exp_tds ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_group_inline attr_str_type_decls_expect rf
~items:tds ~expanded_items:exp_tds ~loc ~base_ctxt
~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Pstr_modtype mtd, Pstr_modtype exp_mtd ->
handle_attr_inline attr_str_module_type_decls ~item:mtd
~expanded_item:exp_mtd ~loc ~base_ctxt
~expanded_item:exp_mtd ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_str_module_type_decls_expect
~item:mtd ~expanded_item:exp_mtd ~loc ~base_ctxt
~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Pstr_typext te, Pstr_typext exp_te ->
handle_attr_inline attr_str_type_exts ~item:te
~expanded_item:exp_te ~loc ~base_ctxt
~expanded_item:exp_te ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_str_type_exts_expect ~item:te
~expanded_item:exp_te ~loc ~base_ctxt
~expanded_item:exp_te ~loc ~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Pstr_exception ec, Pstr_exception exp_ec ->
handle_attr_inline attr_str_exceptions ~item:ec
~expanded_item:exp_ec ~loc ~base_ctxt
~expanded_item:exp_ec ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_str_exceptions_expect ~item:ec
~expanded_item:exp_ec ~loc ~base_ctxt
~expanded_item:exp_ec ~loc ~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
Expand Down Expand Up @@ -783,43 +833,46 @@ class map_top_down ?(expect_mismatch_handler = Expect_mismatch_handler.nop)
loop rest ~in_generated_code >>| fun rest -> items @ rest)
| _ -> (
super#signature_item base_ctxt item >>= fun expanded_item ->
let convert_exn = exn_to_sigi in
match (item.psig_desc, expanded_item.psig_desc) with
| Psig_type (rf, tds), Psig_type (exp_rf, exp_tds) ->
(* No context-free rule can rewrite rec flags atm, this
assert acts as a failsafe in case it ever changes *)
assert (Poly.(rf = exp_rf));
handle_attr_group_inline attr_sig_type_decls rf ~items:tds
~expanded_items:exp_tds ~loc ~base_ctxt
~expanded_items:exp_tds ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_group_inline attr_sig_type_decls_expect rf
~items:tds ~expanded_items:exp_tds ~loc ~base_ctxt
~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Psig_modtype mtd, Psig_modtype exp_mtd ->
handle_attr_inline attr_sig_module_type_decls ~item:mtd
~expanded_item:exp_mtd ~loc ~base_ctxt
~expanded_item:exp_mtd ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_sig_module_type_decls_expect
~item:mtd ~expanded_item:exp_mtd ~loc ~base_ctxt
~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Psig_typext te, Psig_typext exp_te ->
handle_attr_inline attr_sig_type_exts ~item:te
~expanded_item:exp_te ~loc ~base_ctxt
~expanded_item:exp_te ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_sig_type_exts_expect ~item:te
~expanded_item:exp_te ~loc ~base_ctxt
~expanded_item:exp_te ~loc ~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
| Psig_exception ec, Psig_exception exp_ec ->
handle_attr_inline attr_sig_exceptions ~item:ec
~expanded_item:exp_ec ~loc ~base_ctxt
~expanded_item:exp_ec ~loc ~base_ctxt ~convert_exn
>>= fun extra_items ->
handle_attr_inline attr_sig_exceptions_expect ~item:ec
~expanded_item:exp_ec ~loc ~base_ctxt
~expanded_item:exp_ec ~loc ~base_ctxt ~convert_exn
>>= fun expect_items ->
with_extra_items expanded_item ~extra_items ~expect_items
~rest ~in_generated_code
Expand Down
3 changes: 1 addition & 2 deletions test/driver/error_embedding/test.t/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ is caught and prepended to the last valid AST

$ echo "let _ = [%raise]" > impl.ml
$ ../raiser.exe -embed-errors impl.ml
[%%ocaml.error "Raising inside the rewriter"]
let _ = [%raise ]
let _ = [%ocaml.error "Raising inside the rewriter"]

The same is true when using the `-as-ppx` mode (note that the error is reported
by ocaml itself)
Expand Down
20 changes: 8 additions & 12 deletions test/driver/exception_handling/run.t
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ caught, so no AST is produced.

when the -embed-errors flag is passed
$ ./extender.exe -embed-errors impl.ml
[%%ocaml.error "A raised located error"]
[%%ocaml.error "A second raised located error"]
let x = 1 + 1.
let _ = [%gen_raise_located_error ]
let _ = [%gen_raise_located_error2 ]
let _ = [%ocaml.error "A raised located error"]
let _ = [%ocaml.error "A second raised located error"]

In the case of derivers

Expand All @@ -76,11 +74,11 @@ caught, so no AST is produced.

when the -embed-errors flag is passed
$ ./deriver.exe -embed-errors impl.ml
[%%ocaml.error "A raised located error"]
[%%ocaml.error "A second raised located error"]
type a = int
type b = int[@@deriving deriver_located_error]
[%%ocaml.error "A raised located error"]
type c = int[@@deriving deriver_located_error2]
[%%ocaml.error "A second raised located error"]

In the case of whole file transformations:

Expand All @@ -107,11 +105,9 @@ when the -embed-errors flag is not passed

when the -embed-errors flag is passed
$ ./extender.exe -embed-errors impl.ml
[%%ocaml.error "A raised located error"]
[%%ocaml.error "A second raised located error"]
let x = 1 + 1.
let _ = [%gen_raise_located_error ]
let _ = [%gen_raise_located_error2 ]
let _ = [%ocaml.error "A raised located error"]
let _ = [%ocaml.error "A second raised located error"]

In the case of derivers

Expand All @@ -127,12 +123,12 @@ when the -embed-errors flag is not passed
[1]
when the -embed-errors flag is passed
$ ./deriver.exe -embed-errors impl.ml
[%%ocaml.error "A raised located error"]
[%%ocaml.error "A second raised located error"]
let x = 1 + 1.
type a = int
type b = int[@@deriving deriver_located_error]
[%%ocaml.error "A raised located error"]
type b = int[@@deriving deriver_located_error2]
[%%ocaml.error "A second raised located error"]

In the case of whole file transformations:

Expand Down

0 comments on commit 17a7fd5

Please sign in to comment.