diff --git a/.unreleased/features/rewrite.md b/.unreleased/features/rewrite.md new file mode 100644 index 0000000000..00fb95d91b --- /dev/null +++ b/.unreleased/features/rewrite.md @@ -0,0 +1 @@ +Handle expressions such as S \in SUBSET [ a : Int ] by rewriting the expression into \A r \in S: DOMAIN r = {"a"} /\ r.a \in Int \ No newline at end of file diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala index e85d215f85..9faa5d1b68 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala @@ -97,6 +97,34 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker } apply(tla.and(domEq +: fieldsEq: _*).as(b)) } + + // S ∈ SUBSET { ["a" ↦ x] : x ∈ T } + case memEx @ OperEx(TlaSetOper.in, setRec, + OperEx(TlaSetOper.powerset, + OperEx(TlaSetOper.map, OperEx(TlaFunOper.rec, fieldsAndValues @ _*), varsAndSets @ _*))) + if fieldsAndValues.length == varsAndSets.length => + val (fields, values) = TlaOper.deinterleave(fieldsAndValues) + val (vars, sets) = TlaOper.deinterleave(varsAndSets) + assert(fields.length == vars.length) + if (values.zip(vars).exists(p => p._1 != p._2)) { + // The set has a more general form: { [f_1 |-> e_1, ..., f_k |-> e_k]: x_1 \in S_1, ..., x_k \in S_k }, where + // e_1, ..., e_k are expressions over x_1, ..., x_k. + // We do not know how to optimize it. + memEx + } else { + val strSetT = SetT1(StrT1) + val b = BoolT1 + + val domType = getElemType(setRec) + val r = tla.name(nameGen.newName()).as(domType) + + val domEq = tla.eql(tla.dom(r).as(SetT1(domType)), tla.enumSet(fields: _*).as(strSetT)).as(b) + + val fieldsEq = fields.zip(values.zip(sets)).map { case (key, (value, set)) => + tla.in(tla.appFun(r, key).as(value.typeTag.asTlaType1()), set).as(b) + } + apply(tla.forall(r, setRec, tla.and(domEq +: fieldsEq: _*).as(b)).as(b)) + } } /** diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala index 5f13923734..a38e77fefc 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala @@ -87,22 +87,100 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach { // An optimization for set membership over sets of records. Note that this is the standard form produced by Keramelizer. test("""r \in { [a |-> x, b |-> y]: x \in S, y \in T } becomes DOMAIN r = { "a", "b" } /\ r.a \in S /\ r.b \in T""") { + // ... [a |-> x, b |-> y] ... val recT = RecT1("a" -> IntT1, "b" -> IntT1) - val recSetT = SetT1(recT) + // ... x \in S, y \in T ... val record = enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT) + // ... S ... val S = name("S").as(intSetT) + // ... T ... val T = name("T").as(intSetT) + // { ... } + val recSetT = SetT1(recT) val recordSet = map(record, name("x").as(intT), S, name("y").as(intT), T).as(recSetT) + // r ... val r = name("r").as(recT) + // ... \in ... val input = in(r, recordSet).as(boolT) + // ~~> + + // DOMAIN r = { "a", "b" } val strSetT = SetT1(StrT1) val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT) + // r.a \in S val memA = in(appFun(r, str("a")).as(intT), S).as(boolT) + // r.b \in T val memB = in(appFun(r, str("b")).as(intT), T).as(boolT) + // ... /\ ... /\ ... val expected = and(domEq, memA, memB).as(boolT) val output = optimizer.apply(input) + + assert(expected == output) + } + + // An optimization for set membership of the powerset of a record where the record has infinite co-domains. + test("""S \in SUBSET [a : T] ~~> \A r \in S: DOMAIN r = { "a" } /\ r.a \in T""") { + + // ... { [a |-> x] : x \in T } ... + val recT = RecT1("a" -> IntT1) + val record = + enumFun(str("a"), name("x").as(intT)).as(recT) + val T = name("T").as(intSetT) + val recSetT = SetT1(recT) + val recordSet = map(record, name("x").as(intT), T).as(recSetT) + + // ... SUBSET ... + val powSetT = powSet(recordSet).as(recSetT) + + // S ... + val s = name("S").as(recSetT) + + // ... \in ... + val input = in(s, powSetT).as(boolT) + val output = optimizer.apply(input) + + // ~~> + + // DOMAIN r = { "a" } + val r = name("t_1").as(recT) + val strSetT = SetT1(StrT1) + val domEq = eql(dom(r).as(strSetT), enumSet(str("a")).as(strSetT)).as(boolT) + + // r.a \in T + val memA = in(appFun(r, str("a")).as(intT), T).as(boolT) + + // ... /\ ... + val conjunct = and(domEq, memA).as(boolT) + + // \A ... + val expected = forall(r, s, conjunct).as(boolT) + + assert(expected == output) + } + + test("""S \in SUBSET [a : T, b : U] ~~> \A r \in S: DOMAIN r = { "a", "b" } /\ r.a \in T /\ r.b \in U""") { + val recT = RecT1("a" -> IntT1, "b" -> IntT1) + val record = + enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT) + val T = name("T").as(intSetT) + val U = name("U").as(intSetT) + val recSetT = SetT1(recT) + val recordSet = map(record, name("x").as(intT), T, name("y").as(intT), U).as(recSetT) + val powSetT = powSet(recordSet).as(recSetT) + val s = name("S").as(recSetT) + val input = in(s, powSetT).as(boolT) + val output = optimizer.apply(input) + + val r = name("t_1").as(recT) + val strSetT = SetT1(StrT1) + val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT) + val memA = in(appFun(r, str("a")).as(intT), T).as(boolT) + val memB = in(appFun(r, str("b")).as(intT), U).as(boolT) + val conjunct = and(domEq, memA, memB).as(boolT) + val expected = forall(r, s, conjunct).as(boolT) + assert(expected == output) }