Skip to content

Commit

Permalink
Merge pull request kind2-mc#1016 from lorchrob/global-const-static-ch…
Browse files Browse the repository at this point in the history
…ecks

Add static checks for global constants with int range type
  • Loading branch information
daniel-larraz authored Oct 6, 2023
2 parents 57cd0f7 + f81c873 commit 7037672
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 40 deletions.
179 changes: 147 additions & 32 deletions src/lustre/lustreAbstractInterpretation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,36 @@ module LA = LustreAst
module Ctx = TypeCheckerContext
module TC = LustreTypeChecker

module R = Res

type error_kind = Unknown of string
| ConstantOutOfSubrange of HString.t

type error = [
| `LustreAbstractInterpretationError of Lib.position * error_kind
]

let error_message kind = match kind with
| Unknown s -> s
| ConstantOutOfSubrange i -> "Constant " ^ (HString.string_of_hstring i) ^
" is assigned a value outside of its subrange type"

let inline_error pos kind = Error (`LustreAbstractInterpretationError (pos, kind))

let unwrap result = match result with
| Ok r -> r
| Error _ -> assert false

module IMap = struct
(* everything that [Stdlib.Map] gives us *)
include Map.Make(struct
type t = LA.ident
let compare i1 i2 = HString.compare i1 i2
end)
let keys: 'a t -> key list = fun m -> List.map fst (bindings m)
end
module IMap = HString.HStringMap

(** Context from a node identifier to a map of its
variable identifiers to their inferred subrange bounds *)
type context = LA.lustre_type IMap.t IMap.t

let dpos = Lib.dummy_pos

let dnode_id = HString.mk_hstring "dummy_node_id"

let empty_context = IMap.empty

let union a b = IMap.union
Expand Down Expand Up @@ -77,18 +88,18 @@ let extract_bounds_from_type ty =
| IntRange _ -> None, None
| _ -> None, None)

let subrange_from_bounds l r =
let subrange_from_bounds pos l r =
let l = HString.mk_hstring (Numeral.string_of_numeral l) in
let r = HString.mk_hstring (Numeral.string_of_numeral r) in
LA.IntRange (dpos, Some (Const (dpos, Num l)), Some (Const (dpos, Num r)))
LA.IntRange (pos, Some (Const (pos, Num l)), Some (Const (pos, Num r)))

let subrange_from_lower l =
let subrange_from_lower pos l =
let l = HString.mk_hstring (Numeral.string_of_numeral l) in
LA.IntRange (dpos, Some (Const (dpos, Num l)), None)
LA.IntRange (pos, Some (Const (pos, Num l)), None)

let subrange_from_upper r =
let subrange_from_upper pos r =
let r = HString.mk_hstring (Numeral.string_of_numeral r) in
LA.IntRange (dpos, None, Some (Const (dpos, Num r)))
LA.IntRange (pos, None, Some (Const (pos, Num r)))

let rec merge_types a b = match a, b with
| LA.ArrayType (_, (t1, e)), LA.ArrayType (_, (t2, _)) ->
Expand Down Expand Up @@ -353,13 +364,14 @@ and interpret_expr_by_type node_id ctx ty_ctx ty proj expr : LA.lustre_type =
interpret_structured_expr f node_id ctx ty_ctx ty proj expr
| ArrayType (_, (t, s)) ->
let f = function
| LA.GroupExpr (_, ArrayExpr, es) ->
| LA.GroupExpr (_, ArrayExpr, (e :: _ as es)) ->
let t = List.fold_left (fun acc e ->
let t' = interpret_expr_by_type node_id ctx ty_ctx t proj e in
merge_types acc t')
t es
(interpret_expr_by_type node_id ctx ty_ctx t 0 e) es
in
Some (LA.ArrayType (dpos, (t, s)))
| LA.GroupExpr (pos, ArrayExpr, []) -> Some (ArrayType (pos, (t, s)))
| ArrayConstr (_, e1, _) ->
let t = interpret_expr_by_type node_id ctx ty_ctx t proj e1 in
Some (ArrayType (dpos, (t, s)))
Expand All @@ -377,53 +389,53 @@ and interpret_expr_by_type node_id ctx ty_ctx ty proj expr : LA.lustre_type =
| _ -> None
in
interpret_structured_expr f node_id ctx ty_ctx ty proj expr
| IntRange (_, (Some Const (_, Num l1)), (Some Const (_, Num r1))) as t ->
| IntRange (pos, (Some Const (_, Num l1)), (Some Const (_, Num r1))) as t ->
let l1 = Numeral.of_string (HString.string_of_hstring l1) in
let r1 = Numeral.of_string (HString.string_of_hstring r1) in
let l2, r2 = interpret_int_expr node_id ctx ty_ctx proj expr in
(match l2, r2 with
| Some l2, Some r2 ->
let l, r = Numeral.max l1 l2, Numeral.min r1 r2 in
subrange_from_bounds l r
subrange_from_bounds pos l r
| Some l2, None ->
let l = Numeral.max l1 l2 in
subrange_from_bounds l r1
subrange_from_bounds pos l r1
| None, Some r2 ->
let r = Numeral.min r1 r2 in
subrange_from_bounds l1 r
subrange_from_bounds pos l1 r
| _ -> t)
| IntRange (_, (Some Const (_, Num l1)), None) as t ->
| IntRange (pos, (Some Const (_, Num l1)), None) as t ->
let l1 = Numeral.of_string (HString.string_of_hstring l1) in
let l2, r2 = interpret_int_expr node_id ctx ty_ctx proj expr in
(match l2, r2 with
| Some l2, Some r2 ->
let l = Numeral.max l1 l2 in
subrange_from_bounds l r2
subrange_from_bounds pos l r2
| Some l2, None ->
let l = Numeral.max l1 l2 in
subrange_from_lower l
subrange_from_lower pos l
| None, Some r2 ->
subrange_from_bounds l1 r2
subrange_from_bounds pos l1 r2
| _ -> t)
| IntRange (_, None, (Some Const (_, Num r1))) as t ->
| IntRange (pos, None, (Some Const (_, Num r1))) as t ->
let r1 = Numeral.of_string (HString.string_of_hstring r1) in
let l2, r2 = interpret_int_expr node_id ctx ty_ctx proj expr in
(match l2, r2 with
| Some l2, Some r2 ->
let r = Numeral.min r1 r2 in
subrange_from_bounds l2 r
subrange_from_bounds pos l2 r
| Some l2, None ->
subrange_from_bounds l2 r1
subrange_from_bounds pos l2 r1
| None, Some r2 ->
let r = Numeral.min r1 r2 in
subrange_from_upper r
subrange_from_upper pos r
| _ -> t)
| Int _ | IntRange _ ->
| Int pos | IntRange (pos, None, None) ->
let l, r = interpret_int_expr node_id ctx ty_ctx proj expr in
(match l, r with
| Some l, Some r -> subrange_from_bounds l r
| Some l, None -> subrange_from_lower l
| None, Some r -> subrange_from_upper r
| Some l, Some r -> subrange_from_bounds pos l r
| Some l, None -> subrange_from_lower pos l
| None, Some r -> subrange_from_upper pos r
| _ -> LA.Int dpos)
| t -> t

Expand Down Expand Up @@ -614,3 +626,106 @@ and interpret_int_branch_expr node_id ctx ty_ctx proj e1 e2 =
| _ -> None)
in
l, r

let expr_opt_lte e1 e2 =
match e1 with
| None -> true
| Some (LA.Const (_, Num l1)) -> (
match e2 with
| None -> false
| Some (LA.Const (_, Num l2)) ->
int_of_string (HString.string_of_hstring l1) <= int_of_string (HString.string_of_hstring l2)
| _ -> assert false (* Not possible as we require subranges to have concrete bounds *)
)
| _ -> assert false (* Not possible as we require subranges to have concrete bounds *)

let expr_opt_gte e1 e2 =
match e1 with
| None -> true
| Some (LA.Const (_, Num l1)) -> (
match e2 with
| None -> false
| Some (LA.Const (_, Num l2)) ->
int_of_string (HString.string_of_hstring l1) >= int_of_string (HString.string_of_hstring l2)
| _ -> assert false (* Not possible as we require subranges to have concrete bounds *)
)
| _ -> assert false (* Not possible as we require subranges to have concrete bounds *)

(* Compare a constant's actual range to its inferred range to see if assignment is legal *)
let rec compare_ranges id pos_map actual_ty inferred_range =
let error () =
let pos = IMap.find id pos_map in
inline_error pos (ConstantOutOfSubrange id)
in
match inferred_range with
| LA.IntRange (_, e1, e2) ->
(match actual_ty with
| LA.IntRange (_, e3, e4) ->
if expr_opt_lte e3 e1 && expr_opt_gte e4 e2 && expr_opt_lte e1 e2
then R.ok ()
else error ()
| _ -> R.ok ())
| Int _ ->
(match actual_ty with
| LA.IntRange (_, Some _, _) -> error ()
| LA.IntRange (_, _, Some _) -> error ()
| _ -> R.ok ())
| ArrayType (_, (ty1, _)) ->
(match actual_ty with
| ArrayType (_, (ty2, _)) -> compare_ranges id pos_map ty2 ty1
| _ -> R.ok ())
| TupleType (_, types1) ->
(match actual_ty with
| TupleType (_, types2) ->
R.seq_ (List.map2 (compare_ranges id pos_map) types2 types1)
| _ -> R.ok ())
| _ -> R.ok ()

let rec most_general_int_ty = function
| LA.IntRange (pos, _, _) -> LA.Int pos
| LA.GroupType (pos, types)
| TupleType (pos, types) ->
let types = List.map most_general_int_ty types in
LA.TupleType (pos, types)
| RecordType (pos, id, tis) ->
let tis = List.map (fun (p, id, ty) -> (p, id, most_general_int_ty ty)) tis in
RecordType (pos, id, tis)
| LA.ArrayType (pos, (ty, expr)) ->
let ty = most_general_int_ty ty in
LA.ArrayType (pos, (ty, expr))
| _ as t -> t


let interpret_const_decl ctx pos_map ty_ctx = function
| LA.ConstDecl (_, TypedConst (_, id, e, _))
| ConstDecl (_, UntypedConst (_, id, e)) ->
(* Get inferred bounds from expr *)
let ty = Ctx.lookup_ty ty_ctx id |> get in
let ty = Ctx.expand_nested_type_syn ty_ctx ty in
let ty = most_general_int_ty ty in
let ty = interpret_expr_by_type dnode_id ctx ty_ctx ty 0 e in
let pos = LustreAstHelpers.pos_of_expr e in
add_type ctx dnode_id id ty, IMap.add id pos pos_map
| _ -> ctx, pos_map

let rec interpret_global_consts ty_ctx decls =
let ctx, pos_map = List.fold_left (fun (ctx, pos_map) decl ->
let ctx, pos_map = interpret_const_decl ctx pos_map ty_ctx decl in
ctx, pos_map
) (empty_context, IMap.empty) decls in
Res.seq_ (check_global_const_subrange ty_ctx ctx pos_map)

and check_global_const_subrange ty_ctx ctx pos_map =
let ctx =
match IMap.find_opt dnode_id ctx with
| None -> empty_context
| Some ctx -> ctx
in
IMap.fold (fun id inferred_range acc ->
let actual_ty = Ctx.lookup_ty ty_ctx id |> get in
let actual_ty = Ctx.expand_nested_type_syn ty_ctx actual_ty in
(* Check if inferred range is outside of declared type *)
compare_ranges id pos_map actual_ty inferred_range :: acc
)
ctx
[]
23 changes: 18 additions & 5 deletions src/lustre/lustreAbstractInterpretation.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,31 @@
(**
@author Andrew Marmaduke *)

module IMap : sig
(* everything that [Stdlib.Map] gives us *)
include (Map.S with type key = LustreAst.ident)
val keys: 'a t -> key list
end
module IMap = HString.HStringMap

type context


type error_kind = Unknown of string
| ConstantOutOfSubrange of HString.t

type error = [
| `LustreAbstractInterpretationError of Lib.position * error_kind
]

val error_message: error_kind -> string
(** Returns an message describing the error kind *)

val empty_context: context

val get_type: context -> LustreAst.ident -> LustreAst.ident -> LustreAst.lustre_type option

val union: context -> context -> context

val interpret_program: TypeCheckerContext.tc_context -> GeneratedIdentifiers.t GeneratedIdentifiers.StringMap.t -> LustreAst.t -> context

val interpret_global_consts: TypeCheckerContext.tc_context -> LustreAst.declaration list ->
( unit,
[> `LustreAbstractInterpretationError of
Lib.position * error_kind ] )
result
6 changes: 3 additions & 3 deletions src/lustre/lustreAstInlineConstants.ml
Original file line number Diff line number Diff line change
Expand Up @@ -282,16 +282,16 @@ and push_pre is_guarded pos =
and simplify_expr ?(is_guarded = false) ctx =
function
| LA.Const _ as c -> c
| LA.Ident (pos, i) ->
| LA.Ident (_, i) as ident ->
(match (TC.lookup_const ctx i) with
| Some (const_expr, _) ->
(match const_expr with
| LA.Ident (_, i') as ident' ->
if HString.compare i i' = 0 (* If This is a free constant *)
then ident'
then ident
else simplify_expr ~is_guarded ctx ident'
| _ -> simplify_expr ~is_guarded ctx const_expr)
| None -> LA.Ident (pos, i))
| None -> ident)
| LA.UnaryOp (pos, op, e1) ->
let e1' = simplify_expr ~is_guarded ctx e1 in
let e' = LA.UnaryOp (pos, op, e1') in
Expand Down
3 changes: 3 additions & 0 deletions src/lustre/lustreErrors.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type error = [
| `LustreArrayDependencies of Lib.position * LustreArrayDependencies.error_kind
| `LustreAstDependenciesError of Lib.position * LustreAstDependencies.error_kind
| `LustreAstInlineConstantsError of Lib.position * LustreAstInlineConstants.error_kind
| `LustreAbstractInterpretationError of Lib.position * LustreAbstractInterpretation.error_kind
| `LustreAstNormalizerError
| `LustreSyntaxChecksError of Lib.position * LustreSyntaxChecks.error_kind
| `LustreTypeCheckerError of Lib.position * LustreTypeChecker.error_kind
Expand All @@ -35,6 +36,7 @@ let error_position error = match error with
| `LustreArrayDependencies (pos, _) -> pos
| `LustreAstDependenciesError (pos, _) -> pos
| `LustreAstInlineConstantsError (pos, _) -> pos
| `LustreAbstractInterpretationError (pos, _) -> pos
| `LustreAstNormalizerError -> assert false
| `LustreSyntaxChecksError (pos, _) -> pos
| `LustreTypeCheckerError (pos, _) -> pos
Expand All @@ -47,6 +49,7 @@ let error_message error = match error with
| `LustreArrayDependencies (_, kind) -> LustreArrayDependencies.error_message kind
| `LustreAstDependenciesError (_, kind) -> LustreAstDependencies.error_message kind
| `LustreAstInlineConstantsError (_, kind) -> LustreAstInlineConstants.error_message kind
| `LustreAbstractInterpretationError (_, kind) -> LustreAbstractInterpretation.error_message kind
| `LustreAstNormalizerError -> assert false
| `LustreSyntaxChecksError (_, kind) -> LustreSyntaxChecks.error_message kind
| `LustreTypeCheckerError (_, kind) -> LustreTypeChecker.error_message kind
Expand Down
1 change: 1 addition & 0 deletions src/lustre/lustreErrors.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type error = [
| `LustreArrayDependencies of Lib.position * LustreArrayDependencies.error_kind
| `LustreAstDependenciesError of Lib.position * LustreAstDependencies.error_kind
| `LustreAstInlineConstantsError of Lib.position * LustreAstInlineConstants.error_kind
| `LustreAbstractInterpretationError of Lib.position * LustreAbstractInterpretation.error_kind
| `LustreAstNormalizerError
| `LustreSyntaxChecksError of Lib.position * LustreSyntaxChecks.error_kind
| `LustreTypeCheckerError of Lib.position * LustreTypeChecker.error_kind
Expand Down
2 changes: 2 additions & 0 deletions src/lustre/lustreInput.ml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type error = [
| `LustreArrayDependencies of Lib.position * LustreArrayDependencies.error_kind
| `LustreAstDependenciesError of Lib.position * LustreAstDependencies.error_kind
| `LustreAstInlineConstantsError of Lib.position * LustreAstInlineConstants.error_kind
| `LustreAbstractInterpretationError of Lib.position * LustreAbstractInterpretation.error_kind
| `LustreAstNormalizerError
| `LustreSyntaxChecksError of Lib.position * LustreSyntaxChecks.error_kind
| `LustreTypeCheckerError of Lib.position * LustreTypeChecker.error_kind
Expand Down Expand Up @@ -180,6 +181,7 @@ let type_check declarations =
let* _ = LAD.check_inductive_array_dependencies inlined_global_ctx node_summary const_inlined_nodes_and_contracts in

(* Step 14. Infer tighter subrange constraints with abstract interpretation *)
let* _ = LIA.interpret_global_consts inlined_global_ctx const_inlined_type_and_consts in
let abstract_interp_ctx = LIA.interpret_program inlined_global_ctx gids const_inlined_nodes_and_contracts in

(* Step 15. Normalize AST: guard pres, abstract to locals where appropriate *)
Expand Down
1 change: 1 addition & 0 deletions src/lustre/lustreInput.mli
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ type error = [
| `LustreArrayDependencies of Lib.position * LustreArrayDependencies.error_kind
| `LustreAstDependenciesError of Lib.position * LustreAstDependencies.error_kind
| `LustreAstInlineConstantsError of Lib.position * LustreAstInlineConstants.error_kind
| `LustreAbstractInterpretationError of Lib.position * LustreAbstractInterpretation.error_kind
| `LustreAstNormalizerError
| `LustreSyntaxChecksError of Lib.position * LustreSyntaxChecks.error_kind
| `LustreTypeCheckerError of Lib.position * LustreTypeChecker.error_kind
Expand Down
1 change: 1 addition & 0 deletions tests/ounit/lustre/dune
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
(glob_files lustreAstDependencies/*.lus)
(glob_files lustreTypeChecker/*.lus)
(glob_files lustreArrayDependencies/*.lus)
(glob_files lustreAbstractInterpretation/*.lus)
)
(libraries ounit2 kind2dev))
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
type S02 = subrange [0, 2] of int;

const X: S02;
const O: subrange [0, 1] of int;
const M: S02 = O + X;
node main() returns ();
let
check M <= 2;
tel
Loading

0 comments on commit 7037672

Please sign in to comment.