Skip to content

Commit

Permalink
Fix mark_as_seen, docstrings and add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Paul-Elliot <[email protected]>
  • Loading branch information
panglesd committed Sep 26, 2023
1 parent db877d5 commit 8f09201
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 25 deletions.
36 changes: 17 additions & 19 deletions src/attribute.ml
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,13 @@ let declare_with_attr_loc name context pattern k =
declare_with_all_args name context pattern (fun ~attr_loc ~name_loc:_ ->
k ~attr_loc)

type 'a flag = ('a, unit) t

let declare_flag name context =
let payload_pattern = Ast_pattern.(pstr nil) in
let continuation ~attr_loc:_ ~name_loc:_ = () in
declare_with_all_args name context payload_pattern continuation

module Attribute_table = Caml.Hashtbl.Make (struct
type t = string loc

Expand Down Expand Up @@ -356,6 +363,16 @@ let get t ?mark_as_seen:do_mark_as_seen x =
get_res t ?mark_as_seen:do_mark_as_seen x
|> Result.handle_error ~f:(fun (err, _) -> Location.Error.raise err)

let has_flag_res t ?mark_as_seen x =
match get_res ?mark_as_seen t x with
| Ok (Some ()) -> Ok true
| Ok None -> Ok false
| Error _ as e -> e

let has_flag t ?mark_as_seen x =
has_flag_res t ?mark_as_seen x
|> Result.handle_error ~f:(fun (err, _) -> Location.Error.raise err)

let consume_res t x =
let open Result in
let attrs = Context.get_attributes t.context x in
Expand Down Expand Up @@ -841,22 +858,3 @@ let dropped_so_far_signature sg =
Attribute_table.fold
(fun name loc acc -> { txt = name.txt; loc } :: acc)
table []

let declare_flag name context =
let payload_pattern = Ast_pattern.(pstr nil) in
let continuation ~attr_loc:_ ~name_loc:_ = () in
declare_with_all_args name context payload_pattern continuation
(* registers a flag using Ast_pattern with an empty pattern,
a continuation function that doesn't perform any operations,
and declare_with_all_args function *)

let has_flag (attr : ('a, unit) t) ?(mark_as_seen = false) x =
let seen = ref false in
match get attr x with
| Some () ->
if mark_as_seen then seen := true;
true
| None -> false
(* takes a data structure attr of type ('a, unit) t, an element x of type 'a,
and an optional boolean argument mark_as_seen, and returns a boolean indicating
whether the element x has a flag in the data structure *)
23 changes: 17 additions & 6 deletions src/attribute.mli
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ module Context : sig
val object_type_field : object_field t
end

val declare_flag : string -> 'a Context.t -> ('a, unit) t
(*takes a string and context containing values of type 'a as input and returns
a data structure t that can store elements of type 'a along with associated flags. *)

val declare :
string -> 'a Context.t -> (payload, 'b, 'c) Ast_pattern.t -> 'b -> ('a, 'c) t
(** [declare fully_qualified_name context payload_pattern k] declares an
Expand Down Expand Up @@ -139,6 +135,13 @@ val declare_with_attr_loc :
('a, 'c) t
(** Same as [declare] but the callback receives the location of the attribute. *)

type 'a flag = ('a, unit) t
(** Types for attributes without payload. *)

val declare_flag : string -> 'a Context.t -> 'a flag
(** Same as {!declare}, but the payload is expected to be empty. It is supposed
to be used in conjunction with {!has_flag}. *)

val name : _ t -> string
val context : ('a, _) t -> 'a Context.t

Expand All @@ -154,8 +157,16 @@ val get :
('a, 'b) t -> ?mark_as_seen:bool (** default [true] *) -> 'a -> 'b option
(** See {!get_res}. Raises a located error if the attribute is duplicated *)

val has_flag : ('a, unit) t -> ?mark_as_seen:bool -> 'a -> bool
(* takes data structure as input,returns a boolean with an optional named argument mark_as_seen. *)
val has_flag_res :
'a flag ->
?mark_as_seen:bool (** default [true] *) ->
'a ->
(bool, Location.Error.t NonEmptyList.t) result
(** Answers whether the given flag is attached as an attribute. See {!get_res}
for the meaning of [mark_as_seen]. *)

val has_flag : 'a flag -> ?mark_as_seen:bool (** default [true] *) -> 'a -> bool
(** See {!has_flag_res}. Raises a located error if the attribute is duplicated. *)

val consume_res :
('a, 'b) t -> 'a -> (('a * 'b) option, Location.Error.t NonEmptyList.t) result
Expand Down
33 changes: 33 additions & 0 deletions test/driver/attributes/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,36 @@ let x = (42 [@baz.qux3])
Line _, characters 14-22:
Error: Attribute `baz.qux3' was silently dropped
|}]

(* Testing flags *)

let flag = Attribute.declare_flag "flag" Attribute.Context.expression
[%%expect{|
val flag : expression Attribute.flag = <abstr>
|}]

let replace_flagged = object
inherit Ast_traverse.map as super

method! expression e =
match Attribute.has_flag_res flag e with
| Ok true -> Ast_builder.Default.estring ~loc:e.pexp_loc "Found flag"
| Ok false -> super#expression e
| Error (err, _) -> Ast_builder.Default.estring ~loc:e.pexp_loc (Location.Error.message err)
end
[%%expect{|
val replace_flagged : Ast_traverse.map = <obj>
|}]

let () =
Driver.register_transformation "" ~impl:replace_flagged#structure

let e1 = "flagged" [@flag]
[%%expect{|
val e1 : string = "Found flag"
|}]

let e1 = "flagged" [@flag 12]
[%%expect{|
val e1 : string = "[] expected"
|}]

0 comments on commit 8f09201

Please sign in to comment.