Skip to content

Commit

Permalink
Merge pull request kind2-mc#1034 from lorchrob/array-length-constraints
Browse files Browse the repository at this point in the history
Array length constraints, update node call type checking context
  • Loading branch information
daniel-larraz authored Dec 6, 2023
2 parents 99f890c + 33fd534 commit 3cd4816
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 14 deletions.
71 changes: 71 additions & 0 deletions src/lustre/lustreAstHelpers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,77 @@ let rec substitute_naive (var:HString.t) t = function
| Call (pos, id, expr_list) ->
Call (pos, id, List.map (fun e -> substitute_naive var t e) expr_list)

(* Substitute t for var. ChooseOp and Quantifier are not supported due to introduction of bound variables. *)
let rec apply_subst_in_expr sigma = function
| Ident (pos, i) -> (
match List.assoc_opt i sigma with
| Some expr -> expr
| None -> Ident (pos, i)
)
| ModeRef (_, _) as e -> e
| RecordProject (pos, e, idx) -> RecordProject (pos, apply_subst_in_expr sigma e, idx)
| TupleProject (pos, e, idx) -> TupleProject (pos, apply_subst_in_expr sigma e, idx)
| Const (_, _) as e -> e
| UnaryOp (pos, op, e) -> UnaryOp (pos, op, apply_subst_in_expr sigma e)
| BinaryOp (pos, op, e1, e2) ->
BinaryOp (pos, op, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2)
| TernaryOp (pos, op, e1, e2, e3) ->
TernaryOp (pos, op, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2, apply_subst_in_expr sigma e3)
| ConvOp (pos, op, e) -> ConvOp (pos, op, apply_subst_in_expr sigma e)
| CompOp (pos, op, e1, e2) ->
CompOp (pos, op, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2)
| AnyOp _ -> assert false (* Not supported due to introduction of bound variables *)
| Quantifier _ -> assert false (* Not supported due to introduction of bound variables *)
| RecordExpr (pos, ident, expr_list) ->
RecordExpr (pos, ident, List.map (fun (i, e) -> (i, apply_subst_in_expr sigma e)) expr_list)
| GroupExpr (pos, kind, expr_list) ->
GroupExpr (pos, kind, List.map (fun e -> apply_subst_in_expr sigma e) expr_list)
| StructUpdate (pos, e1, idx, e2) ->
StructUpdate (pos, apply_subst_in_expr sigma e1, idx, apply_subst_in_expr sigma e2)
| ArrayConstr (pos, e1, e2) ->
ArrayConstr (pos, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2)
| ArrayIndex (pos, e1, e2) ->
ArrayIndex (pos, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2)
| When (pos, e, clock) -> When (pos, apply_subst_in_expr sigma e, clock)
| Condact (pos, e1, e2, id, expr_list1, expr_list2) ->
let e1, e2 = apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2 in
let expr_list1 = List.map (fun e -> apply_subst_in_expr sigma e) expr_list1 in
let expr_list2 = List.map (fun e -> apply_subst_in_expr sigma e) expr_list2 in
Condact (pos, e1, e2, id, expr_list1, expr_list2)
| Activate (pos, ident, e1, e2, expr_list) ->
let e1, e2 = apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2 in
let expr_list = List.map (fun e -> apply_subst_in_expr sigma e) expr_list in
Activate (pos, ident, e1, e2, expr_list)
| Merge (pos, ident, expr_list) ->
Merge (pos, ident, List.map (fun (i, e) -> (i, apply_subst_in_expr sigma e)) expr_list)
| RestartEvery (pos, ident, expr_list, e) ->
let expr_list = List.map (fun e -> apply_subst_in_expr sigma e) expr_list in
let e = apply_subst_in_expr sigma e in
RestartEvery (pos, ident, expr_list, e)
| Pre (pos, e) -> Pre (pos, apply_subst_in_expr sigma e)
| Arrow (pos, e1, e2) -> Arrow (pos, apply_subst_in_expr sigma e1, apply_subst_in_expr sigma e2)
| Call (pos, id, expr_list) ->
Call (pos, id, List.map (fun e -> apply_subst_in_expr sigma e) expr_list)

let rec apply_subst_in_type sigma = function
| ArrayType (pos, (ty, expr)) -> (
let expr = apply_subst_in_expr sigma expr in
let ty = apply_subst_in_type sigma ty in
ArrayType (pos, (ty, expr))
)
| TupleType(pos, tys) ->
TupleType(pos, List.map (apply_subst_in_type sigma) tys)
| GroupType(pos, tys) ->
GroupType(pos, List.map (apply_subst_in_type sigma) tys)
| TArr(pos, ty1, ty2) ->
TArr(pos, apply_subst_in_type sigma ty1, apply_subst_in_type sigma ty2)
| RecordType (pos, name, tis) ->
let tis =
List.map (fun (p, id, ty) -> (p, id, apply_subst_in_type sigma ty)) tis
in
RecordType (pos, name, tis)
| ty -> ty

let rec has_unguarded_pre ung = function
| Const _ | Ident _ | ModeRef _ -> false

Expand Down
6 changes: 6 additions & 0 deletions src/lustre/lustreAstHelpers.mli
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ val substitute_naive : HString.t -> expr -> expr -> expr
(** Substitute second param for first param in third param.
AnyOp and Quantifier are not supported due to introduction of bound variables. *)

val apply_subst_in_type : (HString.t * expr) list -> lustre_type -> lustre_type
(** [apply_subst_in_type s t] applies the substitution defined by association list [s]
to the expressions of (possibly dependent) type [t]
AnyOp and Quantifier are not supported due to introduction of bound variables. *)


val has_unguarded_pre : expr -> bool
(** Returns true if the expression has unguareded pre's *)

Expand Down
70 changes: 60 additions & 10 deletions src/lustre/lustreTypeChecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,46 @@ let check_constant_args ctx i arg_exprs =
else R.ok ()
)

let rec type_extract_array_lens ctx ty = match ty with
| LA.ArrayType (_, (ty, expr)) -> expr :: type_extract_array_lens ctx ty
| TupleType (_, tys) -> List.map (type_extract_array_lens ctx) tys |> List.flatten
| GroupType (_, tys) -> List.map (type_extract_array_lens ctx) tys |> List.flatten
| TArr (_, ty1, ty2) ->
type_extract_array_lens ctx ty1 @ type_extract_array_lens ctx ty2
| RecordType (_, _, tis) ->
let tys = List.map (fun (_, _, ty) -> ty) tis in
List.map (type_extract_array_lens ctx) tys |> List.flatten
| UserType (_, id) ->
(match (lookup_ty_syn ctx id) with
| Some ty -> type_extract_array_lens ctx ty;
| None -> [])
| _ -> []

let update_ty_with_ctx node_ty call_params ctx arg_exprs =
let call_param_len_idents =
type_extract_array_lens ctx node_ty
|> List.map (LH.vars_without_node_call_ids)
(* Remove duplicates *)
|> List.fold_left (fun acc vars -> LA.SI.union vars acc) LA.SI.empty
|> LA.SI.elements
(* Filter out constants. If "id" is a constant, it must be a local constant *)
|> List.filter (fun id -> not (member_val ctx id) || (List.mem id call_params) )
in
match call_param_len_idents with
| [] -> node_ty
| _ -> (
let find_matching_len_params array_param_lens =
List.map (fun len -> (Lib.list_index (fun id2 -> len = id2) call_params)) array_param_lens
in
(* Find indices of array length parameters. E.g. in Call(m :: const int, A :: int^m), the index
of array length param "m" is 0. *)
let array_len_indices = find_matching_len_params call_param_len_idents in
(* Retrieve concrete arguments passed as array lengths *)
let array_len_exprs = List.map (List.nth arg_exprs) array_len_indices in
(* Do substitution to express exp_arg_tys and exp_ret_tys in terms of the current context *)
LH.apply_subst_in_type (List.combine call_param_len_idents array_len_exprs) node_ty
)

let rec infer_type_expr: tc_context -> LA.expr -> (tc_type, [> error]) result
= fun ctx -> function
(* Identifiers *)
Expand Down Expand Up @@ -588,17 +628,23 @@ let rec infer_type_expr: tc_context -> LA.expr -> (tc_type, [> error]) result
if List.length arg_tys = 1 then R.ok (List.hd arg_tys)
else R.ok (LA.GroupType (pos, arg_tys))
in
match (lookup_node_ty ctx i) with
| Some (TArr (_, exp_arg_tys, exp_ret_tys)) -> (
match (lookup_node_param_ids ctx i), (lookup_node_ty ctx i) with
| Some call_params, Some node_ty -> (
(* Express exp_arg_tys and exp_ret_tys in terms of the current context *)
let node_ty = update_ty_with_ctx node_ty call_params ctx arg_exprs in
let exp_arg_tys, exp_ret_tys = match node_ty with
| TArr (_, exp_arg_tys, exp_ret_tys) -> exp_arg_tys, exp_ret_tys
| _ -> assert false
in
let* given_arg_tys = infer_type_node_args ctx arg_exprs in
let* are_equal = eq_lustre_type ctx exp_arg_tys given_arg_tys in
if are_equal then
(check_constant_args ctx i arg_exprs >> (R.ok exp_ret_tys))
else
(type_error pos (IlltypedCall (exp_arg_tys, given_arg_tys)))
)
| Some ty -> type_error pos (ExpectedFunctionType ty)
| None -> type_error pos (UnboundNodeName i)
| _, Some ty -> type_error pos (ExpectedFunctionType ty)
| _, None -> type_error pos (UnboundNodeName i)
)
(** Infer the type of a [LA.expr] with the types of free variables given in [tc_context] *)

Expand Down Expand Up @@ -767,14 +813,18 @@ and check_type_expr: tc_context -> LA.expr -> tc_type -> (unit, [> error]) resul

(* Node calls *)
| Call (pos, i, args) ->
R.seq (List.map (infer_type_expr ctx) args) >>= fun arg_tys ->
let* arg_tys = R.seq (List.map (infer_type_expr ctx) args) in
let arg_ty = if List.length arg_tys = 1 then List.hd arg_tys
else GroupType (pos, arg_tys) in
(match (lookup_node_ty ctx i) with
| None -> type_error pos (UnboundNodeName i)
| Some ty ->
R.guard_with (eq_lustre_type ctx ty (LA.TArr (pos, arg_ty, exp_ty)))
(type_error pos (MismatchedNodeType (i, (TArr (pos, arg_ty, exp_ty)), ty))))
(match (lookup_node_ty ctx i), (lookup_node_param_ids ctx i) with
| None, _
| _, None -> type_error pos (UnboundNodeName i)
| Some ty, Some call_params ->
(* Express ty in terms of the current context *)
let ty = update_ty_with_ctx ty call_params ctx args in
let* b = (eq_lustre_type ctx ty (LA.TArr (pos, arg_ty, exp_ty))) in
if b then R.ok ()
else (type_error pos (MismatchedNodeType (i, (TArr (pos, arg_ty, exp_ty)), ty))))
(** Type checks an expression and returns [ok]
* if the expected type is the given type [tc_type]
* returns an [Error of string] otherwise *)
Expand Down
6 changes: 6 additions & 0 deletions src/lustre/typeCheckerContext.ml
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ let lookup_node_ty: tc_context -> LA.ident -> tc_type option
let lookup_node_param_attr: tc_context -> LA.ident -> (HString.t * bool) list option
= fun ctx i -> IMap.find_opt i (ctx.node_param_attr)

let lookup_node_param_ids: tc_context -> LA.ident -> HString.t list option
= fun ctx i ->
match IMap.find_opt i (ctx.node_param_attr) with
| Some l -> Some (List.map fst l)
| None -> None

let lookup_const: tc_context -> LA.ident -> (LA.expr * tc_type option) option
= fun ctx i -> IMap.find_opt i (ctx.vl_ctx)
(** Lookup a constant identifier *)
Expand Down
2 changes: 2 additions & 0 deletions src/lustre/typeCheckerContext.mli
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ val lookup_node_ty: tc_context -> LA.ident -> tc_type option

val lookup_node_param_attr: tc_context -> LA.ident -> (HString.t * bool) list option

val lookup_node_param_ids: tc_context -> LA.ident -> HString.t list option

val lookup_const: tc_context -> LA.ident -> (LA.expr * tc_type option) option
(** Lookup a constant identifier *)

Expand Down
8 changes: 4 additions & 4 deletions tests/ounit/lustre/testLustreFrontend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ let _ = run_test_tt_main ("frontend LustreSyntaxChecks error tests" >::: [
(* Lustre Ast Array Dependencies Checks *)
(* *************************************************************************** *)
let _ = run_test_tt_main ("frontend lustreArrayDependencies error tests" >::: [
mk_test "test illtyped call" (fun () ->
match load_file "./lustreTypeChecker/SteamBoiler2.lus" with
| Error (`LustreArrayDependencies (_, Cycle _)) -> true
| _ -> false);
mk_test "test invalid inductive array def 1" (fun () ->
match load_file "./lustreArrayDependencies/inductive_array1.lus" with
| Error (`LustreArrayDependencies (_, Cycle _)) -> true
Expand Down Expand Up @@ -479,10 +483,6 @@ let _ = run_test_tt_main ("frontend LustreTypeChecker error tests" >::: [
match load_file "./lustreTypeChecker/mode_reqs_by_idents_shadowing.lus" with
| Error (`LustreTypeCheckerError (_, Redeclaration _)) -> true
| _ -> false);
mk_test "test illtyped call" (fun () ->
match load_file "./lustreTypeChecker/SteamBoiler2.lus" with
| Error (`LustreTypeCheckerError (_, IlltypedCall _)) -> true
| _ -> false);
mk_test "test expected type 1" (fun () ->
match load_file "./lustreTypeChecker/test_array_group.lus" with
| Error (`LustreTypeCheckerError (_, ExpectedType _)) -> true
Expand Down

0 comments on commit 3cd4816

Please sign in to comment.