diff --git a/src/lustre/lustreNodeGen.ml b/src/lustre/lustreNodeGen.ml index 910c75788..0207238aa 100644 --- a/src/lustre/lustreNodeGen.ml +++ b/src/lustre/lustreNodeGen.ml @@ -288,18 +288,52 @@ let extract_normalized = function | A.ArrayIndex (_, A.Ident (_, ident), _) -> mk_ident ident | _ -> assert false - -let flatten_list_indexes (e:E.t X.t) = - let over_indices k (indices, e) = - let rec flatten = function - | (X.ListIndex _) :: (ListIndex _) :: t -> flatten ((X.ListIndex k) :: t) - | h :: t -> h :: t - | [] -> [] - in - (flatten indices, e) +module XMap = Map.Make(struct + type t = X.index + let compare = X.compare_indexes +end) + +let flatten_list_indexes (e:'a X.t) = + let top_is_list = + try X.top_max_index e >= 0 + with Invalid_argument _ -> false in - let flattened = List.mapi over_indices (X.bindings e) in - List.fold_left (fun acc (idx, e) -> X.add idx e acc) X.empty flattened + if not top_is_list then e + else + let rec extract_list_prefix acc = function + | (X.ListIndex i) :: tl -> + extract_list_prefix ((X.ListIndex i) :: acc) tl + | rest -> (List.rev acc), rest + in + let m = + List.fold_left (fun acc (indices, e) -> + let prefix, other = extract_list_prefix [] indices in + XMap.update + prefix + (function + | None -> Some [(other, e)] + | Some l -> Some ((other, e) :: l) + ) + acc + ) + XMap.empty + (X.bindings e) + in + XMap.fold + (fun _ l (acc, i) -> + let acc = + List.fold_left + (fun acc (indices, e) -> + X.add ((X.ListIndex i) :: indices) e acc + ) + acc + l + in + acc, i + 1 + ) + m + (X.empty, 0) + |> fst (* Match bindings from a trie of state variables and bindings for a trie of expressions and produce a list of equations *) @@ -1911,6 +1945,7 @@ and compile_node_decl gids is_function cstate ctx i ext inputs outputs locals it vars in H.add_seq !map.quant_vars (H.to_seq quant_var_map); let eq_rhs = compile_ast_expr cstate ctx bounds map ast_expr in + let eq_lhs = flatten_list_indexes eq_lhs in let eq_rhs = flatten_list_indexes eq_rhs in (* Format.eprintf "lhs: %a\n\n rhs: %a\n\n" (X.pp_print_index_trie true StateVar.pp_print_state_var) eq_lhs @@ -1960,6 +1995,7 @@ and compile_node_decl gids is_function cstate ctx i ext inputs outputs locals it in let lhs_bounds = gen_lhs_bounds is_generated eq_lhs ast_expr indexes in let eq_rhs = compile_ast_expr cstate ctx lhs_bounds map ast_expr in + let eq_lhs = flatten_list_indexes eq_lhs in let eq_rhs = flatten_list_indexes eq_rhs in (* Format.eprintf "lhs: %a@.rhs: %a@.@." (X.pp_print_index_trie true StateVar.pp_print_state_var) eq_lhs diff --git a/tests/regression/success/list_flattening.lus b/tests/regression/success/list_flattening.lus new file mode 100644 index 000000000..f17230ca1 --- /dev/null +++ b/tests/regression/success/list_flattening.lus @@ -0,0 +1,42 @@ + + +node F0(x:int) returns (z:int); +let + z = x; +tel + +node F1(x,y:int) returns (z1,z2:int); +let + z1 = x; z2 = y; +tel + +node F2(x1:int;y1:int;x2:int;y2:int) returns (z1,z2:int); +let + z1 = x1 + y1; z2 = x2 - y2; +tel + +node N(x,y:int; c: bool) returns (ok: bool); +var a0,a1,a2,a3,a4,a5,a6,a7,a8,a9:int; + b0,b1,b2,b3,b4,b5,b6:int; +let + a0, a1 = F2(F1(x,y),F1(x,y)); + a2, a3 = F2(F1(x,y),y,F0(x)); + a4, a5 = F2((x,y),F0(x),y); + a6, a7 = F2((F1(x,y),x),y); + a8, a9 = F2((x,(F0(y),(y,F0(x))))); + b0, b1, b2 = (F1(x,y),x); + b3, b4 = F2((F0(x),(y,F1(x,y)))); + b5, b6 = + if c then F2(F1(x,y),F1(x,y)) + else F2((x,y),F0(x),y); + + check "P1" a0 = x + y and a1 = x - y; + check "P2" a2 = x + y and a3 = y - x; + check "P3" a4 = x + y and a5 = x - y; + check "P4" a6 = x + y and a7 = x - y; + check "P5" a6 = x + y and a7 = x - y; + check "P6" a8 = x + y and a9 = y - x; + check "P7" b0 = x and b1 = y and b2 = x; + check "P8" b3 = x + y and b4 = x - y; + check "P9" c => b5 = x + y and b6 = x - y; +tel \ No newline at end of file